Files
xmpp-radio-tower/jingle_webrtc.py
2025-12-13 18:32:34 +00:00

742 lines
29 KiB
Python

#!/usr/bin/env python3
"""
WebRTC and Jingle handling module.
This module provides the core logic for establishing Audio/Video sessions via
XMPP Jingle (XEP-0166) and WebRTC. It manages the RTCPeerConnections,
ICE candidate negotiation, and audio broadcasting.
Key features:
- Shared memory audio broadcasting to multiple subscribers.
- Synchronization of audio streams.
- Explicit cleanup routines to prevent memory leaks with aiortc objects.
"""
import asyncio
import gc
import random
import logging
import os
import re
import weakref
from typing import Optional, Dict, List, Any, Callable
from dataclasses import dataclass, field
from enum import Enum
from fractions import Fraction
logger = logging.getLogger(__name__)
AIORTC_AVAILABLE = False
try:
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration, RTCIceServer
from aiortc import MediaStreamTrack, RTCIceCandidate
from aiortc.contrib.media import MediaPlayer
import av
AIORTC_AVAILABLE = True
except ImportError as e:
logger.error(f"aiortc import failed: {e}")
class CallState(Enum):
"""Represents the lifecycle state of a Jingle session."""
IDLE = "idle"
ACTIVE = "active"
ENDED = "ended"
@dataclass(slots=True)
class JingleSession:
"""
Stores state for a single WebRTC session.
Attributes:
sid: Session ID.
peer_jid: JID of the remote peer.
pc: The RTCPeerConnection object.
audio_track: The specific audio track associated with this session.
state: Current call state.
pending_candidates: ICE candidates received before remote description was set.
local_candidates: ICE candidates generated locally.
ice_gathering_complete: Event signaling completion of ICE gathering.
"""
sid: str
peer_jid: str
pc: Any = None
audio_track: Any = None
state: CallState = CallState.IDLE
pending_candidates: List[Any] = field(default_factory=list)
local_candidates: List[Dict] = field(default_factory=list)
ice_gathering_complete: asyncio.Event = field(default_factory=asyncio.Event)
def cleanup(self):
"""Releases references to heavy objects to aid garbage collection."""
self.pending_candidates.clear()
self.local_candidates.clear()
self.pc = None
self.audio_track = None
class SynchronizedRadioBroadcaster:
"""
Central audio broadcaster using a producer-consumer model.
This class decodes a single audio stream and distributes the frames to
multiple subscribers (SynchronizedAudioTrack instances) simultaneously.
It uses weak references to subscribers to avoid reference cycles.
"""
MAX_SUBSCRIBERS = 10
SAMPLES_PER_FRAME = 960
CLEANUP_INTERVAL = 200
def __init__(self, on_track_end_callback: Optional[Callable] = None):
self._player: Optional[MediaPlayer] = None
self._current_source: Optional[str] = None
self._lock = asyncio.Lock()
# Use weakrefs to allow tracks to be GC'd if the connection drops unexpectedly.
self._subscribers: List[weakref.ref] = []
self._broadcast_task: Optional[asyncio.Task] = None
self._pts = 0
self._time_base = Fraction(1, 48000)
self.on_track_end = on_track_end_callback
self._shutdown = False
self._track_end_fired = False
self._stopped = False
self._frame_counter = 0
# Persist resampler to avoid high CPU usage from constant reallocation.
self._resampler = None
def _get_silence_frame(self) -> 'av.AudioFrame':
"""Generates a silent audio frame to keep the RTP stream alive."""
frame = av.AudioFrame(format='s16', layout='stereo', samples=self.SAMPLES_PER_FRAME)
frame.sample_rate = 48000
frame.time_base = self._time_base
for plane in frame.planes:
plane.update(bytes(plane.buffer_size))
frame.pts = self._pts
self._pts += self.SAMPLES_PER_FRAME
return frame
async def set_source(self, source: str, force_restart: bool = False):
"""
Updates the audio source, restarting the broadcast loop if necessary.
Args:
source: URI or path to the audio file.
force_restart: If True, restarts player even if source is identical.
"""
async with self._lock:
normalized_source = source
if source and not source.startswith(('http:', 'https:', 'rtmp:', 'rtsp:')):
normalized_source = os.path.abspath(source)
broadcast_is_running = (
self._broadcast_task and
not self._broadcast_task.done()
)
if (normalized_source == self._current_source and
self._player and
broadcast_is_running and
not force_restart):
return
logger.info(f"Changing broadcast: {os.path.basename(normalized_source) if normalized_source else 'None'}")
await self._stop_broadcast_task()
await self._close_player()
self._current_source = normalized_source
self._pts = 0
self._track_end_fired = False
self._stopped = False
self._frame_counter = 0
# Reset resampler context to prevent timestamp drift or format mismatch.
self._resampler = None
if not source:
return
try:
logger.info(f"Starting broadcast: {source}")
# Use a buffer for network streams to reduce jitter.
options = {'rtbufsize': '8M'} if source.startswith(('http', 'rtmp', 'rtsp')) else {}
self._player = MediaPlayer(source, options=options)
self._broadcast_task = asyncio.create_task(self._broadcast_loop())
except Exception as e:
logger.error(f"MediaPlayer Error: {e}")
self._player = None
async def _stop_broadcast_task(self):
"""Cancels and awaits the broadcast task."""
if self._broadcast_task:
self._broadcast_task.cancel()
try:
await asyncio.wait_for(self._broadcast_task, timeout=2.0)
except (asyncio.CancelledError, asyncio.TimeoutError):
pass
self._broadcast_task = None
async def _close_player(self):
"""Safely closes the media player and underlying container."""
self._resampler = None
if self._player:
try:
old_player = self._player
self._player = None
if hasattr(old_player, 'audio') and old_player.audio:
try: old_player.audio.stop()
except: pass
if hasattr(old_player, '_container') and old_player._container:
try: old_player._container.close()
except: pass
del old_player
except Exception as e:
logger.debug(f"Error closing player: {e}")
gc.collect()
async def stop_playback(self):
"""Stops the broadcast and clears the current source."""
async with self._lock:
self._stopped = True
await self._stop_broadcast_task()
await self._close_player()
self._current_source = None
logger.info("Playback stopped")
async def _broadcast_loop(self):
"""
Main loop: decodes frames, resamples them, and distributes to subscribers.
Handles errors and End-Of-File (EOF) events.
"""
last_error_time = 0
error_count_window = 0
try:
while self._player and not self._shutdown and not self._stopped:
self._frame_counter += 1
# Periodically clean up dead weak references.
if self._frame_counter % self.CLEANUP_INTERVAL == 0:
self._cleanup_subscribers()
# Pause decoding if no one is listening to save CPU.
if not self._get_active_subscribers():
await asyncio.sleep(0.1)
continue
frame = None
try:
if not self._player or not self._player.audio:
break
frame = await asyncio.wait_for(
self._player.audio.recv(),
timeout=0.5
)
current_time = asyncio.get_event_loop().time()
if current_time - last_error_time > 5.0:
error_count_window = 0
# Ensure audio is standard stereo 48kHz.
if frame.format.name != 's16' or frame.sample_rate != 48000:
if self._resampler is None:
self._resampler = av.AudioResampler(format='s16', layout='stereo', rate=48000)
resampled_frames = self._resampler.resample(frame)
# Explicitly release the original frame to free C-level memory.
del frame
if resampled_frames:
frame = resampled_frames[0]
else:
continue
# Rewrite timestamps for the continuous stream.
frame.pts = self._pts
frame.time_base = self._time_base
self._pts += frame.samples
await self._distribute_frame(frame)
except asyncio.TimeoutError:
# Send silence on network timeout to prevent RTP timeout.
s_frame = self._get_silence_frame()
await self._distribute_frame(s_frame)
del s_frame
except (av.error.EOFError, StopIteration, StopAsyncIteration):
logger.info("Track ended (EOF)")
break
except Exception as e:
error_msg = str(e)
if "MediaStreamError" in str(type(e).__name__):
break
logger.debug(f"Broadcast frame warning: {e}")
error_count_window += 1
if error_count_window >= 20:
logger.error("Too many errors, stopping broadcast")
break
await asyncio.sleep(0.1)
finally:
if frame is not None:
del frame
# Notify the main bot that the track finished.
if self.on_track_end and not self._shutdown and not self._track_end_fired and not self._stopped:
self._track_end_fired = True
try:
result = self.on_track_end()
if asyncio.iscoroutine(result):
asyncio.create_task(result)
except Exception as e:
logger.error(f"Track end callback error: {e}")
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"Broadcast loop error: {e}")
finally:
self._resampler = None
gc.collect()
def _cleanup_subscribers(self):
self._subscribers = [ref for ref in self._subscribers if ref() is not None]
def _get_active_subscribers(self) -> List['SynchronizedAudioTrack']:
active = []
for ref in self._subscribers:
track = ref()
if track is not None and track._active:
active.append(track)
return active
async def _distribute_frame(self, frame: 'av.AudioFrame'):
subscribers = self._get_active_subscribers()
if not subscribers:
return
# Dispatch frames to all tracks concurrently.
tasks = [sub._receive_frame(frame) for sub in subscribers]
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
def subscribe(self, track: 'SynchronizedAudioTrack'):
self._cleanup_subscribers()
if len(self._subscribers) >= self.MAX_SUBSCRIBERS:
return
self._subscribers.append(weakref.ref(track))
def unsubscribe(self, track: 'SynchronizedAudioTrack'):
self._subscribers = [ref for ref in self._subscribers
if ref() is not None and ref() is not track]
def shutdown(self):
self._shutdown = True
self._resampler = None
gc.collect()
class SynchronizedAudioTrack(MediaStreamTrack):
"""
A MediaStreamTrack that receives frames from the broadcaster queue.
"""
kind = "audio"
MAX_QUEUE_SIZE = 3
def __init__(self, broadcaster: SynchronizedRadioBroadcaster):
super().__init__()
self._broadcaster = broadcaster
self._frame_queue: asyncio.Queue = asyncio.Queue(maxsize=self.MAX_QUEUE_SIZE)
self._active = True
self._silence_frame: Optional[av.AudioFrame] = None
self._silence_pts = 0
self._broadcaster.subscribe(self)
def _get_silence_frame(self) -> 'av.AudioFrame':
"""Lazily creates and reuses a silence frame."""
if self._silence_frame is None:
self._silence_frame = av.AudioFrame(format='s16', layout='stereo', samples=960)
self._silence_frame.sample_rate = 48000
self._silence_frame.time_base = Fraction(1, 48000)
for plane in self._silence_frame.planes:
plane.update(bytes(plane.buffer_size))
self._silence_frame.pts = self._silence_pts
self._silence_pts += 960
return self._silence_frame
async def _receive_frame(self, frame: 'av.AudioFrame'):
if not self._active:
return
try:
# If the queue is full, drop the oldest frame to reduce latency.
if self._frame_queue.full():
try:
old_frame = self._frame_queue.get_nowait()
del old_frame
except asyncio.QueueEmpty:
pass
self._frame_queue.put_nowait(frame)
except Exception:
pass
async def recv(self):
"""Called by aiortc to pull the next frame."""
if not self._active:
raise Exception("Track stopped")
try:
frame = await asyncio.wait_for(
self._frame_queue.get(),
timeout=0.05
)
return frame
except asyncio.TimeoutError:
return self._get_silence_frame()
def stop(self):
self._active = False
self._broadcaster.unsubscribe(self)
# Drain queue to release frame references.
while not self._frame_queue.empty():
try:
frame = self._frame_queue.get_nowait()
del frame
except asyncio.QueueEmpty:
break
if self._silence_frame is not None:
del self._silence_frame
self._silence_frame = None
super().stop()
class JingleWebRTCHandler:
"""
Manages the XMPP Jingle negotiation and WebRTC sessions.
Translates between Jingle XML stanzas and SDP.
"""
NS_JINGLE = "urn:xmpp:jingle:1"
NS_JINGLE_ICE = "urn:xmpp:jingle:transports:ice-udp:1"
MAX_SESSIONS = 20
SESSION_CLEANUP_INTERVAL = 60
def __init__(self, stun_server: str, send_transport_info_callback: Optional[Callable] = None, on_track_end: Optional[Callable] = None):
if stun_server and not stun_server.startswith(('stun:', 'turn:', 'stuns:')):
self.stun_server = f"stun:{stun_server}"
else:
self.stun_server = stun_server
self.sessions: Dict[str, JingleSession] = {}
self._proposed_sessions: Dict[str, str] = {}
self._broadcaster = SynchronizedRadioBroadcaster(on_track_end_callback=on_track_end)
self.send_transport_info = send_transport_info_callback
self._cleanup_task: Optional[asyncio.Task] = None
self._shutdown = False
async def start_cleanup_task(self):
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
async def _periodic_cleanup(self):
"""Periodically removes ended sessions to recover memory."""
while not self._shutdown:
await asyncio.sleep(self.SESSION_CLEANUP_INTERVAL)
try:
await self._cleanup_dead_sessions()
gc.collect()
except Exception as e:
logger.error(f"Cleanup error: {e}")
async def _cleanup_dead_sessions(self):
dead_sessions = [
sid for sid, session in self.sessions.items()
if session.state == CallState.ENDED
]
for sid in dead_sessions:
await self.stop_session(sid)
def register_proposed_session(self, sid: str, peer_jid: str):
# Limit the proposed session cache to prevent memory exhaustion attacks.
if len(self._proposed_sessions) > 50:
oldest = list(self._proposed_sessions.keys())[:25]
for key in oldest:
del self._proposed_sessions[key]
self._proposed_sessions[sid] = peer_jid
def clear_proposed_session(self, sid: str):
self._proposed_sessions.pop(sid, None)
def set_audio_source(self, source: str, force_restart: bool = False):
asyncio.create_task(self._broadcaster.set_source(source, force_restart))
async def stop_playback(self):
await self._broadcaster.stop_playback()
def get_session(self, sid: str):
return self.sessions.get(sid)
async def create_session(self, sid, peer_jid) -> JingleSession:
if len(self.sessions) >= self.MAX_SESSIONS:
await self._cleanup_dead_sessions()
if len(self.sessions) >= self.MAX_SESSIONS:
oldest = list(self.sessions.keys())[0]
await self.stop_session(oldest)
servers = []
if self.stun_server:
servers.append(RTCIceServer(urls=self.stun_server))
config = RTCConfiguration(iceServers=servers)
pc = RTCPeerConnection(configuration=config)
session = JingleSession(sid=sid, peer_jid=peer_jid, pc=pc)
self.sessions[sid] = session
@pc.on("icegatheringstatechange")
async def on_gathering_state():
if pc.iceGatheringState == "complete":
session.ice_gathering_complete.set()
@pc.on("connectionstatechange")
async def on_state():
if pc.connectionState == "connected":
session.state = CallState.ACTIVE
elif pc.connectionState in ["failed", "closed", "disconnected"]:
session.state = CallState.ENDED
asyncio.create_task(self._cleanup_session(sid))
return session
async def _cleanup_session(self, sid: str):
await asyncio.sleep(1)
if sid in self.sessions:
session = self.sessions[sid]
if session.state == CallState.ENDED:
await self.stop_session(sid)
def _extract_candidates_from_sdp(self, sdp: str) -> List[Dict]:
"""Parses SDP string to extract ICE candidates as dictionaries."""
candidates = []
for line in sdp.splitlines():
if line.startswith('a=candidate:'):
match = re.match(
r'a=candidate:(\S+)\s+(\d+)\s+(\w+)\s+(\d+)\s+(\S+)\s+(\d+)\s+typ\s+(\w+)(?:\s+raddr\s+(\S+))?\s*(?:rport\s+(\d+))?',
line
)
if match:
foundation, component, protocol, priority, ip, port, typ, raddr, rport = match.groups()
cand = {
'foundation': foundation,
'component': component,
'protocol': protocol.lower(),
'priority': priority,
'ip': ip,
'port': port,
'type': typ,
'generation': '0',
'id': f"cand-{random.randint(1000, 9999)}"
}
if raddr: cand['rel-addr'] = raddr
if rport: cand['rel-port'] = rport
candidates.append(cand)
return candidates
async def handle_session_initiate(self, jingle_xml, peer_jid, our_jid):
"""Handles incoming Jingle session-initiate, creating the WebRTC answer."""
sid = jingle_xml.get('sid')
session = await self.create_session(sid, peer_jid)
sdp = self._jingle_to_sdp(jingle_xml)
offer = RTCSessionDescription(sdp=sdp, type="offer")
track = SynchronizedAudioTrack(self._broadcaster)
session.audio_track = track
session.pc.addTrack(track)
await session.pc.setRemoteDescription(offer)
if session.pending_candidates:
for cand in session.pending_candidates:
try: await session.pc.addIceCandidate(cand)
except: pass
session.pending_candidates.clear()
answer = await session.pc.createAnswer()
# Set to passive to let the other side act as DTLS server if needed.
answer_sdp = answer.sdp.replace('a=setup:active', 'a=setup:passive')
answer = RTCSessionDescription(sdp=answer_sdp, type="answer")
await session.pc.setLocalDescription(answer)
try:
await asyncio.wait_for(session.ice_gathering_complete.wait(), timeout=5.0)
except asyncio.TimeoutError:
pass
if session.pc.localDescription:
session.local_candidates = self._extract_candidates_from_sdp(
session.pc.localDescription.sdp
)
jingle_xml = self._build_session_accept(session, sid, our_jid)
return session, jingle_xml
def _build_session_accept(self, session: JingleSession, sid: str, responder: str) -> str:
"""Constructs the session-accept Jingle XML stanza."""
import xml.etree.ElementTree as ET
root = ET.Element('jingle', {'xmlns': self.NS_JINGLE, 'action': 'session-accept', 'sid': sid, 'responder': responder})
content = ET.SubElement(root, 'content', {'creator': 'initiator', 'name': '0', 'senders': 'both'})
desc = ET.SubElement(content, 'description', {'xmlns': 'urn:xmpp:jingle:apps:rtp:1', 'media': 'audio'})
sdp = session.pc.localDescription.sdp
codecs = {}
for line in sdp.splitlines():
if line.startswith("a=rtpmap:"):
match = re.match(r'a=rtpmap:(\d+)\s+(\S+)/(\d+)(?:/(\d+))?', line)
if match:
codec_id = match.group(1)
codecs[codec_id] = {'name': match.group(2), 'clockrate': match.group(3), 'channels': match.group(4) or '1', 'fmtp': {}}
elif line.startswith("a=fmtp:"):
match = re.match(r'a=fmtp:(\d+)\s+(.+)', line)
if match:
codec_id = match.group(1)
fmtp_params = {}
for param in match.group(2).split(';'):
if '=' in param:
key, val = param.strip().split('=', 1)
fmtp_params[key.strip()] = val.strip()
if codec_id in codecs:
codecs[codec_id]['fmtp'] = fmtp_params
for codec_id, codec_info in codecs.items():
pt_attrs = {'id': codec_id, 'name': codec_info['name'], 'clockrate': codec_info['clockrate']}
if codec_info['channels'] != '1': pt_attrs['channels'] = codec_info['channels']
pt = ET.SubElement(desc, 'payload-type', pt_attrs)
for param_name, param_value in codec_info['fmtp'].items():
ET.SubElement(pt, 'parameter', {'name': param_name, 'value': param_value})
ET.SubElement(desc, 'rtcp-mux')
trans = ET.SubElement(content, 'transport', {'xmlns': self.NS_JINGLE_ICE})
for line in sdp.splitlines():
if line.startswith("a=ice-ufrag:"): trans.set('ufrag', line.split(':')[1].strip())
if line.startswith("a=ice-pwd:"): trans.set('pwd', line.split(':')[1].strip())
if line.startswith("a=fingerprint:"):
parts = line.split(':', 1)[1].strip().split(None, 1)
if len(parts) == 2:
fp = ET.SubElement(trans, 'fingerprint', {'xmlns': 'urn:xmpp:jingle:apps:dtls:0', 'hash': parts[0], 'setup': 'passive'})
fp.text = parts[1]
for cand in session.local_candidates:
cand_elem = ET.SubElement(trans, 'candidate', {
'component': str(cand.get('component', '1')), 'foundation': str(cand.get('foundation', '1')),
'generation': str(cand.get('generation', '0')), 'id': cand.get('id', 'cand-0'),
'ip': cand.get('ip', ''), 'port': str(cand.get('port', '0')),
'priority': str(cand.get('priority', '1')), 'protocol': cand.get('protocol', 'udp'),
'type': cand.get('type', 'host')
})
if 'rel-addr' in cand: cand_elem.set('rel-addr', cand['rel-addr'])
if 'rel-port' in cand: cand_elem.set('rel-port', str(cand['rel-port']))
return ET.tostring(root, encoding='unicode')
async def add_ice_candidate(self, session: JingleSession, xml_cand):
try:
cand = RTCIceCandidate(
foundation=xml_cand.get('foundation', '1'), component=int(xml_cand.get('component', '1')),
protocol=xml_cand.get('protocol', 'udp'), priority=int(xml_cand.get('priority', '1')),
ip=xml_cand.get('ip'), port=int(xml_cand.get('port')), type=xml_cand.get('type', 'host'),
sdpMid="0", sdpMLineIndex=0
)
if session.pc.remoteDescription: await session.pc.addIceCandidate(cand)
else: session.pending_candidates.append(cand)
except Exception as e:
logger.debug(f"Candidate error: {e}")
async def stop_session(self, sid):
if sid in self.sessions:
s = self.sessions.pop(sid)
if s.audio_track:
try: s.audio_track.stop()
except: pass
s.audio_track = None
if s.pc:
try: await s.pc.close()
except: pass
s.pc = None
s.cleanup()
del s
logger.info(f"Stopped session {sid}")
self.clear_proposed_session(sid)
gc.collect()
async def end_all_sessions(self):
"""Terminates all active sessions and shuts down the broadcaster."""
self._shutdown = True
self._broadcaster.shutdown()
if self._cleanup_task:
self._cleanup_task.cancel()
try: await self._cleanup_task
except asyncio.CancelledError: pass
self._cleanup_task = None
for sid in list(self.sessions.keys()):
await self.stop_session(sid)
self._proposed_sessions.clear()
gc.collect()
def get_active_sessions(self):
return [s for s in self.sessions.values() if s.state == CallState.ACTIVE]
def _jingle_to_sdp(self, xml):
"""Converts Jingle XML format to standard SDP text."""
lines = ["v=0", "o=- 0 0 IN IP4 0.0.0.0", "s=-", "t=0 0", "a=group:BUNDLE 0", "a=msid-semantic: WMS *"]
content = xml.find(f"{{urn:xmpp:jingle:1}}content")
desc = content.find(f"{{urn:xmpp:jingle:apps:rtp:1}}description")
trans = content.find(f"{{urn:xmpp:jingle:transports:ice-udp:1}}transport")
payloads = [pt.get('id') for pt in desc.findall(f"{{urn:xmpp:jingle:apps:rtp:1}}payload-type")]
lines.append(f"m=audio 9 UDP/TLS/RTP/SAVPF {' '.join(payloads)}")
lines.append("c=IN IP4 0.0.0.0")
lines.append("a=rtcp:9 IN IP4 0.0.0.0")
lines.append(f"a=ice-ufrag:{trans.get('ufrag')}")
lines.append(f"a=ice-pwd:{trans.get('pwd')}")
fp = trans.find(f"{{urn:xmpp:jingle:apps:dtls:0}}fingerprint")
if fp is not None:
lines.append(f"a=fingerprint:{fp.get('hash')} {fp.text}")
setup = fp.get('setup', 'actpass')
if setup == 'actpass': setup = 'active'
lines.append(f"a=setup:{setup}")
lines.append("a=mid:0")
lines.append("a=sendrecv")
lines.append("a=rtcp-mux")
for pt in desc.findall(f"{{urn:xmpp:jingle:apps:rtp:1}}payload-type"):
lines.append(f"a=rtpmap:{pt.get('id')} {pt.get('name')}/{pt.get('clockrate')}/{pt.get('channels', '1')}")
return "\r\n".join(lines)