#!/usr/bin/env python3 """ WebRTC and Jingle handling module. This module provides the core logic for establishing Audio/Video sessions via XMPP Jingle (XEP-0166) and WebRTC. It manages the RTCPeerConnections, ICE candidate negotiation, and audio broadcasting. Key features: - Shared memory audio broadcasting to multiple subscribers. - Synchronization of audio streams. - Explicit cleanup routines to prevent memory leaks with aiortc objects. """ import asyncio import gc import random import logging import os import re import weakref from typing import Optional, Dict, List, Any, Callable from dataclasses import dataclass, field from enum import Enum from fractions import Fraction logger = logging.getLogger(__name__) AIORTC_AVAILABLE = False try: from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration, RTCIceServer from aiortc import MediaStreamTrack, RTCIceCandidate from aiortc.contrib.media import MediaPlayer import av AIORTC_AVAILABLE = True except ImportError as e: logger.error(f"aiortc import failed: {e}") class CallState(Enum): """Represents the lifecycle state of a Jingle session.""" IDLE = "idle" ACTIVE = "active" ENDED = "ended" @dataclass(slots=True) class JingleSession: """ Stores state for a single WebRTC session. Attributes: sid: Session ID. peer_jid: JID of the remote peer. pc: The RTCPeerConnection object. audio_track: The specific audio track associated with this session. state: Current call state. pending_candidates: ICE candidates received before remote description was set. local_candidates: ICE candidates generated locally. ice_gathering_complete: Event signaling completion of ICE gathering. """ sid: str peer_jid: str pc: Any = None audio_track: Any = None state: CallState = CallState.IDLE pending_candidates: List[Any] = field(default_factory=list) local_candidates: List[Dict] = field(default_factory=list) ice_gathering_complete: asyncio.Event = field(default_factory=asyncio.Event) def cleanup(self): """Releases references to heavy objects to aid garbage collection.""" self.pending_candidates.clear() self.local_candidates.clear() self.pc = None self.audio_track = None class SynchronizedRadioBroadcaster: """ Central audio broadcaster using a producer-consumer model. This class decodes a single audio stream and distributes the frames to multiple subscribers (SynchronizedAudioTrack instances) simultaneously. It uses weak references to subscribers to avoid reference cycles. """ MAX_SUBSCRIBERS = 10 SAMPLES_PER_FRAME = 960 CLEANUP_INTERVAL = 200 def __init__(self, on_track_end_callback: Optional[Callable] = None): self._player: Optional[MediaPlayer] = None self._current_source: Optional[str] = None self._lock = asyncio.Lock() # Use weakrefs to allow tracks to be GC'd if the connection drops unexpectedly. self._subscribers: List[weakref.ref] = [] self._broadcast_task: Optional[asyncio.Task] = None self._pts = 0 self._time_base = Fraction(1, 48000) self.on_track_end = on_track_end_callback self._shutdown = False self._track_end_fired = False self._stopped = False self._frame_counter = 0 # Persist resampler to avoid high CPU usage from constant reallocation. self._resampler = None def _get_silence_frame(self) -> 'av.AudioFrame': """Generates a silent audio frame to keep the RTP stream alive.""" frame = av.AudioFrame(format='s16', layout='stereo', samples=self.SAMPLES_PER_FRAME) frame.sample_rate = 48000 frame.time_base = self._time_base for plane in frame.planes: plane.update(bytes(plane.buffer_size)) frame.pts = self._pts self._pts += self.SAMPLES_PER_FRAME return frame async def set_source(self, source: str, force_restart: bool = False): """ Updates the audio source, restarting the broadcast loop if necessary. Args: source: URI or path to the audio file. force_restart: If True, restarts player even if source is identical. """ async with self._lock: normalized_source = source if source and not source.startswith(('http:', 'https:', 'rtmp:', 'rtsp:')): normalized_source = os.path.abspath(source) broadcast_is_running = ( self._broadcast_task and not self._broadcast_task.done() ) if (normalized_source == self._current_source and self._player and broadcast_is_running and not force_restart): return logger.info(f"Changing broadcast: {os.path.basename(normalized_source) if normalized_source else 'None'}") await self._stop_broadcast_task() await self._close_player() self._current_source = normalized_source self._pts = 0 self._track_end_fired = False self._stopped = False self._frame_counter = 0 # Reset resampler context to prevent timestamp drift or format mismatch. self._resampler = None if not source: return try: logger.info(f"Starting broadcast: {source}") # Use a buffer for network streams to reduce jitter. options = {'rtbufsize': '8M'} if source.startswith(('http', 'rtmp', 'rtsp')) else {} self._player = MediaPlayer(source, options=options) self._broadcast_task = asyncio.create_task(self._broadcast_loop()) except Exception as e: logger.error(f"MediaPlayer Error: {e}") self._player = None async def _stop_broadcast_task(self): """Cancels and awaits the broadcast task.""" if self._broadcast_task: self._broadcast_task.cancel() try: await asyncio.wait_for(self._broadcast_task, timeout=2.0) except (asyncio.CancelledError, asyncio.TimeoutError): pass self._broadcast_task = None async def _close_player(self): """Safely closes the media player and underlying container.""" self._resampler = None if self._player: try: old_player = self._player self._player = None if hasattr(old_player, 'audio') and old_player.audio: try: old_player.audio.stop() except: pass if hasattr(old_player, '_container') and old_player._container: try: old_player._container.close() except: pass del old_player except Exception as e: logger.debug(f"Error closing player: {e}") gc.collect() async def stop_playback(self): """Stops the broadcast and clears the current source.""" async with self._lock: self._stopped = True await self._stop_broadcast_task() await self._close_player() self._current_source = None logger.info("Playback stopped") async def _broadcast_loop(self): """ Main loop: decodes frames, resamples them, and distributes to subscribers. Handles errors and End-Of-File (EOF) events. """ last_error_time = 0 error_count_window = 0 try: while self._player and not self._shutdown and not self._stopped: self._frame_counter += 1 # Periodically clean up dead weak references. if self._frame_counter % self.CLEANUP_INTERVAL == 0: self._cleanup_subscribers() # Pause decoding if no one is listening to save CPU. if not self._get_active_subscribers(): await asyncio.sleep(0.1) continue frame = None try: if not self._player or not self._player.audio: break frame = await asyncio.wait_for( self._player.audio.recv(), timeout=0.5 ) current_time = asyncio.get_event_loop().time() if current_time - last_error_time > 5.0: error_count_window = 0 # Ensure audio is standard stereo 48kHz. if frame.format.name != 's16' or frame.sample_rate != 48000: if self._resampler is None: self._resampler = av.AudioResampler(format='s16', layout='stereo', rate=48000) resampled_frames = self._resampler.resample(frame) # Explicitly release the original frame to free C-level memory. del frame if resampled_frames: frame = resampled_frames[0] else: continue # Rewrite timestamps for the continuous stream. frame.pts = self._pts frame.time_base = self._time_base self._pts += frame.samples await self._distribute_frame(frame) except asyncio.TimeoutError: # Send silence on network timeout to prevent RTP timeout. s_frame = self._get_silence_frame() await self._distribute_frame(s_frame) del s_frame except (av.error.EOFError, StopIteration, StopAsyncIteration): logger.info("Track ended (EOF)") break except Exception as e: error_msg = str(e) if "MediaStreamError" in str(type(e).__name__): break logger.debug(f"Broadcast frame warning: {e}") error_count_window += 1 if error_count_window >= 20: logger.error("Too many errors, stopping broadcast") break await asyncio.sleep(0.1) finally: if frame is not None: del frame # Notify the main bot that the track finished. if self.on_track_end and not self._shutdown and not self._track_end_fired and not self._stopped: self._track_end_fired = True try: result = self.on_track_end() if asyncio.iscoroutine(result): asyncio.create_task(result) except Exception as e: logger.error(f"Track end callback error: {e}") except asyncio.CancelledError: pass except Exception as e: logger.error(f"Broadcast loop error: {e}") finally: self._resampler = None gc.collect() def _cleanup_subscribers(self): self._subscribers = [ref for ref in self._subscribers if ref() is not None] def _get_active_subscribers(self) -> List['SynchronizedAudioTrack']: active = [] for ref in self._subscribers: track = ref() if track is not None and track._active: active.append(track) return active async def _distribute_frame(self, frame: 'av.AudioFrame'): subscribers = self._get_active_subscribers() if not subscribers: return # Dispatch frames to all tracks concurrently. tasks = [sub._receive_frame(frame) for sub in subscribers] if tasks: await asyncio.gather(*tasks, return_exceptions=True) def subscribe(self, track: 'SynchronizedAudioTrack'): self._cleanup_subscribers() if len(self._subscribers) >= self.MAX_SUBSCRIBERS: return self._subscribers.append(weakref.ref(track)) def unsubscribe(self, track: 'SynchronizedAudioTrack'): self._subscribers = [ref for ref in self._subscribers if ref() is not None and ref() is not track] def shutdown(self): self._shutdown = True self._resampler = None gc.collect() class SynchronizedAudioTrack(MediaStreamTrack): """ A MediaStreamTrack that receives frames from the broadcaster queue. """ kind = "audio" MAX_QUEUE_SIZE = 3 def __init__(self, broadcaster: SynchronizedRadioBroadcaster): super().__init__() self._broadcaster = broadcaster self._frame_queue: asyncio.Queue = asyncio.Queue(maxsize=self.MAX_QUEUE_SIZE) self._active = True self._silence_frame: Optional[av.AudioFrame] = None self._silence_pts = 0 self._broadcaster.subscribe(self) def _get_silence_frame(self) -> 'av.AudioFrame': """Lazily creates and reuses a silence frame.""" if self._silence_frame is None: self._silence_frame = av.AudioFrame(format='s16', layout='stereo', samples=960) self._silence_frame.sample_rate = 48000 self._silence_frame.time_base = Fraction(1, 48000) for plane in self._silence_frame.planes: plane.update(bytes(plane.buffer_size)) self._silence_frame.pts = self._silence_pts self._silence_pts += 960 return self._silence_frame async def _receive_frame(self, frame: 'av.AudioFrame'): if not self._active: return try: # If the queue is full, drop the oldest frame to reduce latency. if self._frame_queue.full(): try: old_frame = self._frame_queue.get_nowait() del old_frame except asyncio.QueueEmpty: pass self._frame_queue.put_nowait(frame) except Exception: pass async def recv(self): """Called by aiortc to pull the next frame.""" if not self._active: raise Exception("Track stopped") try: frame = await asyncio.wait_for( self._frame_queue.get(), timeout=0.05 ) return frame except asyncio.TimeoutError: return self._get_silence_frame() def stop(self): self._active = False self._broadcaster.unsubscribe(self) # Drain queue to release frame references. while not self._frame_queue.empty(): try: frame = self._frame_queue.get_nowait() del frame except asyncio.QueueEmpty: break if self._silence_frame is not None: del self._silence_frame self._silence_frame = None super().stop() class JingleWebRTCHandler: """ Manages the XMPP Jingle negotiation and WebRTC sessions. Translates between Jingle XML stanzas and SDP. """ NS_JINGLE = "urn:xmpp:jingle:1" NS_JINGLE_ICE = "urn:xmpp:jingle:transports:ice-udp:1" MAX_SESSIONS = 20 SESSION_CLEANUP_INTERVAL = 60 def __init__(self, stun_server: str, send_transport_info_callback: Optional[Callable] = None, on_track_end: Optional[Callable] = None): if stun_server and not stun_server.startswith(('stun:', 'turn:', 'stuns:')): self.stun_server = f"stun:{stun_server}" else: self.stun_server = stun_server self.sessions: Dict[str, JingleSession] = {} self._proposed_sessions: Dict[str, str] = {} self._broadcaster = SynchronizedRadioBroadcaster(on_track_end_callback=on_track_end) self.send_transport_info = send_transport_info_callback self._cleanup_task: Optional[asyncio.Task] = None self._shutdown = False async def start_cleanup_task(self): self._cleanup_task = asyncio.create_task(self._periodic_cleanup()) async def _periodic_cleanup(self): """Periodically removes ended sessions to recover memory.""" while not self._shutdown: await asyncio.sleep(self.SESSION_CLEANUP_INTERVAL) try: await self._cleanup_dead_sessions() gc.collect() except Exception as e: logger.error(f"Cleanup error: {e}") async def _cleanup_dead_sessions(self): dead_sessions = [ sid for sid, session in self.sessions.items() if session.state == CallState.ENDED ] for sid in dead_sessions: await self.stop_session(sid) def register_proposed_session(self, sid: str, peer_jid: str): # Limit the proposed session cache to prevent memory exhaustion attacks. if len(self._proposed_sessions) > 50: oldest = list(self._proposed_sessions.keys())[:25] for key in oldest: del self._proposed_sessions[key] self._proposed_sessions[sid] = peer_jid def clear_proposed_session(self, sid: str): self._proposed_sessions.pop(sid, None) def set_audio_source(self, source: str, force_restart: bool = False): asyncio.create_task(self._broadcaster.set_source(source, force_restart)) async def stop_playback(self): await self._broadcaster.stop_playback() def get_session(self, sid: str): return self.sessions.get(sid) async def create_session(self, sid, peer_jid) -> JingleSession: if len(self.sessions) >= self.MAX_SESSIONS: await self._cleanup_dead_sessions() if len(self.sessions) >= self.MAX_SESSIONS: oldest = list(self.sessions.keys())[0] await self.stop_session(oldest) servers = [] if self.stun_server: servers.append(RTCIceServer(urls=self.stun_server)) config = RTCConfiguration(iceServers=servers) pc = RTCPeerConnection(configuration=config) session = JingleSession(sid=sid, peer_jid=peer_jid, pc=pc) self.sessions[sid] = session @pc.on("icegatheringstatechange") async def on_gathering_state(): if pc.iceGatheringState == "complete": session.ice_gathering_complete.set() @pc.on("connectionstatechange") async def on_state(): if pc.connectionState == "connected": session.state = CallState.ACTIVE elif pc.connectionState in ["failed", "closed", "disconnected"]: session.state = CallState.ENDED asyncio.create_task(self._cleanup_session(sid)) return session async def _cleanup_session(self, sid: str): await asyncio.sleep(1) if sid in self.sessions: session = self.sessions[sid] if session.state == CallState.ENDED: await self.stop_session(sid) def _extract_candidates_from_sdp(self, sdp: str) -> List[Dict]: """Parses SDP string to extract ICE candidates as dictionaries.""" candidates = [] for line in sdp.splitlines(): if line.startswith('a=candidate:'): match = re.match( r'a=candidate:(\S+)\s+(\d+)\s+(\w+)\s+(\d+)\s+(\S+)\s+(\d+)\s+typ\s+(\w+)(?:\s+raddr\s+(\S+))?\s*(?:rport\s+(\d+))?', line ) if match: foundation, component, protocol, priority, ip, port, typ, raddr, rport = match.groups() cand = { 'foundation': foundation, 'component': component, 'protocol': protocol.lower(), 'priority': priority, 'ip': ip, 'port': port, 'type': typ, 'generation': '0', 'id': f"cand-{random.randint(1000, 9999)}" } if raddr: cand['rel-addr'] = raddr if rport: cand['rel-port'] = rport candidates.append(cand) return candidates async def handle_session_initiate(self, jingle_xml, peer_jid, our_jid): """Handles incoming Jingle session-initiate, creating the WebRTC answer.""" sid = jingle_xml.get('sid') session = await self.create_session(sid, peer_jid) sdp = self._jingle_to_sdp(jingle_xml) offer = RTCSessionDescription(sdp=sdp, type="offer") track = SynchronizedAudioTrack(self._broadcaster) session.audio_track = track session.pc.addTrack(track) await session.pc.setRemoteDescription(offer) if session.pending_candidates: for cand in session.pending_candidates: try: await session.pc.addIceCandidate(cand) except: pass session.pending_candidates.clear() answer = await session.pc.createAnswer() # Set to passive to let the other side act as DTLS server if needed. answer_sdp = answer.sdp.replace('a=setup:active', 'a=setup:passive') answer = RTCSessionDescription(sdp=answer_sdp, type="answer") await session.pc.setLocalDescription(answer) try: await asyncio.wait_for(session.ice_gathering_complete.wait(), timeout=5.0) except asyncio.TimeoutError: pass if session.pc.localDescription: session.local_candidates = self._extract_candidates_from_sdp( session.pc.localDescription.sdp ) jingle_xml = self._build_session_accept(session, sid, our_jid) return session, jingle_xml def _build_session_accept(self, session: JingleSession, sid: str, responder: str) -> str: """Constructs the session-accept Jingle XML stanza.""" import xml.etree.ElementTree as ET root = ET.Element('jingle', {'xmlns': self.NS_JINGLE, 'action': 'session-accept', 'sid': sid, 'responder': responder}) content = ET.SubElement(root, 'content', {'creator': 'initiator', 'name': '0', 'senders': 'both'}) desc = ET.SubElement(content, 'description', {'xmlns': 'urn:xmpp:jingle:apps:rtp:1', 'media': 'audio'}) sdp = session.pc.localDescription.sdp codecs = {} for line in sdp.splitlines(): if line.startswith("a=rtpmap:"): match = re.match(r'a=rtpmap:(\d+)\s+(\S+)/(\d+)(?:/(\d+))?', line) if match: codec_id = match.group(1) codecs[codec_id] = {'name': match.group(2), 'clockrate': match.group(3), 'channels': match.group(4) or '1', 'fmtp': {}} elif line.startswith("a=fmtp:"): match = re.match(r'a=fmtp:(\d+)\s+(.+)', line) if match: codec_id = match.group(1) fmtp_params = {} for param in match.group(2).split(';'): if '=' in param: key, val = param.strip().split('=', 1) fmtp_params[key.strip()] = val.strip() if codec_id in codecs: codecs[codec_id]['fmtp'] = fmtp_params for codec_id, codec_info in codecs.items(): pt_attrs = {'id': codec_id, 'name': codec_info['name'], 'clockrate': codec_info['clockrate']} if codec_info['channels'] != '1': pt_attrs['channels'] = codec_info['channels'] pt = ET.SubElement(desc, 'payload-type', pt_attrs) for param_name, param_value in codec_info['fmtp'].items(): ET.SubElement(pt, 'parameter', {'name': param_name, 'value': param_value}) ET.SubElement(desc, 'rtcp-mux') trans = ET.SubElement(content, 'transport', {'xmlns': self.NS_JINGLE_ICE}) for line in sdp.splitlines(): if line.startswith("a=ice-ufrag:"): trans.set('ufrag', line.split(':')[1].strip()) if line.startswith("a=ice-pwd:"): trans.set('pwd', line.split(':')[1].strip()) if line.startswith("a=fingerprint:"): parts = line.split(':', 1)[1].strip().split(None, 1) if len(parts) == 2: fp = ET.SubElement(trans, 'fingerprint', {'xmlns': 'urn:xmpp:jingle:apps:dtls:0', 'hash': parts[0], 'setup': 'passive'}) fp.text = parts[1] for cand in session.local_candidates: cand_elem = ET.SubElement(trans, 'candidate', { 'component': str(cand.get('component', '1')), 'foundation': str(cand.get('foundation', '1')), 'generation': str(cand.get('generation', '0')), 'id': cand.get('id', 'cand-0'), 'ip': cand.get('ip', ''), 'port': str(cand.get('port', '0')), 'priority': str(cand.get('priority', '1')), 'protocol': cand.get('protocol', 'udp'), 'type': cand.get('type', 'host') }) if 'rel-addr' in cand: cand_elem.set('rel-addr', cand['rel-addr']) if 'rel-port' in cand: cand_elem.set('rel-port', str(cand['rel-port'])) return ET.tostring(root, encoding='unicode') async def add_ice_candidate(self, session: JingleSession, xml_cand): try: cand = RTCIceCandidate( foundation=xml_cand.get('foundation', '1'), component=int(xml_cand.get('component', '1')), protocol=xml_cand.get('protocol', 'udp'), priority=int(xml_cand.get('priority', '1')), ip=xml_cand.get('ip'), port=int(xml_cand.get('port')), type=xml_cand.get('type', 'host'), sdpMid="0", sdpMLineIndex=0 ) if session.pc.remoteDescription: await session.pc.addIceCandidate(cand) else: session.pending_candidates.append(cand) except Exception as e: logger.debug(f"Candidate error: {e}") async def stop_session(self, sid): if sid in self.sessions: s = self.sessions.pop(sid) if s.audio_track: try: s.audio_track.stop() except: pass s.audio_track = None if s.pc: try: await s.pc.close() except: pass s.pc = None s.cleanup() del s logger.info(f"Stopped session {sid}") self.clear_proposed_session(sid) gc.collect() async def end_all_sessions(self): """Terminates all active sessions and shuts down the broadcaster.""" self._shutdown = True self._broadcaster.shutdown() if self._cleanup_task: self._cleanup_task.cancel() try: await self._cleanup_task except asyncio.CancelledError: pass self._cleanup_task = None for sid in list(self.sessions.keys()): await self.stop_session(sid) self._proposed_sessions.clear() gc.collect() def get_active_sessions(self): return [s for s in self.sessions.values() if s.state == CallState.ACTIVE] def _jingle_to_sdp(self, xml): """Converts Jingle XML format to standard SDP text.""" lines = ["v=0", "o=- 0 0 IN IP4 0.0.0.0", "s=-", "t=0 0", "a=group:BUNDLE 0", "a=msid-semantic: WMS *"] content = xml.find(f"{{urn:xmpp:jingle:1}}content") desc = content.find(f"{{urn:xmpp:jingle:apps:rtp:1}}description") trans = content.find(f"{{urn:xmpp:jingle:transports:ice-udp:1}}transport") payloads = [pt.get('id') for pt in desc.findall(f"{{urn:xmpp:jingle:apps:rtp:1}}payload-type")] lines.append(f"m=audio 9 UDP/TLS/RTP/SAVPF {' '.join(payloads)}") lines.append("c=IN IP4 0.0.0.0") lines.append("a=rtcp:9 IN IP4 0.0.0.0") lines.append(f"a=ice-ufrag:{trans.get('ufrag')}") lines.append(f"a=ice-pwd:{trans.get('pwd')}") fp = trans.find(f"{{urn:xmpp:jingle:apps:dtls:0}}fingerprint") if fp is not None: lines.append(f"a=fingerprint:{fp.get('hash')} {fp.text}") setup = fp.get('setup', 'actpass') if setup == 'actpass': setup = 'active' lines.append(f"a=setup:{setup}") lines.append("a=mid:0") lines.append("a=sendrecv") lines.append("a=rtcp-mux") for pt in desc.findall(f"{{urn:xmpp:jingle:apps:rtp:1}}payload-type"): lines.append(f"a=rtpmap:{pt.get('id')} {pt.get('name')}/{pt.get('clockrate')}/{pt.get('channels', '1')}") return "\r\n".join(lines)