Update jingle_webrtc.py

This commit is contained in:
2025-12-17 17:24:34 +00:00
parent 1fb27b265c
commit 3f5f5f51f9

View File

@@ -1,16 +1,4 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""
WebRTC and Jingle handling module.
This module provides the core logic for establishing Audio/Video sessions via
XMPP Jingle (XEP-0166) and WebRTC. It manages the RTCPeerConnections,
ICE candidate negotiation, and audio broadcasting.
Key features:
- Shared memory audio broadcasting to multiple subscribers.
- Synchronization of audio streams.
- Explicit cleanup routines to prevent memory leaks with aiortc objects.
"""
import asyncio import asyncio
import gc import gc
@@ -36,29 +24,13 @@ try:
except ImportError as e: except ImportError as e:
logger.error(f"aiortc import failed: {e}") logger.error(f"aiortc import failed: {e}")
class CallState(Enum): class CallState(Enum):
"""Represents the lifecycle state of a Jingle session."""
IDLE = "idle" IDLE = "idle"
ACTIVE = "active" ACTIVE = "active"
ENDED = "ended" ENDED = "ended"
@dataclass(slots=True) @dataclass(slots=True)
class JingleSession: class JingleSession:
"""
Stores state for a single WebRTC session.
Attributes:
sid: Session ID.
peer_jid: JID of the remote peer.
pc: The RTCPeerConnection object.
audio_track: The specific audio track associated with this session.
state: Current call state.
pending_candidates: ICE candidates received before remote description was set.
local_candidates: ICE candidates generated locally.
ice_gathering_complete: Event signaling completion of ICE gathering.
"""
sid: str sid: str
peer_jid: str peer_jid: str
pc: Any = None pc: Any = None
@@ -69,22 +41,12 @@ class JingleSession:
ice_gathering_complete: asyncio.Event = field(default_factory=asyncio.Event) ice_gathering_complete: asyncio.Event = field(default_factory=asyncio.Event)
def cleanup(self): def cleanup(self):
"""Releases references to heavy objects to aid garbage collection."""
self.pending_candidates.clear() self.pending_candidates.clear()
self.local_candidates.clear() self.local_candidates.clear()
self.pc = None self.pc = None
self.audio_track = None self.audio_track = None
class SynchronizedRadioBroadcaster: class SynchronizedRadioBroadcaster:
"""
Central audio broadcaster using a producer-consumer model.
This class decodes a single audio stream and distributes the frames to
multiple subscribers (SynchronizedAudioTrack instances) simultaneously.
It uses weak references to subscribers to avoid reference cycles.
"""
MAX_SUBSCRIBERS = 10 MAX_SUBSCRIBERS = 10
SAMPLES_PER_FRAME = 960 SAMPLES_PER_FRAME = 960
CLEANUP_INTERVAL = 200 CLEANUP_INTERVAL = 200
@@ -93,7 +55,6 @@ class SynchronizedRadioBroadcaster:
self._player: Optional[MediaPlayer] = None self._player: Optional[MediaPlayer] = None
self._current_source: Optional[str] = None self._current_source: Optional[str] = None
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
# Use weakrefs to allow tracks to be GC'd if the connection drops unexpectedly.
self._subscribers: List[weakref.ref] = [] self._subscribers: List[weakref.ref] = []
self._broadcast_task: Optional[asyncio.Task] = None self._broadcast_task: Optional[asyncio.Task] = None
self._pts = 0 self._pts = 0
@@ -104,11 +65,9 @@ class SynchronizedRadioBroadcaster:
self._stopped = False self._stopped = False
self._frame_counter = 0 self._frame_counter = 0
# Persist resampler to avoid high CPU usage from constant reallocation.
self._resampler = None self._resampler = None
def _get_silence_frame(self) -> 'av.AudioFrame': def _get_silence_frame(self) -> 'av.AudioFrame':
"""Generates a silent audio frame to keep the RTP stream alive."""
frame = av.AudioFrame(format='s16', layout='stereo', samples=self.SAMPLES_PER_FRAME) frame = av.AudioFrame(format='s16', layout='stereo', samples=self.SAMPLES_PER_FRAME)
frame.sample_rate = 48000 frame.sample_rate = 48000
frame.time_base = self._time_base frame.time_base = self._time_base
@@ -119,13 +78,6 @@ class SynchronizedRadioBroadcaster:
return frame return frame
async def set_source(self, source: str, force_restart: bool = False): async def set_source(self, source: str, force_restart: bool = False):
"""
Updates the audio source, restarting the broadcast loop if necessary.
Args:
source: URI or path to the audio file.
force_restart: If True, restarts player even if source is identical.
"""
async with self._lock: async with self._lock:
normalized_source = source normalized_source = source
if source and not source.startswith(('http:', 'https:', 'rtmp:', 'rtsp:')): if source and not source.startswith(('http:', 'https:', 'rtmp:', 'rtsp:')):
@@ -142,7 +94,7 @@ class SynchronizedRadioBroadcaster:
not force_restart): not force_restart):
return return
logger.info(f"Changing broadcast: {os.path.basename(normalized_source) if normalized_source else 'None'}") logger.info(f"🔄 Changing broadcast: {os.path.basename(normalized_source) if normalized_source else 'None'}")
await self._stop_broadcast_task() await self._stop_broadcast_task()
await self._close_player() await self._close_player()
@@ -153,15 +105,13 @@ class SynchronizedRadioBroadcaster:
self._stopped = False self._stopped = False
self._frame_counter = 0 self._frame_counter = 0
# Reset resampler context to prevent timestamp drift or format mismatch.
self._resampler = None self._resampler = None
if not source: if not source:
return return
try: try:
logger.info(f"Starting broadcast: {source}") logger.info(f"🎵 Starting broadcast: {source}")
# Use a buffer for network streams to reduce jitter.
options = {'rtbufsize': '8M'} if source.startswith(('http', 'rtmp', 'rtsp')) else {} options = {'rtbufsize': '8M'} if source.startswith(('http', 'rtmp', 'rtsp')) else {}
self._player = MediaPlayer(source, options=options) self._player = MediaPlayer(source, options=options)
self._broadcast_task = asyncio.create_task(self._broadcast_loop()) self._broadcast_task = asyncio.create_task(self._broadcast_loop())
@@ -171,7 +121,6 @@ class SynchronizedRadioBroadcaster:
self._player = None self._player = None
async def _stop_broadcast_task(self): async def _stop_broadcast_task(self):
"""Cancels and awaits the broadcast task."""
if self._broadcast_task: if self._broadcast_task:
self._broadcast_task.cancel() self._broadcast_task.cancel()
try: try:
@@ -181,7 +130,6 @@ class SynchronizedRadioBroadcaster:
self._broadcast_task = None self._broadcast_task = None
async def _close_player(self): async def _close_player(self):
"""Safely closes the media player and underlying container."""
self._resampler = None self._resampler = None
if self._player: if self._player:
@@ -204,19 +152,14 @@ class SynchronizedRadioBroadcaster:
gc.collect() gc.collect()
async def stop_playback(self): async def stop_playback(self):
"""Stops the broadcast and clears the current source."""
async with self._lock: async with self._lock:
self._stopped = True self._stopped = True
await self._stop_broadcast_task() await self._stop_broadcast_task()
await self._close_player() await self._close_player()
self._current_source = None self._current_source = None
logger.info("Playback stopped") logger.info("⏹️ Playback stopped")
async def _broadcast_loop(self): async def _broadcast_loop(self):
"""
Main loop: decodes frames, resamples them, and distributes to subscribers.
Handles errors and End-Of-File (EOF) events.
"""
last_error_time = 0 last_error_time = 0
error_count_window = 0 error_count_window = 0
@@ -224,11 +167,9 @@ class SynchronizedRadioBroadcaster:
while self._player and not self._shutdown and not self._stopped: while self._player and not self._shutdown and not self._stopped:
self._frame_counter += 1 self._frame_counter += 1
# Periodically clean up dead weak references.
if self._frame_counter % self.CLEANUP_INTERVAL == 0: if self._frame_counter % self.CLEANUP_INTERVAL == 0:
self._cleanup_subscribers() self._cleanup_subscribers()
# Pause decoding if no one is listening to save CPU.
if not self._get_active_subscribers(): if not self._get_active_subscribers():
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
continue continue
@@ -247,14 +188,12 @@ class SynchronizedRadioBroadcaster:
if current_time - last_error_time > 5.0: if current_time - last_error_time > 5.0:
error_count_window = 0 error_count_window = 0
# Ensure audio is standard stereo 48kHz.
if frame.format.name != 's16' or frame.sample_rate != 48000: if frame.format.name != 's16' or frame.sample_rate != 48000:
if self._resampler is None: if self._resampler is None:
self._resampler = av.AudioResampler(format='s16', layout='stereo', rate=48000) self._resampler = av.AudioResampler(format='s16', layout='stereo', rate=48000)
resampled_frames = self._resampler.resample(frame) resampled_frames = self._resampler.resample(frame)
# Explicitly release the original frame to free C-level memory.
del frame del frame
if resampled_frames: if resampled_frames:
@@ -262,7 +201,6 @@ class SynchronizedRadioBroadcaster:
else: else:
continue continue
# Rewrite timestamps for the continuous stream.
frame.pts = self._pts frame.pts = self._pts
frame.time_base = self._time_base frame.time_base = self._time_base
self._pts += frame.samples self._pts += frame.samples
@@ -270,13 +208,12 @@ class SynchronizedRadioBroadcaster:
await self._distribute_frame(frame) await self._distribute_frame(frame)
except asyncio.TimeoutError: except asyncio.TimeoutError:
# Send silence on network timeout to prevent RTP timeout.
s_frame = self._get_silence_frame() s_frame = self._get_silence_frame()
await self._distribute_frame(s_frame) await self._distribute_frame(s_frame)
del s_frame del s_frame
except (av.error.EOFError, StopIteration, StopAsyncIteration): except (av.error.EOFError, StopIteration, StopAsyncIteration):
logger.info("Track ended (EOF)") logger.info("📻 Track ended (EOF)")
break break
except Exception as e: except Exception as e:
@@ -296,7 +233,6 @@ class SynchronizedRadioBroadcaster:
if frame is not None: if frame is not None:
del frame del frame
# Notify the main bot that the track finished.
if self.on_track_end and not self._shutdown and not self._track_end_fired and not self._stopped: if self.on_track_end and not self._shutdown and not self._track_end_fired and not self._stopped:
self._track_end_fired = True self._track_end_fired = True
try: try:
@@ -330,7 +266,6 @@ class SynchronizedRadioBroadcaster:
if not subscribers: if not subscribers:
return return
# Dispatch frames to all tracks concurrently.
tasks = [sub._receive_frame(frame) for sub in subscribers] tasks = [sub._receive_frame(frame) for sub in subscribers]
if tasks: if tasks:
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
@@ -350,11 +285,7 @@ class SynchronizedRadioBroadcaster:
self._resampler = None self._resampler = None
gc.collect() gc.collect()
class SynchronizedAudioTrack(MediaStreamTrack): class SynchronizedAudioTrack(MediaStreamTrack):
"""
A MediaStreamTrack that receives frames from the broadcaster queue.
"""
kind = "audio" kind = "audio"
MAX_QUEUE_SIZE = 3 MAX_QUEUE_SIZE = 3
@@ -370,7 +301,6 @@ class SynchronizedAudioTrack(MediaStreamTrack):
self._broadcaster.subscribe(self) self._broadcaster.subscribe(self)
def _get_silence_frame(self) -> 'av.AudioFrame': def _get_silence_frame(self) -> 'av.AudioFrame':
"""Lazily creates and reuses a silence frame."""
if self._silence_frame is None: if self._silence_frame is None:
self._silence_frame = av.AudioFrame(format='s16', layout='stereo', samples=960) self._silence_frame = av.AudioFrame(format='s16', layout='stereo', samples=960)
self._silence_frame.sample_rate = 48000 self._silence_frame.sample_rate = 48000
@@ -387,7 +317,6 @@ class SynchronizedAudioTrack(MediaStreamTrack):
return return
try: try:
# If the queue is full, drop the oldest frame to reduce latency.
if self._frame_queue.full(): if self._frame_queue.full():
try: try:
old_frame = self._frame_queue.get_nowait() old_frame = self._frame_queue.get_nowait()
@@ -400,7 +329,6 @@ class SynchronizedAudioTrack(MediaStreamTrack):
pass pass
async def recv(self): async def recv(self):
"""Called by aiortc to pull the next frame."""
if not self._active: if not self._active:
raise Exception("Track stopped") raise Exception("Track stopped")
@@ -417,7 +345,6 @@ class SynchronizedAudioTrack(MediaStreamTrack):
self._active = False self._active = False
self._broadcaster.unsubscribe(self) self._broadcaster.unsubscribe(self)
# Drain queue to release frame references.
while not self._frame_queue.empty(): while not self._frame_queue.empty():
try: try:
frame = self._frame_queue.get_nowait() frame = self._frame_queue.get_nowait()
@@ -431,12 +358,7 @@ class SynchronizedAudioTrack(MediaStreamTrack):
super().stop() super().stop()
class JingleWebRTCHandler: class JingleWebRTCHandler:
"""
Manages the XMPP Jingle negotiation and WebRTC sessions.
Translates between Jingle XML stanzas and SDP.
"""
NS_JINGLE = "urn:xmpp:jingle:1" NS_JINGLE = "urn:xmpp:jingle:1"
NS_JINGLE_ICE = "urn:xmpp:jingle:transports:ice-udp:1" NS_JINGLE_ICE = "urn:xmpp:jingle:transports:ice-udp:1"
@@ -462,7 +384,6 @@ class JingleWebRTCHandler:
self._cleanup_task = asyncio.create_task(self._periodic_cleanup()) self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
async def _periodic_cleanup(self): async def _periodic_cleanup(self):
"""Periodically removes ended sessions to recover memory."""
while not self._shutdown: while not self._shutdown:
await asyncio.sleep(self.SESSION_CLEANUP_INTERVAL) await asyncio.sleep(self.SESSION_CLEANUP_INTERVAL)
try: try:
@@ -480,7 +401,6 @@ class JingleWebRTCHandler:
await self.stop_session(sid) await self.stop_session(sid)
def register_proposed_session(self, sid: str, peer_jid: str): def register_proposed_session(self, sid: str, peer_jid: str):
# Limit the proposed session cache to prevent memory exhaustion attacks.
if len(self._proposed_sessions) > 50: if len(self._proposed_sessions) > 50:
oldest = list(self._proposed_sessions.keys())[:25] oldest = list(self._proposed_sessions.keys())[:25]
for key in oldest: for key in oldest:
@@ -539,37 +459,54 @@ class JingleWebRTCHandler:
await self.stop_session(sid) await self.stop_session(sid)
def _extract_candidates_from_sdp(self, sdp: str) -> List[Dict]: def _extract_candidates_from_sdp(self, sdp: str) -> List[Dict]:
"""Parses SDP string to extract ICE candidates as dictionaries."""
candidates = [] candidates = []
for line in sdp.splitlines(): for line in sdp.splitlines():
if line.startswith('a=candidate:'): if line.startswith('a=candidate:'):
match = re.match( parts = line[12:].split()
r'a=candidate:(\S+)\s+(\d+)\s+(\w+)\s+(\d+)\s+(\S+)\s+(\d+)\s+typ\s+(\w+)(?:\s+raddr\s+(\S+))?\s*(?:rport\s+(\d+))?', if len(parts) < 8:
line continue
)
if match: cand = {
foundation, component, protocol, priority, ip, port, typ, raddr, rport = match.groups() 'foundation': parts[0],
cand = { 'component': parts[1],
'foundation': foundation, 'protocol': parts[2].lower(),
'component': component, 'priority': parts[3],
'protocol': protocol.lower(), 'ip': parts[4],
'priority': priority, 'port': parts[5],
'ip': ip, 'type': parts[7],
'port': port, 'generation': '0',
'type': typ, 'id': f"cand-{random.randint(1000, 9999)}"
'generation': '0', }
'id': f"cand-{random.randint(1000, 9999)}"
} for i in range(8, len(parts), 2):
if raddr: cand['rel-addr'] = raddr if i + 1 < len(parts):
if rport: cand['rel-port'] = rport key = parts[i]
candidates.append(cand) val = parts[i+1]
if key == 'raddr':
cand['rel-addr'] = val
elif key == 'rport':
cand['rel-port'] = val
elif key == 'generation':
cand['generation'] = val
elif key == 'network':
cand['network'] = val
candidates.append(cand)
return candidates return candidates
async def handle_session_initiate(self, jingle_xml, peer_jid, our_jid): async def handle_session_initiate(self, jingle_xml, peer_jid, our_jid):
"""Handles incoming Jingle session-initiate, creating the WebRTC answer."""
sid = jingle_xml.get('sid') sid = jingle_xml.get('sid')
session = await self.create_session(sid, peer_jid) session = await self.create_session(sid, peer_jid)
content = jingle_xml.find(f"{{{self.NS_JINGLE}}}content")
if content is None:
for child in jingle_xml:
if child.tag.endswith('content'):
content = child
break
content_name = content.get('name', '0') if content is not None else '0'
sdp = self._jingle_to_sdp(jingle_xml) sdp = self._jingle_to_sdp(jingle_xml)
offer = RTCSessionDescription(sdp=sdp, type="offer") offer = RTCSessionDescription(sdp=sdp, type="offer")
@@ -586,7 +523,6 @@ class JingleWebRTCHandler:
session.pending_candidates.clear() session.pending_candidates.clear()
answer = await session.pc.createAnswer() answer = await session.pc.createAnswer()
# Set to passive to let the other side act as DTLS server if needed.
answer_sdp = answer.sdp.replace('a=setup:active', 'a=setup:passive') answer_sdp = answer.sdp.replace('a=setup:active', 'a=setup:passive')
answer = RTCSessionDescription(sdp=answer_sdp, type="answer") answer = RTCSessionDescription(sdp=answer_sdp, type="answer")
@@ -602,21 +538,20 @@ class JingleWebRTCHandler:
session.pc.localDescription.sdp session.pc.localDescription.sdp
) )
jingle_xml = self._build_session_accept(session, sid, our_jid) jingle_xml = self._build_session_accept(session, sid, our_jid, content_name)
return session, jingle_xml return session, jingle_xml
def _build_session_accept(self, session: JingleSession, sid: str, responder: str) -> str: def _build_session_accept(self, session: JingleSession, sid: str, responder: str, content_name: str) -> str:
"""Constructs the session-accept Jingle XML stanza."""
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
root = ET.Element('jingle', {'xmlns': self.NS_JINGLE, 'action': 'session-accept', 'sid': sid, 'responder': responder}) root = ET.Element('jingle', {'xmlns': self.NS_JINGLE, 'action': 'session-accept', 'sid': sid, 'responder': responder})
content = ET.SubElement(root, 'content', {'creator': 'initiator', 'name': '0', 'senders': 'both'}) content = ET.SubElement(root, 'content', {'creator': 'initiator', 'name': content_name, 'senders': 'both'})
desc = ET.SubElement(content, 'description', {'xmlns': 'urn:xmpp:jingle:apps:rtp:1', 'media': 'audio'}) desc = ET.SubElement(content, 'description', {'xmlns': 'urn:xmpp:jingle:apps:rtp:1', 'media': 'audio'})
sdp = session.pc.localDescription.sdp sdp = session.pc.localDescription.sdp
codecs = {} codecs = {}
for line in sdp.splitlines(): for line in sdp.splitlines():
if line.startswith("a=rtpmap:"): if line.startswith("a=rtpmap:"):
match = re.match(r'a=rtpmap:(\d+)\s+(\S+)/(\d+)(?:/(\d+))?', line) match = re.match(r'a=rtpmap:(\d+)\s+([^/]+)/(\d+)(?:/(\d+))?', line)
if match: if match:
codec_id = match.group(1) codec_id = match.group(1)
codecs[codec_id] = {'name': match.group(2), 'clockrate': match.group(3), 'channels': match.group(4) or '1', 'fmtp': {}} codecs[codec_id] = {'name': match.group(2), 'clockrate': match.group(3), 'channels': match.group(4) or '1', 'fmtp': {}}
@@ -690,12 +625,11 @@ class JingleWebRTCHandler:
s.pc = None s.pc = None
s.cleanup() s.cleanup()
del s del s
logger.info(f"Stopped session {sid}") logger.info(f"🔴 Stopped session {sid}")
self.clear_proposed_session(sid) self.clear_proposed_session(sid)
gc.collect() gc.collect()
async def end_all_sessions(self): async def end_all_sessions(self):
"""Terminates all active sessions and shuts down the broadcaster."""
self._shutdown = True self._shutdown = True
self._broadcaster.shutdown() self._broadcaster.shutdown()
if self._cleanup_task: if self._cleanup_task:
@@ -712,7 +646,6 @@ class JingleWebRTCHandler:
return [s for s in self.sessions.values() if s.state == CallState.ACTIVE] return [s for s in self.sessions.values() if s.state == CallState.ACTIVE]
def _jingle_to_sdp(self, xml): def _jingle_to_sdp(self, xml):
"""Converts Jingle XML format to standard SDP text."""
lines = ["v=0", "o=- 0 0 IN IP4 0.0.0.0", "s=-", "t=0 0", "a=group:BUNDLE 0", "a=msid-semantic: WMS *"] lines = ["v=0", "o=- 0 0 IN IP4 0.0.0.0", "s=-", "t=0 0", "a=group:BUNDLE 0", "a=msid-semantic: WMS *"]
content = xml.find(f"{{urn:xmpp:jingle:1}}content") content = xml.find(f"{{urn:xmpp:jingle:1}}content")
desc = content.find(f"{{urn:xmpp:jingle:apps:rtp:1}}description") desc = content.find(f"{{urn:xmpp:jingle:apps:rtp:1}}description")