diff --git a/jingle_webrtc.py b/jingle_webrtc.py new file mode 100644 index 0000000..583bb6b --- /dev/null +++ b/jingle_webrtc.py @@ -0,0 +1,742 @@ +#!/usr/bin/env python3 +""" +WebRTC and Jingle handling module. + +This module provides the core logic for establishing Audio/Video sessions via +XMPP Jingle (XEP-0166) and WebRTC. It manages the RTCPeerConnections, +ICE candidate negotiation, and audio broadcasting. + +Key features: +- Shared memory audio broadcasting to multiple subscribers. +- Synchronization of audio streams. +- Explicit cleanup routines to prevent memory leaks with aiortc objects. +""" + +import asyncio +import gc +import random +import logging +import os +import re +import weakref +from typing import Optional, Dict, List, Any, Callable +from dataclasses import dataclass, field +from enum import Enum +from fractions import Fraction + +logger = logging.getLogger(__name__) + +AIORTC_AVAILABLE = False +try: + from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration, RTCIceServer + from aiortc import MediaStreamTrack, RTCIceCandidate + from aiortc.contrib.media import MediaPlayer + import av + AIORTC_AVAILABLE = True +except ImportError as e: + logger.error(f"aiortc import failed: {e}") + + +class CallState(Enum): + """Represents the lifecycle state of a Jingle session.""" + IDLE = "idle" + ACTIVE = "active" + ENDED = "ended" + + +@dataclass(slots=True) +class JingleSession: + """ + Stores state for a single WebRTC session. + + Attributes: + sid: Session ID. + peer_jid: JID of the remote peer. + pc: The RTCPeerConnection object. + audio_track: The specific audio track associated with this session. + state: Current call state. + pending_candidates: ICE candidates received before remote description was set. + local_candidates: ICE candidates generated locally. + ice_gathering_complete: Event signaling completion of ICE gathering. + """ + sid: str + peer_jid: str + pc: Any = None + audio_track: Any = None + state: CallState = CallState.IDLE + pending_candidates: List[Any] = field(default_factory=list) + local_candidates: List[Dict] = field(default_factory=list) + ice_gathering_complete: asyncio.Event = field(default_factory=asyncio.Event) + + def cleanup(self): + """Releases references to heavy objects to aid garbage collection.""" + self.pending_candidates.clear() + self.local_candidates.clear() + self.pc = None + self.audio_track = None + + +class SynchronizedRadioBroadcaster: + """ + Central audio broadcaster using a producer-consumer model. + + This class decodes a single audio stream and distributes the frames to + multiple subscribers (SynchronizedAudioTrack instances) simultaneously. + It uses weak references to subscribers to avoid reference cycles. + """ + + MAX_SUBSCRIBERS = 10 + SAMPLES_PER_FRAME = 960 + CLEANUP_INTERVAL = 200 + + def __init__(self, on_track_end_callback: Optional[Callable] = None): + self._player: Optional[MediaPlayer] = None + self._current_source: Optional[str] = None + self._lock = asyncio.Lock() + # Use weakrefs to allow tracks to be GC'd if the connection drops unexpectedly. + self._subscribers: List[weakref.ref] = [] + self._broadcast_task: Optional[asyncio.Task] = None + self._pts = 0 + self._time_base = Fraction(1, 48000) + self.on_track_end = on_track_end_callback + self._shutdown = False + self._track_end_fired = False + self._stopped = False + self._frame_counter = 0 + + # Persist resampler to avoid high CPU usage from constant reallocation. + self._resampler = None + + def _get_silence_frame(self) -> 'av.AudioFrame': + """Generates a silent audio frame to keep the RTP stream alive.""" + frame = av.AudioFrame(format='s16', layout='stereo', samples=self.SAMPLES_PER_FRAME) + frame.sample_rate = 48000 + frame.time_base = self._time_base + for plane in frame.planes: + plane.update(bytes(plane.buffer_size)) + frame.pts = self._pts + self._pts += self.SAMPLES_PER_FRAME + return frame + + async def set_source(self, source: str, force_restart: bool = False): + """ + Updates the audio source, restarting the broadcast loop if necessary. + + Args: + source: URI or path to the audio file. + force_restart: If True, restarts player even if source is identical. + """ + async with self._lock: + normalized_source = source + if source and not source.startswith(('http:', 'https:', 'rtmp:', 'rtsp:')): + normalized_source = os.path.abspath(source) + + broadcast_is_running = ( + self._broadcast_task and + not self._broadcast_task.done() + ) + + if (normalized_source == self._current_source and + self._player and + broadcast_is_running and + not force_restart): + return + + logger.info(f"Changing broadcast: {os.path.basename(normalized_source) if normalized_source else 'None'}") + + await self._stop_broadcast_task() + await self._close_player() + + self._current_source = normalized_source + self._pts = 0 + self._track_end_fired = False + self._stopped = False + self._frame_counter = 0 + + # Reset resampler context to prevent timestamp drift or format mismatch. + self._resampler = None + + if not source: + return + + try: + logger.info(f"Starting broadcast: {source}") + # Use a buffer for network streams to reduce jitter. + options = {'rtbufsize': '8M'} if source.startswith(('http', 'rtmp', 'rtsp')) else {} + self._player = MediaPlayer(source, options=options) + self._broadcast_task = asyncio.create_task(self._broadcast_loop()) + + except Exception as e: + logger.error(f"MediaPlayer Error: {e}") + self._player = None + + async def _stop_broadcast_task(self): + """Cancels and awaits the broadcast task.""" + if self._broadcast_task: + self._broadcast_task.cancel() + try: + await asyncio.wait_for(self._broadcast_task, timeout=2.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass + self._broadcast_task = None + + async def _close_player(self): + """Safely closes the media player and underlying container.""" + self._resampler = None + + if self._player: + try: + old_player = self._player + self._player = None + + if hasattr(old_player, 'audio') and old_player.audio: + try: old_player.audio.stop() + except: pass + + if hasattr(old_player, '_container') and old_player._container: + try: old_player._container.close() + except: pass + + del old_player + except Exception as e: + logger.debug(f"Error closing player: {e}") + + gc.collect() + + async def stop_playback(self): + """Stops the broadcast and clears the current source.""" + async with self._lock: + self._stopped = True + await self._stop_broadcast_task() + await self._close_player() + self._current_source = None + logger.info("Playback stopped") + + async def _broadcast_loop(self): + """ + Main loop: decodes frames, resamples them, and distributes to subscribers. + Handles errors and End-Of-File (EOF) events. + """ + last_error_time = 0 + error_count_window = 0 + + try: + while self._player and not self._shutdown and not self._stopped: + self._frame_counter += 1 + + # Periodically clean up dead weak references. + if self._frame_counter % self.CLEANUP_INTERVAL == 0: + self._cleanup_subscribers() + + # Pause decoding if no one is listening to save CPU. + if not self._get_active_subscribers(): + await asyncio.sleep(0.1) + continue + + frame = None + try: + if not self._player or not self._player.audio: + break + + frame = await asyncio.wait_for( + self._player.audio.recv(), + timeout=0.5 + ) + + current_time = asyncio.get_event_loop().time() + if current_time - last_error_time > 5.0: + error_count_window = 0 + + # Ensure audio is standard stereo 48kHz. + if frame.format.name != 's16' or frame.sample_rate != 48000: + if self._resampler is None: + self._resampler = av.AudioResampler(format='s16', layout='stereo', rate=48000) + + resampled_frames = self._resampler.resample(frame) + + # Explicitly release the original frame to free C-level memory. + del frame + + if resampled_frames: + frame = resampled_frames[0] + else: + continue + + # Rewrite timestamps for the continuous stream. + frame.pts = self._pts + frame.time_base = self._time_base + self._pts += frame.samples + + await self._distribute_frame(frame) + + except asyncio.TimeoutError: + # Send silence on network timeout to prevent RTP timeout. + s_frame = self._get_silence_frame() + await self._distribute_frame(s_frame) + del s_frame + + except (av.error.EOFError, StopIteration, StopAsyncIteration): + logger.info("Track ended (EOF)") + break + + except Exception as e: + error_msg = str(e) + if "MediaStreamError" in str(type(e).__name__): + break + + logger.debug(f"Broadcast frame warning: {e}") + error_count_window += 1 + if error_count_window >= 20: + logger.error("Too many errors, stopping broadcast") + break + + await asyncio.sleep(0.1) + + finally: + if frame is not None: + del frame + + # Notify the main bot that the track finished. + if self.on_track_end and not self._shutdown and not self._track_end_fired and not self._stopped: + self._track_end_fired = True + try: + result = self.on_track_end() + if asyncio.iscoroutine(result): + asyncio.create_task(result) + except Exception as e: + logger.error(f"Track end callback error: {e}") + + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"Broadcast loop error: {e}") + finally: + self._resampler = None + gc.collect() + + def _cleanup_subscribers(self): + self._subscribers = [ref for ref in self._subscribers if ref() is not None] + + def _get_active_subscribers(self) -> List['SynchronizedAudioTrack']: + active = [] + for ref in self._subscribers: + track = ref() + if track is not None and track._active: + active.append(track) + return active + + async def _distribute_frame(self, frame: 'av.AudioFrame'): + subscribers = self._get_active_subscribers() + if not subscribers: + return + + # Dispatch frames to all tracks concurrently. + tasks = [sub._receive_frame(frame) for sub in subscribers] + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + def subscribe(self, track: 'SynchronizedAudioTrack'): + self._cleanup_subscribers() + if len(self._subscribers) >= self.MAX_SUBSCRIBERS: + return + self._subscribers.append(weakref.ref(track)) + + def unsubscribe(self, track: 'SynchronizedAudioTrack'): + self._subscribers = [ref for ref in self._subscribers + if ref() is not None and ref() is not track] + + def shutdown(self): + self._shutdown = True + self._resampler = None + gc.collect() + + +class SynchronizedAudioTrack(MediaStreamTrack): + """ + A MediaStreamTrack that receives frames from the broadcaster queue. + """ + kind = "audio" + MAX_QUEUE_SIZE = 3 + + def __init__(self, broadcaster: SynchronizedRadioBroadcaster): + super().__init__() + self._broadcaster = broadcaster + self._frame_queue: asyncio.Queue = asyncio.Queue(maxsize=self.MAX_QUEUE_SIZE) + self._active = True + + self._silence_frame: Optional[av.AudioFrame] = None + self._silence_pts = 0 + + self._broadcaster.subscribe(self) + + def _get_silence_frame(self) -> 'av.AudioFrame': + """Lazily creates and reuses a silence frame.""" + if self._silence_frame is None: + self._silence_frame = av.AudioFrame(format='s16', layout='stereo', samples=960) + self._silence_frame.sample_rate = 48000 + self._silence_frame.time_base = Fraction(1, 48000) + for plane in self._silence_frame.planes: + plane.update(bytes(plane.buffer_size)) + + self._silence_frame.pts = self._silence_pts + self._silence_pts += 960 + return self._silence_frame + + async def _receive_frame(self, frame: 'av.AudioFrame'): + if not self._active: + return + + try: + # If the queue is full, drop the oldest frame to reduce latency. + if self._frame_queue.full(): + try: + old_frame = self._frame_queue.get_nowait() + del old_frame + except asyncio.QueueEmpty: + pass + + self._frame_queue.put_nowait(frame) + except Exception: + pass + + async def recv(self): + """Called by aiortc to pull the next frame.""" + if not self._active: + raise Exception("Track stopped") + + try: + frame = await asyncio.wait_for( + self._frame_queue.get(), + timeout=0.05 + ) + return frame + except asyncio.TimeoutError: + return self._get_silence_frame() + + def stop(self): + self._active = False + self._broadcaster.unsubscribe(self) + + # Drain queue to release frame references. + while not self._frame_queue.empty(): + try: + frame = self._frame_queue.get_nowait() + del frame + except asyncio.QueueEmpty: + break + + if self._silence_frame is not None: + del self._silence_frame + self._silence_frame = None + + super().stop() + + +class JingleWebRTCHandler: + """ + Manages the XMPP Jingle negotiation and WebRTC sessions. + Translates between Jingle XML stanzas and SDP. + """ + NS_JINGLE = "urn:xmpp:jingle:1" + NS_JINGLE_ICE = "urn:xmpp:jingle:transports:ice-udp:1" + + MAX_SESSIONS = 20 + SESSION_CLEANUP_INTERVAL = 60 + + def __init__(self, stun_server: str, send_transport_info_callback: Optional[Callable] = None, on_track_end: Optional[Callable] = None): + if stun_server and not stun_server.startswith(('stun:', 'turn:', 'stuns:')): + self.stun_server = f"stun:{stun_server}" + else: + self.stun_server = stun_server + + self.sessions: Dict[str, JingleSession] = {} + self._proposed_sessions: Dict[str, str] = {} + + self._broadcaster = SynchronizedRadioBroadcaster(on_track_end_callback=on_track_end) + self.send_transport_info = send_transport_info_callback + + self._cleanup_task: Optional[asyncio.Task] = None + self._shutdown = False + + async def start_cleanup_task(self): + self._cleanup_task = asyncio.create_task(self._periodic_cleanup()) + + async def _periodic_cleanup(self): + """Periodically removes ended sessions to recover memory.""" + while not self._shutdown: + await asyncio.sleep(self.SESSION_CLEANUP_INTERVAL) + try: + await self._cleanup_dead_sessions() + gc.collect() + except Exception as e: + logger.error(f"Cleanup error: {e}") + + async def _cleanup_dead_sessions(self): + dead_sessions = [ + sid for sid, session in self.sessions.items() + if session.state == CallState.ENDED + ] + for sid in dead_sessions: + await self.stop_session(sid) + + def register_proposed_session(self, sid: str, peer_jid: str): + # Limit the proposed session cache to prevent memory exhaustion attacks. + if len(self._proposed_sessions) > 50: + oldest = list(self._proposed_sessions.keys())[:25] + for key in oldest: + del self._proposed_sessions[key] + self._proposed_sessions[sid] = peer_jid + + def clear_proposed_session(self, sid: str): + self._proposed_sessions.pop(sid, None) + + def set_audio_source(self, source: str, force_restart: bool = False): + asyncio.create_task(self._broadcaster.set_source(source, force_restart)) + + async def stop_playback(self): + await self._broadcaster.stop_playback() + + def get_session(self, sid: str): + return self.sessions.get(sid) + + async def create_session(self, sid, peer_jid) -> JingleSession: + if len(self.sessions) >= self.MAX_SESSIONS: + await self._cleanup_dead_sessions() + if len(self.sessions) >= self.MAX_SESSIONS: + oldest = list(self.sessions.keys())[0] + await self.stop_session(oldest) + + servers = [] + if self.stun_server: + servers.append(RTCIceServer(urls=self.stun_server)) + + config = RTCConfiguration(iceServers=servers) + pc = RTCPeerConnection(configuration=config) + + session = JingleSession(sid=sid, peer_jid=peer_jid, pc=pc) + self.sessions[sid] = session + + @pc.on("icegatheringstatechange") + async def on_gathering_state(): + if pc.iceGatheringState == "complete": + session.ice_gathering_complete.set() + + @pc.on("connectionstatechange") + async def on_state(): + if pc.connectionState == "connected": + session.state = CallState.ACTIVE + elif pc.connectionState in ["failed", "closed", "disconnected"]: + session.state = CallState.ENDED + asyncio.create_task(self._cleanup_session(sid)) + + return session + + async def _cleanup_session(self, sid: str): + await asyncio.sleep(1) + if sid in self.sessions: + session = self.sessions[sid] + if session.state == CallState.ENDED: + await self.stop_session(sid) + + def _extract_candidates_from_sdp(self, sdp: str) -> List[Dict]: + """Parses SDP string to extract ICE candidates as dictionaries.""" + candidates = [] + for line in sdp.splitlines(): + if line.startswith('a=candidate:'): + match = re.match( + r'a=candidate:(\S+)\s+(\d+)\s+(\w+)\s+(\d+)\s+(\S+)\s+(\d+)\s+typ\s+(\w+)(?:\s+raddr\s+(\S+))?\s*(?:rport\s+(\d+))?', + line + ) + if match: + foundation, component, protocol, priority, ip, port, typ, raddr, rport = match.groups() + cand = { + 'foundation': foundation, + 'component': component, + 'protocol': protocol.lower(), + 'priority': priority, + 'ip': ip, + 'port': port, + 'type': typ, + 'generation': '0', + 'id': f"cand-{random.randint(1000, 9999)}" + } + if raddr: cand['rel-addr'] = raddr + if rport: cand['rel-port'] = rport + candidates.append(cand) + return candidates + + async def handle_session_initiate(self, jingle_xml, peer_jid, our_jid): + """Handles incoming Jingle session-initiate, creating the WebRTC answer.""" + sid = jingle_xml.get('sid') + session = await self.create_session(sid, peer_jid) + + sdp = self._jingle_to_sdp(jingle_xml) + offer = RTCSessionDescription(sdp=sdp, type="offer") + + track = SynchronizedAudioTrack(self._broadcaster) + session.audio_track = track + session.pc.addTrack(track) + + await session.pc.setRemoteDescription(offer) + + if session.pending_candidates: + for cand in session.pending_candidates: + try: await session.pc.addIceCandidate(cand) + except: pass + session.pending_candidates.clear() + + answer = await session.pc.createAnswer() + # Set to passive to let the other side act as DTLS server if needed. + answer_sdp = answer.sdp.replace('a=setup:active', 'a=setup:passive') + answer = RTCSessionDescription(sdp=answer_sdp, type="answer") + + await session.pc.setLocalDescription(answer) + + try: + await asyncio.wait_for(session.ice_gathering_complete.wait(), timeout=5.0) + except asyncio.TimeoutError: + pass + + if session.pc.localDescription: + session.local_candidates = self._extract_candidates_from_sdp( + session.pc.localDescription.sdp + ) + + jingle_xml = self._build_session_accept(session, sid, our_jid) + return session, jingle_xml + + def _build_session_accept(self, session: JingleSession, sid: str, responder: str) -> str: + """Constructs the session-accept Jingle XML stanza.""" + import xml.etree.ElementTree as ET + root = ET.Element('jingle', {'xmlns': self.NS_JINGLE, 'action': 'session-accept', 'sid': sid, 'responder': responder}) + content = ET.SubElement(root, 'content', {'creator': 'initiator', 'name': '0', 'senders': 'both'}) + desc = ET.SubElement(content, 'description', {'xmlns': 'urn:xmpp:jingle:apps:rtp:1', 'media': 'audio'}) + + sdp = session.pc.localDescription.sdp + codecs = {} + for line in sdp.splitlines(): + if line.startswith("a=rtpmap:"): + match = re.match(r'a=rtpmap:(\d+)\s+(\S+)/(\d+)(?:/(\d+))?', line) + if match: + codec_id = match.group(1) + codecs[codec_id] = {'name': match.group(2), 'clockrate': match.group(3), 'channels': match.group(4) or '1', 'fmtp': {}} + elif line.startswith("a=fmtp:"): + match = re.match(r'a=fmtp:(\d+)\s+(.+)', line) + if match: + codec_id = match.group(1) + fmtp_params = {} + for param in match.group(2).split(';'): + if '=' in param: + key, val = param.strip().split('=', 1) + fmtp_params[key.strip()] = val.strip() + if codec_id in codecs: + codecs[codec_id]['fmtp'] = fmtp_params + + for codec_id, codec_info in codecs.items(): + pt_attrs = {'id': codec_id, 'name': codec_info['name'], 'clockrate': codec_info['clockrate']} + if codec_info['channels'] != '1': pt_attrs['channels'] = codec_info['channels'] + pt = ET.SubElement(desc, 'payload-type', pt_attrs) + for param_name, param_value in codec_info['fmtp'].items(): + ET.SubElement(pt, 'parameter', {'name': param_name, 'value': param_value}) + + ET.SubElement(desc, 'rtcp-mux') + trans = ET.SubElement(content, 'transport', {'xmlns': self.NS_JINGLE_ICE}) + + for line in sdp.splitlines(): + if line.startswith("a=ice-ufrag:"): trans.set('ufrag', line.split(':')[1].strip()) + if line.startswith("a=ice-pwd:"): trans.set('pwd', line.split(':')[1].strip()) + if line.startswith("a=fingerprint:"): + parts = line.split(':', 1)[1].strip().split(None, 1) + if len(parts) == 2: + fp = ET.SubElement(trans, 'fingerprint', {'xmlns': 'urn:xmpp:jingle:apps:dtls:0', 'hash': parts[0], 'setup': 'passive'}) + fp.text = parts[1] + + for cand in session.local_candidates: + cand_elem = ET.SubElement(trans, 'candidate', { + 'component': str(cand.get('component', '1')), 'foundation': str(cand.get('foundation', '1')), + 'generation': str(cand.get('generation', '0')), 'id': cand.get('id', 'cand-0'), + 'ip': cand.get('ip', ''), 'port': str(cand.get('port', '0')), + 'priority': str(cand.get('priority', '1')), 'protocol': cand.get('protocol', 'udp'), + 'type': cand.get('type', 'host') + }) + if 'rel-addr' in cand: cand_elem.set('rel-addr', cand['rel-addr']) + if 'rel-port' in cand: cand_elem.set('rel-port', str(cand['rel-port'])) + + return ET.tostring(root, encoding='unicode') + + async def add_ice_candidate(self, session: JingleSession, xml_cand): + try: + cand = RTCIceCandidate( + foundation=xml_cand.get('foundation', '1'), component=int(xml_cand.get('component', '1')), + protocol=xml_cand.get('protocol', 'udp'), priority=int(xml_cand.get('priority', '1')), + ip=xml_cand.get('ip'), port=int(xml_cand.get('port')), type=xml_cand.get('type', 'host'), + sdpMid="0", sdpMLineIndex=0 + ) + if session.pc.remoteDescription: await session.pc.addIceCandidate(cand) + else: session.pending_candidates.append(cand) + except Exception as e: + logger.debug(f"Candidate error: {e}") + + async def stop_session(self, sid): + if sid in self.sessions: + s = self.sessions.pop(sid) + if s.audio_track: + try: s.audio_track.stop() + except: pass + s.audio_track = None + if s.pc: + try: await s.pc.close() + except: pass + s.pc = None + s.cleanup() + del s + logger.info(f"Stopped session {sid}") + self.clear_proposed_session(sid) + gc.collect() + + async def end_all_sessions(self): + """Terminates all active sessions and shuts down the broadcaster.""" + self._shutdown = True + self._broadcaster.shutdown() + if self._cleanup_task: + self._cleanup_task.cancel() + try: await self._cleanup_task + except asyncio.CancelledError: pass + self._cleanup_task = None + for sid in list(self.sessions.keys()): + await self.stop_session(sid) + self._proposed_sessions.clear() + gc.collect() + + def get_active_sessions(self): + return [s for s in self.sessions.values() if s.state == CallState.ACTIVE] + + def _jingle_to_sdp(self, xml): + """Converts Jingle XML format to standard SDP text.""" + lines = ["v=0", "o=- 0 0 IN IP4 0.0.0.0", "s=-", "t=0 0", "a=group:BUNDLE 0", "a=msid-semantic: WMS *"] + content = xml.find(f"{{urn:xmpp:jingle:1}}content") + desc = content.find(f"{{urn:xmpp:jingle:apps:rtp:1}}description") + trans = content.find(f"{{urn:xmpp:jingle:transports:ice-udp:1}}transport") + + payloads = [pt.get('id') for pt in desc.findall(f"{{urn:xmpp:jingle:apps:rtp:1}}payload-type")] + lines.append(f"m=audio 9 UDP/TLS/RTP/SAVPF {' '.join(payloads)}") + lines.append("c=IN IP4 0.0.0.0") + lines.append("a=rtcp:9 IN IP4 0.0.0.0") + lines.append(f"a=ice-ufrag:{trans.get('ufrag')}") + lines.append(f"a=ice-pwd:{trans.get('pwd')}") + + fp = trans.find(f"{{urn:xmpp:jingle:apps:dtls:0}}fingerprint") + if fp is not None: + lines.append(f"a=fingerprint:{fp.get('hash')} {fp.text}") + setup = fp.get('setup', 'actpass') + if setup == 'actpass': setup = 'active' + lines.append(f"a=setup:{setup}") + + lines.append("a=mid:0") + lines.append("a=sendrecv") + lines.append("a=rtcp-mux") + + for pt in desc.findall(f"{{urn:xmpp:jingle:apps:rtp:1}}payload-type"): + lines.append(f"a=rtpmap:{pt.get('id')} {pt.get('name')}/{pt.get('clockrate')}/{pt.get('channels', '1')}") + + return "\r\n".join(lines) \ No newline at end of file