Update jingle_webrtc.py
This commit is contained in:
163
jingle_webrtc.py
163
jingle_webrtc.py
@@ -1,16 +1,4 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
WebRTC and Jingle handling module.
|
||||
|
||||
This module provides the core logic for establishing Audio/Video sessions via
|
||||
XMPP Jingle (XEP-0166) and WebRTC. It manages the RTCPeerConnections,
|
||||
ICE candidate negotiation, and audio broadcasting.
|
||||
|
||||
Key features:
|
||||
- Shared memory audio broadcasting to multiple subscribers.
|
||||
- Synchronization of audio streams.
|
||||
- Explicit cleanup routines to prevent memory leaks with aiortc objects.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
@@ -36,29 +24,13 @@ try:
|
||||
except ImportError as e:
|
||||
logger.error(f"aiortc import failed: {e}")
|
||||
|
||||
|
||||
class CallState(Enum):
|
||||
"""Represents the lifecycle state of a Jingle session."""
|
||||
IDLE = "idle"
|
||||
ACTIVE = "active"
|
||||
ENDED = "ended"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class JingleSession:
|
||||
"""
|
||||
Stores state for a single WebRTC session.
|
||||
|
||||
Attributes:
|
||||
sid: Session ID.
|
||||
peer_jid: JID of the remote peer.
|
||||
pc: The RTCPeerConnection object.
|
||||
audio_track: The specific audio track associated with this session.
|
||||
state: Current call state.
|
||||
pending_candidates: ICE candidates received before remote description was set.
|
||||
local_candidates: ICE candidates generated locally.
|
||||
ice_gathering_complete: Event signaling completion of ICE gathering.
|
||||
"""
|
||||
sid: str
|
||||
peer_jid: str
|
||||
pc: Any = None
|
||||
@@ -69,22 +41,12 @@ class JingleSession:
|
||||
ice_gathering_complete: asyncio.Event = field(default_factory=asyncio.Event)
|
||||
|
||||
def cleanup(self):
|
||||
"""Releases references to heavy objects to aid garbage collection."""
|
||||
self.pending_candidates.clear()
|
||||
self.local_candidates.clear()
|
||||
self.pc = None
|
||||
self.audio_track = None
|
||||
|
||||
|
||||
class SynchronizedRadioBroadcaster:
|
||||
"""
|
||||
Central audio broadcaster using a producer-consumer model.
|
||||
|
||||
This class decodes a single audio stream and distributes the frames to
|
||||
multiple subscribers (SynchronizedAudioTrack instances) simultaneously.
|
||||
It uses weak references to subscribers to avoid reference cycles.
|
||||
"""
|
||||
|
||||
MAX_SUBSCRIBERS = 10
|
||||
SAMPLES_PER_FRAME = 960
|
||||
CLEANUP_INTERVAL = 200
|
||||
@@ -93,7 +55,6 @@ class SynchronizedRadioBroadcaster:
|
||||
self._player: Optional[MediaPlayer] = None
|
||||
self._current_source: Optional[str] = None
|
||||
self._lock = asyncio.Lock()
|
||||
# Use weakrefs to allow tracks to be GC'd if the connection drops unexpectedly.
|
||||
self._subscribers: List[weakref.ref] = []
|
||||
self._broadcast_task: Optional[asyncio.Task] = None
|
||||
self._pts = 0
|
||||
@@ -104,11 +65,9 @@ class SynchronizedRadioBroadcaster:
|
||||
self._stopped = False
|
||||
self._frame_counter = 0
|
||||
|
||||
# Persist resampler to avoid high CPU usage from constant reallocation.
|
||||
self._resampler = None
|
||||
|
||||
def _get_silence_frame(self) -> 'av.AudioFrame':
|
||||
"""Generates a silent audio frame to keep the RTP stream alive."""
|
||||
frame = av.AudioFrame(format='s16', layout='stereo', samples=self.SAMPLES_PER_FRAME)
|
||||
frame.sample_rate = 48000
|
||||
frame.time_base = self._time_base
|
||||
@@ -119,13 +78,6 @@ class SynchronizedRadioBroadcaster:
|
||||
return frame
|
||||
|
||||
async def set_source(self, source: str, force_restart: bool = False):
|
||||
"""
|
||||
Updates the audio source, restarting the broadcast loop if necessary.
|
||||
|
||||
Args:
|
||||
source: URI or path to the audio file.
|
||||
force_restart: If True, restarts player even if source is identical.
|
||||
"""
|
||||
async with self._lock:
|
||||
normalized_source = source
|
||||
if source and not source.startswith(('http:', 'https:', 'rtmp:', 'rtsp:')):
|
||||
@@ -142,7 +94,7 @@ class SynchronizedRadioBroadcaster:
|
||||
not force_restart):
|
||||
return
|
||||
|
||||
logger.info(f"Changing broadcast: {os.path.basename(normalized_source) if normalized_source else 'None'}")
|
||||
logger.info(f"🔄 Changing broadcast: {os.path.basename(normalized_source) if normalized_source else 'None'}")
|
||||
|
||||
await self._stop_broadcast_task()
|
||||
await self._close_player()
|
||||
@@ -153,15 +105,13 @@ class SynchronizedRadioBroadcaster:
|
||||
self._stopped = False
|
||||
self._frame_counter = 0
|
||||
|
||||
# Reset resampler context to prevent timestamp drift or format mismatch.
|
||||
self._resampler = None
|
||||
|
||||
if not source:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"Starting broadcast: {source}")
|
||||
# Use a buffer for network streams to reduce jitter.
|
||||
logger.info(f"🎵 Starting broadcast: {source}")
|
||||
options = {'rtbufsize': '8M'} if source.startswith(('http', 'rtmp', 'rtsp')) else {}
|
||||
self._player = MediaPlayer(source, options=options)
|
||||
self._broadcast_task = asyncio.create_task(self._broadcast_loop())
|
||||
@@ -171,7 +121,6 @@ class SynchronizedRadioBroadcaster:
|
||||
self._player = None
|
||||
|
||||
async def _stop_broadcast_task(self):
|
||||
"""Cancels and awaits the broadcast task."""
|
||||
if self._broadcast_task:
|
||||
self._broadcast_task.cancel()
|
||||
try:
|
||||
@@ -181,7 +130,6 @@ class SynchronizedRadioBroadcaster:
|
||||
self._broadcast_task = None
|
||||
|
||||
async def _close_player(self):
|
||||
"""Safely closes the media player and underlying container."""
|
||||
self._resampler = None
|
||||
|
||||
if self._player:
|
||||
@@ -204,19 +152,14 @@ class SynchronizedRadioBroadcaster:
|
||||
gc.collect()
|
||||
|
||||
async def stop_playback(self):
|
||||
"""Stops the broadcast and clears the current source."""
|
||||
async with self._lock:
|
||||
self._stopped = True
|
||||
await self._stop_broadcast_task()
|
||||
await self._close_player()
|
||||
self._current_source = None
|
||||
logger.info("Playback stopped")
|
||||
logger.info("⏹️ Playback stopped")
|
||||
|
||||
async def _broadcast_loop(self):
|
||||
"""
|
||||
Main loop: decodes frames, resamples them, and distributes to subscribers.
|
||||
Handles errors and End-Of-File (EOF) events.
|
||||
"""
|
||||
last_error_time = 0
|
||||
error_count_window = 0
|
||||
|
||||
@@ -224,11 +167,9 @@ class SynchronizedRadioBroadcaster:
|
||||
while self._player and not self._shutdown and not self._stopped:
|
||||
self._frame_counter += 1
|
||||
|
||||
# Periodically clean up dead weak references.
|
||||
if self._frame_counter % self.CLEANUP_INTERVAL == 0:
|
||||
self._cleanup_subscribers()
|
||||
|
||||
# Pause decoding if no one is listening to save CPU.
|
||||
if not self._get_active_subscribers():
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
@@ -247,14 +188,12 @@ class SynchronizedRadioBroadcaster:
|
||||
if current_time - last_error_time > 5.0:
|
||||
error_count_window = 0
|
||||
|
||||
# Ensure audio is standard stereo 48kHz.
|
||||
if frame.format.name != 's16' or frame.sample_rate != 48000:
|
||||
if self._resampler is None:
|
||||
self._resampler = av.AudioResampler(format='s16', layout='stereo', rate=48000)
|
||||
|
||||
resampled_frames = self._resampler.resample(frame)
|
||||
|
||||
# Explicitly release the original frame to free C-level memory.
|
||||
del frame
|
||||
|
||||
if resampled_frames:
|
||||
@@ -262,7 +201,6 @@ class SynchronizedRadioBroadcaster:
|
||||
else:
|
||||
continue
|
||||
|
||||
# Rewrite timestamps for the continuous stream.
|
||||
frame.pts = self._pts
|
||||
frame.time_base = self._time_base
|
||||
self._pts += frame.samples
|
||||
@@ -270,13 +208,12 @@ class SynchronizedRadioBroadcaster:
|
||||
await self._distribute_frame(frame)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Send silence on network timeout to prevent RTP timeout.
|
||||
s_frame = self._get_silence_frame()
|
||||
await self._distribute_frame(s_frame)
|
||||
del s_frame
|
||||
|
||||
except (av.error.EOFError, StopIteration, StopAsyncIteration):
|
||||
logger.info("Track ended (EOF)")
|
||||
logger.info("📻 Track ended (EOF)")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
@@ -296,7 +233,6 @@ class SynchronizedRadioBroadcaster:
|
||||
if frame is not None:
|
||||
del frame
|
||||
|
||||
# Notify the main bot that the track finished.
|
||||
if self.on_track_end and not self._shutdown and not self._track_end_fired and not self._stopped:
|
||||
self._track_end_fired = True
|
||||
try:
|
||||
@@ -330,7 +266,6 @@ class SynchronizedRadioBroadcaster:
|
||||
if not subscribers:
|
||||
return
|
||||
|
||||
# Dispatch frames to all tracks concurrently.
|
||||
tasks = [sub._receive_frame(frame) for sub in subscribers]
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
@@ -350,11 +285,7 @@ class SynchronizedRadioBroadcaster:
|
||||
self._resampler = None
|
||||
gc.collect()
|
||||
|
||||
|
||||
class SynchronizedAudioTrack(MediaStreamTrack):
|
||||
"""
|
||||
A MediaStreamTrack that receives frames from the broadcaster queue.
|
||||
"""
|
||||
kind = "audio"
|
||||
MAX_QUEUE_SIZE = 3
|
||||
|
||||
@@ -370,7 +301,6 @@ class SynchronizedAudioTrack(MediaStreamTrack):
|
||||
self._broadcaster.subscribe(self)
|
||||
|
||||
def _get_silence_frame(self) -> 'av.AudioFrame':
|
||||
"""Lazily creates and reuses a silence frame."""
|
||||
if self._silence_frame is None:
|
||||
self._silence_frame = av.AudioFrame(format='s16', layout='stereo', samples=960)
|
||||
self._silence_frame.sample_rate = 48000
|
||||
@@ -387,7 +317,6 @@ class SynchronizedAudioTrack(MediaStreamTrack):
|
||||
return
|
||||
|
||||
try:
|
||||
# If the queue is full, drop the oldest frame to reduce latency.
|
||||
if self._frame_queue.full():
|
||||
try:
|
||||
old_frame = self._frame_queue.get_nowait()
|
||||
@@ -400,7 +329,6 @@ class SynchronizedAudioTrack(MediaStreamTrack):
|
||||
pass
|
||||
|
||||
async def recv(self):
|
||||
"""Called by aiortc to pull the next frame."""
|
||||
if not self._active:
|
||||
raise Exception("Track stopped")
|
||||
|
||||
@@ -417,7 +345,6 @@ class SynchronizedAudioTrack(MediaStreamTrack):
|
||||
self._active = False
|
||||
self._broadcaster.unsubscribe(self)
|
||||
|
||||
# Drain queue to release frame references.
|
||||
while not self._frame_queue.empty():
|
||||
try:
|
||||
frame = self._frame_queue.get_nowait()
|
||||
@@ -431,12 +358,7 @@ class SynchronizedAudioTrack(MediaStreamTrack):
|
||||
|
||||
super().stop()
|
||||
|
||||
|
||||
class JingleWebRTCHandler:
|
||||
"""
|
||||
Manages the XMPP Jingle negotiation and WebRTC sessions.
|
||||
Translates between Jingle XML stanzas and SDP.
|
||||
"""
|
||||
NS_JINGLE = "urn:xmpp:jingle:1"
|
||||
NS_JINGLE_ICE = "urn:xmpp:jingle:transports:ice-udp:1"
|
||||
|
||||
@@ -462,7 +384,6 @@ class JingleWebRTCHandler:
|
||||
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||
|
||||
async def _periodic_cleanup(self):
|
||||
"""Periodically removes ended sessions to recover memory."""
|
||||
while not self._shutdown:
|
||||
await asyncio.sleep(self.SESSION_CLEANUP_INTERVAL)
|
||||
try:
|
||||
@@ -480,7 +401,6 @@ class JingleWebRTCHandler:
|
||||
await self.stop_session(sid)
|
||||
|
||||
def register_proposed_session(self, sid: str, peer_jid: str):
|
||||
# Limit the proposed session cache to prevent memory exhaustion attacks.
|
||||
if len(self._proposed_sessions) > 50:
|
||||
oldest = list(self._proposed_sessions.keys())[:25]
|
||||
for key in oldest:
|
||||
@@ -539,37 +459,54 @@ class JingleWebRTCHandler:
|
||||
await self.stop_session(sid)
|
||||
|
||||
def _extract_candidates_from_sdp(self, sdp: str) -> List[Dict]:
|
||||
"""Parses SDP string to extract ICE candidates as dictionaries."""
|
||||
candidates = []
|
||||
for line in sdp.splitlines():
|
||||
if line.startswith('a=candidate:'):
|
||||
match = re.match(
|
||||
r'a=candidate:(\S+)\s+(\d+)\s+(\w+)\s+(\d+)\s+(\S+)\s+(\d+)\s+typ\s+(\w+)(?:\s+raddr\s+(\S+))?\s*(?:rport\s+(\d+))?',
|
||||
line
|
||||
)
|
||||
if match:
|
||||
foundation, component, protocol, priority, ip, port, typ, raddr, rport = match.groups()
|
||||
cand = {
|
||||
'foundation': foundation,
|
||||
'component': component,
|
||||
'protocol': protocol.lower(),
|
||||
'priority': priority,
|
||||
'ip': ip,
|
||||
'port': port,
|
||||
'type': typ,
|
||||
'generation': '0',
|
||||
'id': f"cand-{random.randint(1000, 9999)}"
|
||||
}
|
||||
if raddr: cand['rel-addr'] = raddr
|
||||
if rport: cand['rel-port'] = rport
|
||||
candidates.append(cand)
|
||||
parts = line[12:].split()
|
||||
if len(parts) < 8:
|
||||
continue
|
||||
|
||||
cand = {
|
||||
'foundation': parts[0],
|
||||
'component': parts[1],
|
||||
'protocol': parts[2].lower(),
|
||||
'priority': parts[3],
|
||||
'ip': parts[4],
|
||||
'port': parts[5],
|
||||
'type': parts[7],
|
||||
'generation': '0',
|
||||
'id': f"cand-{random.randint(1000, 9999)}"
|
||||
}
|
||||
|
||||
for i in range(8, len(parts), 2):
|
||||
if i + 1 < len(parts):
|
||||
key = parts[i]
|
||||
val = parts[i+1]
|
||||
if key == 'raddr':
|
||||
cand['rel-addr'] = val
|
||||
elif key == 'rport':
|
||||
cand['rel-port'] = val
|
||||
elif key == 'generation':
|
||||
cand['generation'] = val
|
||||
elif key == 'network':
|
||||
cand['network'] = val
|
||||
|
||||
candidates.append(cand)
|
||||
return candidates
|
||||
|
||||
async def handle_session_initiate(self, jingle_xml, peer_jid, our_jid):
|
||||
"""Handles incoming Jingle session-initiate, creating the WebRTC answer."""
|
||||
sid = jingle_xml.get('sid')
|
||||
session = await self.create_session(sid, peer_jid)
|
||||
|
||||
content = jingle_xml.find(f"{{{self.NS_JINGLE}}}content")
|
||||
if content is None:
|
||||
for child in jingle_xml:
|
||||
if child.tag.endswith('content'):
|
||||
content = child
|
||||
break
|
||||
|
||||
content_name = content.get('name', '0') if content is not None else '0'
|
||||
|
||||
sdp = self._jingle_to_sdp(jingle_xml)
|
||||
offer = RTCSessionDescription(sdp=sdp, type="offer")
|
||||
|
||||
@@ -586,7 +523,6 @@ class JingleWebRTCHandler:
|
||||
session.pending_candidates.clear()
|
||||
|
||||
answer = await session.pc.createAnswer()
|
||||
# Set to passive to let the other side act as DTLS server if needed.
|
||||
answer_sdp = answer.sdp.replace('a=setup:active', 'a=setup:passive')
|
||||
answer = RTCSessionDescription(sdp=answer_sdp, type="answer")
|
||||
|
||||
@@ -602,21 +538,20 @@ class JingleWebRTCHandler:
|
||||
session.pc.localDescription.sdp
|
||||
)
|
||||
|
||||
jingle_xml = self._build_session_accept(session, sid, our_jid)
|
||||
jingle_xml = self._build_session_accept(session, sid, our_jid, content_name)
|
||||
return session, jingle_xml
|
||||
|
||||
def _build_session_accept(self, session: JingleSession, sid: str, responder: str) -> str:
|
||||
"""Constructs the session-accept Jingle XML stanza."""
|
||||
def _build_session_accept(self, session: JingleSession, sid: str, responder: str, content_name: str) -> str:
|
||||
import xml.etree.ElementTree as ET
|
||||
root = ET.Element('jingle', {'xmlns': self.NS_JINGLE, 'action': 'session-accept', 'sid': sid, 'responder': responder})
|
||||
content = ET.SubElement(root, 'content', {'creator': 'initiator', 'name': '0', 'senders': 'both'})
|
||||
content = ET.SubElement(root, 'content', {'creator': 'initiator', 'name': content_name, 'senders': 'both'})
|
||||
desc = ET.SubElement(content, 'description', {'xmlns': 'urn:xmpp:jingle:apps:rtp:1', 'media': 'audio'})
|
||||
|
||||
sdp = session.pc.localDescription.sdp
|
||||
codecs = {}
|
||||
for line in sdp.splitlines():
|
||||
if line.startswith("a=rtpmap:"):
|
||||
match = re.match(r'a=rtpmap:(\d+)\s+(\S+)/(\d+)(?:/(\d+))?', line)
|
||||
match = re.match(r'a=rtpmap:(\d+)\s+([^/]+)/(\d+)(?:/(\d+))?', line)
|
||||
if match:
|
||||
codec_id = match.group(1)
|
||||
codecs[codec_id] = {'name': match.group(2), 'clockrate': match.group(3), 'channels': match.group(4) or '1', 'fmtp': {}}
|
||||
@@ -690,12 +625,11 @@ class JingleWebRTCHandler:
|
||||
s.pc = None
|
||||
s.cleanup()
|
||||
del s
|
||||
logger.info(f"Stopped session {sid}")
|
||||
logger.info(f"🔴 Stopped session {sid}")
|
||||
self.clear_proposed_session(sid)
|
||||
gc.collect()
|
||||
|
||||
async def end_all_sessions(self):
|
||||
"""Terminates all active sessions and shuts down the broadcaster."""
|
||||
self._shutdown = True
|
||||
self._broadcaster.shutdown()
|
||||
if self._cleanup_task:
|
||||
@@ -712,7 +646,6 @@ class JingleWebRTCHandler:
|
||||
return [s for s in self.sessions.values() if s.state == CallState.ACTIVE]
|
||||
|
||||
def _jingle_to_sdp(self, xml):
|
||||
"""Converts Jingle XML format to standard SDP text."""
|
||||
lines = ["v=0", "o=- 0 0 IN IP4 0.0.0.0", "s=-", "t=0 0", "a=group:BUNDLE 0", "a=msid-semantic: WMS *"]
|
||||
content = xml.find(f"{{urn:xmpp:jingle:1}}content")
|
||||
desc = content.find(f"{{urn:xmpp:jingle:apps:rtp:1}}description")
|
||||
|
||||
Reference in New Issue
Block a user