Add jingle_webrtc.py
This commit is contained in:
742
jingle_webrtc.py
Normal file
742
jingle_webrtc.py
Normal file
@@ -0,0 +1,742 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
WebRTC and Jingle handling module.
|
||||||
|
|
||||||
|
This module provides the core logic for establishing Audio/Video sessions via
|
||||||
|
XMPP Jingle (XEP-0166) and WebRTC. It manages the RTCPeerConnections,
|
||||||
|
ICE candidate negotiation, and audio broadcasting.
|
||||||
|
|
||||||
|
Key features:
|
||||||
|
- Shared memory audio broadcasting to multiple subscribers.
|
||||||
|
- Synchronization of audio streams.
|
||||||
|
- Explicit cleanup routines to prevent memory leaks with aiortc objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import gc
|
||||||
|
import random
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import weakref
|
||||||
|
from typing import Optional, Dict, List, Any, Callable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from fractions import Fraction
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
AIORTC_AVAILABLE = False
|
||||||
|
try:
|
||||||
|
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration, RTCIceServer
|
||||||
|
from aiortc import MediaStreamTrack, RTCIceCandidate
|
||||||
|
from aiortc.contrib.media import MediaPlayer
|
||||||
|
import av
|
||||||
|
AIORTC_AVAILABLE = True
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"aiortc import failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class CallState(Enum):
|
||||||
|
"""Represents the lifecycle state of a Jingle session."""
|
||||||
|
IDLE = "idle"
|
||||||
|
ACTIVE = "active"
|
||||||
|
ENDED = "ended"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class JingleSession:
|
||||||
|
"""
|
||||||
|
Stores state for a single WebRTC session.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
sid: Session ID.
|
||||||
|
peer_jid: JID of the remote peer.
|
||||||
|
pc: The RTCPeerConnection object.
|
||||||
|
audio_track: The specific audio track associated with this session.
|
||||||
|
state: Current call state.
|
||||||
|
pending_candidates: ICE candidates received before remote description was set.
|
||||||
|
local_candidates: ICE candidates generated locally.
|
||||||
|
ice_gathering_complete: Event signaling completion of ICE gathering.
|
||||||
|
"""
|
||||||
|
sid: str
|
||||||
|
peer_jid: str
|
||||||
|
pc: Any = None
|
||||||
|
audio_track: Any = None
|
||||||
|
state: CallState = CallState.IDLE
|
||||||
|
pending_candidates: List[Any] = field(default_factory=list)
|
||||||
|
local_candidates: List[Dict] = field(default_factory=list)
|
||||||
|
ice_gathering_complete: asyncio.Event = field(default_factory=asyncio.Event)
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Releases references to heavy objects to aid garbage collection."""
|
||||||
|
self.pending_candidates.clear()
|
||||||
|
self.local_candidates.clear()
|
||||||
|
self.pc = None
|
||||||
|
self.audio_track = None
|
||||||
|
|
||||||
|
|
||||||
|
class SynchronizedRadioBroadcaster:
|
||||||
|
"""
|
||||||
|
Central audio broadcaster using a producer-consumer model.
|
||||||
|
|
||||||
|
This class decodes a single audio stream and distributes the frames to
|
||||||
|
multiple subscribers (SynchronizedAudioTrack instances) simultaneously.
|
||||||
|
It uses weak references to subscribers to avoid reference cycles.
|
||||||
|
"""
|
||||||
|
|
||||||
|
MAX_SUBSCRIBERS = 10
|
||||||
|
SAMPLES_PER_FRAME = 960
|
||||||
|
CLEANUP_INTERVAL = 200
|
||||||
|
|
||||||
|
def __init__(self, on_track_end_callback: Optional[Callable] = None):
|
||||||
|
self._player: Optional[MediaPlayer] = None
|
||||||
|
self._current_source: Optional[str] = None
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
# Use weakrefs to allow tracks to be GC'd if the connection drops unexpectedly.
|
||||||
|
self._subscribers: List[weakref.ref] = []
|
||||||
|
self._broadcast_task: Optional[asyncio.Task] = None
|
||||||
|
self._pts = 0
|
||||||
|
self._time_base = Fraction(1, 48000)
|
||||||
|
self.on_track_end = on_track_end_callback
|
||||||
|
self._shutdown = False
|
||||||
|
self._track_end_fired = False
|
||||||
|
self._stopped = False
|
||||||
|
self._frame_counter = 0
|
||||||
|
|
||||||
|
# Persist resampler to avoid high CPU usage from constant reallocation.
|
||||||
|
self._resampler = None
|
||||||
|
|
||||||
|
def _get_silence_frame(self) -> 'av.AudioFrame':
|
||||||
|
"""Generates a silent audio frame to keep the RTP stream alive."""
|
||||||
|
frame = av.AudioFrame(format='s16', layout='stereo', samples=self.SAMPLES_PER_FRAME)
|
||||||
|
frame.sample_rate = 48000
|
||||||
|
frame.time_base = self._time_base
|
||||||
|
for plane in frame.planes:
|
||||||
|
plane.update(bytes(plane.buffer_size))
|
||||||
|
frame.pts = self._pts
|
||||||
|
self._pts += self.SAMPLES_PER_FRAME
|
||||||
|
return frame
|
||||||
|
|
||||||
|
async def set_source(self, source: str, force_restart: bool = False):
|
||||||
|
"""
|
||||||
|
Updates the audio source, restarting the broadcast loop if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: URI or path to the audio file.
|
||||||
|
force_restart: If True, restarts player even if source is identical.
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
normalized_source = source
|
||||||
|
if source and not source.startswith(('http:', 'https:', 'rtmp:', 'rtsp:')):
|
||||||
|
normalized_source = os.path.abspath(source)
|
||||||
|
|
||||||
|
broadcast_is_running = (
|
||||||
|
self._broadcast_task and
|
||||||
|
not self._broadcast_task.done()
|
||||||
|
)
|
||||||
|
|
||||||
|
if (normalized_source == self._current_source and
|
||||||
|
self._player and
|
||||||
|
broadcast_is_running and
|
||||||
|
not force_restart):
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Changing broadcast: {os.path.basename(normalized_source) if normalized_source else 'None'}")
|
||||||
|
|
||||||
|
await self._stop_broadcast_task()
|
||||||
|
await self._close_player()
|
||||||
|
|
||||||
|
self._current_source = normalized_source
|
||||||
|
self._pts = 0
|
||||||
|
self._track_end_fired = False
|
||||||
|
self._stopped = False
|
||||||
|
self._frame_counter = 0
|
||||||
|
|
||||||
|
# Reset resampler context to prevent timestamp drift or format mismatch.
|
||||||
|
self._resampler = None
|
||||||
|
|
||||||
|
if not source:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Starting broadcast: {source}")
|
||||||
|
# Use a buffer for network streams to reduce jitter.
|
||||||
|
options = {'rtbufsize': '8M'} if source.startswith(('http', 'rtmp', 'rtsp')) else {}
|
||||||
|
self._player = MediaPlayer(source, options=options)
|
||||||
|
self._broadcast_task = asyncio.create_task(self._broadcast_loop())
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"MediaPlayer Error: {e}")
|
||||||
|
self._player = None
|
||||||
|
|
||||||
|
async def _stop_broadcast_task(self):
|
||||||
|
"""Cancels and awaits the broadcast task."""
|
||||||
|
if self._broadcast_task:
|
||||||
|
self._broadcast_task.cancel()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._broadcast_task, timeout=2.0)
|
||||||
|
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||||
|
pass
|
||||||
|
self._broadcast_task = None
|
||||||
|
|
||||||
|
async def _close_player(self):
|
||||||
|
"""Safely closes the media player and underlying container."""
|
||||||
|
self._resampler = None
|
||||||
|
|
||||||
|
if self._player:
|
||||||
|
try:
|
||||||
|
old_player = self._player
|
||||||
|
self._player = None
|
||||||
|
|
||||||
|
if hasattr(old_player, 'audio') and old_player.audio:
|
||||||
|
try: old_player.audio.stop()
|
||||||
|
except: pass
|
||||||
|
|
||||||
|
if hasattr(old_player, '_container') and old_player._container:
|
||||||
|
try: old_player._container.close()
|
||||||
|
except: pass
|
||||||
|
|
||||||
|
del old_player
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error closing player: {e}")
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
async def stop_playback(self):
|
||||||
|
"""Stops the broadcast and clears the current source."""
|
||||||
|
async with self._lock:
|
||||||
|
self._stopped = True
|
||||||
|
await self._stop_broadcast_task()
|
||||||
|
await self._close_player()
|
||||||
|
self._current_source = None
|
||||||
|
logger.info("Playback stopped")
|
||||||
|
|
||||||
|
async def _broadcast_loop(self):
|
||||||
|
"""
|
||||||
|
Main loop: decodes frames, resamples them, and distributes to subscribers.
|
||||||
|
Handles errors and End-Of-File (EOF) events.
|
||||||
|
"""
|
||||||
|
last_error_time = 0
|
||||||
|
error_count_window = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
while self._player and not self._shutdown and not self._stopped:
|
||||||
|
self._frame_counter += 1
|
||||||
|
|
||||||
|
# Periodically clean up dead weak references.
|
||||||
|
if self._frame_counter % self.CLEANUP_INTERVAL == 0:
|
||||||
|
self._cleanup_subscribers()
|
||||||
|
|
||||||
|
# Pause decoding if no one is listening to save CPU.
|
||||||
|
if not self._get_active_subscribers():
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
frame = None
|
||||||
|
try:
|
||||||
|
if not self._player or not self._player.audio:
|
||||||
|
break
|
||||||
|
|
||||||
|
frame = await asyncio.wait_for(
|
||||||
|
self._player.audio.recv(),
|
||||||
|
timeout=0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
current_time = asyncio.get_event_loop().time()
|
||||||
|
if current_time - last_error_time > 5.0:
|
||||||
|
error_count_window = 0
|
||||||
|
|
||||||
|
# Ensure audio is standard stereo 48kHz.
|
||||||
|
if frame.format.name != 's16' or frame.sample_rate != 48000:
|
||||||
|
if self._resampler is None:
|
||||||
|
self._resampler = av.AudioResampler(format='s16', layout='stereo', rate=48000)
|
||||||
|
|
||||||
|
resampled_frames = self._resampler.resample(frame)
|
||||||
|
|
||||||
|
# Explicitly release the original frame to free C-level memory.
|
||||||
|
del frame
|
||||||
|
|
||||||
|
if resampled_frames:
|
||||||
|
frame = resampled_frames[0]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Rewrite timestamps for the continuous stream.
|
||||||
|
frame.pts = self._pts
|
||||||
|
frame.time_base = self._time_base
|
||||||
|
self._pts += frame.samples
|
||||||
|
|
||||||
|
await self._distribute_frame(frame)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Send silence on network timeout to prevent RTP timeout.
|
||||||
|
s_frame = self._get_silence_frame()
|
||||||
|
await self._distribute_frame(s_frame)
|
||||||
|
del s_frame
|
||||||
|
|
||||||
|
except (av.error.EOFError, StopIteration, StopAsyncIteration):
|
||||||
|
logger.info("Track ended (EOF)")
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
if "MediaStreamError" in str(type(e).__name__):
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.debug(f"Broadcast frame warning: {e}")
|
||||||
|
error_count_window += 1
|
||||||
|
if error_count_window >= 20:
|
||||||
|
logger.error("Too many errors, stopping broadcast")
|
||||||
|
break
|
||||||
|
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if frame is not None:
|
||||||
|
del frame
|
||||||
|
|
||||||
|
# Notify the main bot that the track finished.
|
||||||
|
if self.on_track_end and not self._shutdown and not self._track_end_fired and not self._stopped:
|
||||||
|
self._track_end_fired = True
|
||||||
|
try:
|
||||||
|
result = self.on_track_end()
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
asyncio.create_task(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Track end callback error: {e}")
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Broadcast loop error: {e}")
|
||||||
|
finally:
|
||||||
|
self._resampler = None
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
def _cleanup_subscribers(self):
|
||||||
|
self._subscribers = [ref for ref in self._subscribers if ref() is not None]
|
||||||
|
|
||||||
|
def _get_active_subscribers(self) -> List['SynchronizedAudioTrack']:
|
||||||
|
active = []
|
||||||
|
for ref in self._subscribers:
|
||||||
|
track = ref()
|
||||||
|
if track is not None and track._active:
|
||||||
|
active.append(track)
|
||||||
|
return active
|
||||||
|
|
||||||
|
async def _distribute_frame(self, frame: 'av.AudioFrame'):
|
||||||
|
subscribers = self._get_active_subscribers()
|
||||||
|
if not subscribers:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Dispatch frames to all tracks concurrently.
|
||||||
|
tasks = [sub._receive_frame(frame) for sub in subscribers]
|
||||||
|
if tasks:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
def subscribe(self, track: 'SynchronizedAudioTrack'):
|
||||||
|
self._cleanup_subscribers()
|
||||||
|
if len(self._subscribers) >= self.MAX_SUBSCRIBERS:
|
||||||
|
return
|
||||||
|
self._subscribers.append(weakref.ref(track))
|
||||||
|
|
||||||
|
def unsubscribe(self, track: 'SynchronizedAudioTrack'):
|
||||||
|
self._subscribers = [ref for ref in self._subscribers
|
||||||
|
if ref() is not None and ref() is not track]
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
self._shutdown = True
|
||||||
|
self._resampler = None
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
class SynchronizedAudioTrack(MediaStreamTrack):
|
||||||
|
"""
|
||||||
|
A MediaStreamTrack that receives frames from the broadcaster queue.
|
||||||
|
"""
|
||||||
|
kind = "audio"
|
||||||
|
MAX_QUEUE_SIZE = 3
|
||||||
|
|
||||||
|
def __init__(self, broadcaster: SynchronizedRadioBroadcaster):
|
||||||
|
super().__init__()
|
||||||
|
self._broadcaster = broadcaster
|
||||||
|
self._frame_queue: asyncio.Queue = asyncio.Queue(maxsize=self.MAX_QUEUE_SIZE)
|
||||||
|
self._active = True
|
||||||
|
|
||||||
|
self._silence_frame: Optional[av.AudioFrame] = None
|
||||||
|
self._silence_pts = 0
|
||||||
|
|
||||||
|
self._broadcaster.subscribe(self)
|
||||||
|
|
||||||
|
def _get_silence_frame(self) -> 'av.AudioFrame':
|
||||||
|
"""Lazily creates and reuses a silence frame."""
|
||||||
|
if self._silence_frame is None:
|
||||||
|
self._silence_frame = av.AudioFrame(format='s16', layout='stereo', samples=960)
|
||||||
|
self._silence_frame.sample_rate = 48000
|
||||||
|
self._silence_frame.time_base = Fraction(1, 48000)
|
||||||
|
for plane in self._silence_frame.planes:
|
||||||
|
plane.update(bytes(plane.buffer_size))
|
||||||
|
|
||||||
|
self._silence_frame.pts = self._silence_pts
|
||||||
|
self._silence_pts += 960
|
||||||
|
return self._silence_frame
|
||||||
|
|
||||||
|
async def _receive_frame(self, frame: 'av.AudioFrame'):
|
||||||
|
if not self._active:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# If the queue is full, drop the oldest frame to reduce latency.
|
||||||
|
if self._frame_queue.full():
|
||||||
|
try:
|
||||||
|
old_frame = self._frame_queue.get_nowait()
|
||||||
|
del old_frame
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._frame_queue.put_nowait(frame)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def recv(self):
|
||||||
|
"""Called by aiortc to pull the next frame."""
|
||||||
|
if not self._active:
|
||||||
|
raise Exception("Track stopped")
|
||||||
|
|
||||||
|
try:
|
||||||
|
frame = await asyncio.wait_for(
|
||||||
|
self._frame_queue.get(),
|
||||||
|
timeout=0.05
|
||||||
|
)
|
||||||
|
return frame
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return self._get_silence_frame()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self._active = False
|
||||||
|
self._broadcaster.unsubscribe(self)
|
||||||
|
|
||||||
|
# Drain queue to release frame references.
|
||||||
|
while not self._frame_queue.empty():
|
||||||
|
try:
|
||||||
|
frame = self._frame_queue.get_nowait()
|
||||||
|
del frame
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
break
|
||||||
|
|
||||||
|
if self._silence_frame is not None:
|
||||||
|
del self._silence_frame
|
||||||
|
self._silence_frame = None
|
||||||
|
|
||||||
|
super().stop()
|
||||||
|
|
||||||
|
|
||||||
|
class JingleWebRTCHandler:
|
||||||
|
"""
|
||||||
|
Manages the XMPP Jingle negotiation and WebRTC sessions.
|
||||||
|
Translates between Jingle XML stanzas and SDP.
|
||||||
|
"""
|
||||||
|
NS_JINGLE = "urn:xmpp:jingle:1"
|
||||||
|
NS_JINGLE_ICE = "urn:xmpp:jingle:transports:ice-udp:1"
|
||||||
|
|
||||||
|
MAX_SESSIONS = 20
|
||||||
|
SESSION_CLEANUP_INTERVAL = 60
|
||||||
|
|
||||||
|
def __init__(self, stun_server: str, send_transport_info_callback: Optional[Callable] = None, on_track_end: Optional[Callable] = None):
|
||||||
|
if stun_server and not stun_server.startswith(('stun:', 'turn:', 'stuns:')):
|
||||||
|
self.stun_server = f"stun:{stun_server}"
|
||||||
|
else:
|
||||||
|
self.stun_server = stun_server
|
||||||
|
|
||||||
|
self.sessions: Dict[str, JingleSession] = {}
|
||||||
|
self._proposed_sessions: Dict[str, str] = {}
|
||||||
|
|
||||||
|
self._broadcaster = SynchronizedRadioBroadcaster(on_track_end_callback=on_track_end)
|
||||||
|
self.send_transport_info = send_transport_info_callback
|
||||||
|
|
||||||
|
self._cleanup_task: Optional[asyncio.Task] = None
|
||||||
|
self._shutdown = False
|
||||||
|
|
||||||
|
async def start_cleanup_task(self):
|
||||||
|
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
|
||||||
|
|
||||||
|
async def _periodic_cleanup(self):
|
||||||
|
"""Periodically removes ended sessions to recover memory."""
|
||||||
|
while not self._shutdown:
|
||||||
|
await asyncio.sleep(self.SESSION_CLEANUP_INTERVAL)
|
||||||
|
try:
|
||||||
|
await self._cleanup_dead_sessions()
|
||||||
|
gc.collect()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Cleanup error: {e}")
|
||||||
|
|
||||||
|
async def _cleanup_dead_sessions(self):
|
||||||
|
dead_sessions = [
|
||||||
|
sid for sid, session in self.sessions.items()
|
||||||
|
if session.state == CallState.ENDED
|
||||||
|
]
|
||||||
|
for sid in dead_sessions:
|
||||||
|
await self.stop_session(sid)
|
||||||
|
|
||||||
|
def register_proposed_session(self, sid: str, peer_jid: str):
|
||||||
|
# Limit the proposed session cache to prevent memory exhaustion attacks.
|
||||||
|
if len(self._proposed_sessions) > 50:
|
||||||
|
oldest = list(self._proposed_sessions.keys())[:25]
|
||||||
|
for key in oldest:
|
||||||
|
del self._proposed_sessions[key]
|
||||||
|
self._proposed_sessions[sid] = peer_jid
|
||||||
|
|
||||||
|
def clear_proposed_session(self, sid: str):
|
||||||
|
self._proposed_sessions.pop(sid, None)
|
||||||
|
|
||||||
|
def set_audio_source(self, source: str, force_restart: bool = False):
|
||||||
|
asyncio.create_task(self._broadcaster.set_source(source, force_restart))
|
||||||
|
|
||||||
|
async def stop_playback(self):
|
||||||
|
await self._broadcaster.stop_playback()
|
||||||
|
|
||||||
|
def get_session(self, sid: str):
|
||||||
|
return self.sessions.get(sid)
|
||||||
|
|
||||||
|
async def create_session(self, sid, peer_jid) -> JingleSession:
|
||||||
|
if len(self.sessions) >= self.MAX_SESSIONS:
|
||||||
|
await self._cleanup_dead_sessions()
|
||||||
|
if len(self.sessions) >= self.MAX_SESSIONS:
|
||||||
|
oldest = list(self.sessions.keys())[0]
|
||||||
|
await self.stop_session(oldest)
|
||||||
|
|
||||||
|
servers = []
|
||||||
|
if self.stun_server:
|
||||||
|
servers.append(RTCIceServer(urls=self.stun_server))
|
||||||
|
|
||||||
|
config = RTCConfiguration(iceServers=servers)
|
||||||
|
pc = RTCPeerConnection(configuration=config)
|
||||||
|
|
||||||
|
session = JingleSession(sid=sid, peer_jid=peer_jid, pc=pc)
|
||||||
|
self.sessions[sid] = session
|
||||||
|
|
||||||
|
@pc.on("icegatheringstatechange")
|
||||||
|
async def on_gathering_state():
|
||||||
|
if pc.iceGatheringState == "complete":
|
||||||
|
session.ice_gathering_complete.set()
|
||||||
|
|
||||||
|
@pc.on("connectionstatechange")
|
||||||
|
async def on_state():
|
||||||
|
if pc.connectionState == "connected":
|
||||||
|
session.state = CallState.ACTIVE
|
||||||
|
elif pc.connectionState in ["failed", "closed", "disconnected"]:
|
||||||
|
session.state = CallState.ENDED
|
||||||
|
asyncio.create_task(self._cleanup_session(sid))
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
async def _cleanup_session(self, sid: str):
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
if sid in self.sessions:
|
||||||
|
session = self.sessions[sid]
|
||||||
|
if session.state == CallState.ENDED:
|
||||||
|
await self.stop_session(sid)
|
||||||
|
|
||||||
|
def _extract_candidates_from_sdp(self, sdp: str) -> List[Dict]:
|
||||||
|
"""Parses SDP string to extract ICE candidates as dictionaries."""
|
||||||
|
candidates = []
|
||||||
|
for line in sdp.splitlines():
|
||||||
|
if line.startswith('a=candidate:'):
|
||||||
|
match = re.match(
|
||||||
|
r'a=candidate:(\S+)\s+(\d+)\s+(\w+)\s+(\d+)\s+(\S+)\s+(\d+)\s+typ\s+(\w+)(?:\s+raddr\s+(\S+))?\s*(?:rport\s+(\d+))?',
|
||||||
|
line
|
||||||
|
)
|
||||||
|
if match:
|
||||||
|
foundation, component, protocol, priority, ip, port, typ, raddr, rport = match.groups()
|
||||||
|
cand = {
|
||||||
|
'foundation': foundation,
|
||||||
|
'component': component,
|
||||||
|
'protocol': protocol.lower(),
|
||||||
|
'priority': priority,
|
||||||
|
'ip': ip,
|
||||||
|
'port': port,
|
||||||
|
'type': typ,
|
||||||
|
'generation': '0',
|
||||||
|
'id': f"cand-{random.randint(1000, 9999)}"
|
||||||
|
}
|
||||||
|
if raddr: cand['rel-addr'] = raddr
|
||||||
|
if rport: cand['rel-port'] = rport
|
||||||
|
candidates.append(cand)
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
async def handle_session_initiate(self, jingle_xml, peer_jid, our_jid):
|
||||||
|
"""Handles incoming Jingle session-initiate, creating the WebRTC answer."""
|
||||||
|
sid = jingle_xml.get('sid')
|
||||||
|
session = await self.create_session(sid, peer_jid)
|
||||||
|
|
||||||
|
sdp = self._jingle_to_sdp(jingle_xml)
|
||||||
|
offer = RTCSessionDescription(sdp=sdp, type="offer")
|
||||||
|
|
||||||
|
track = SynchronizedAudioTrack(self._broadcaster)
|
||||||
|
session.audio_track = track
|
||||||
|
session.pc.addTrack(track)
|
||||||
|
|
||||||
|
await session.pc.setRemoteDescription(offer)
|
||||||
|
|
||||||
|
if session.pending_candidates:
|
||||||
|
for cand in session.pending_candidates:
|
||||||
|
try: await session.pc.addIceCandidate(cand)
|
||||||
|
except: pass
|
||||||
|
session.pending_candidates.clear()
|
||||||
|
|
||||||
|
answer = await session.pc.createAnswer()
|
||||||
|
# Set to passive to let the other side act as DTLS server if needed.
|
||||||
|
answer_sdp = answer.sdp.replace('a=setup:active', 'a=setup:passive')
|
||||||
|
answer = RTCSessionDescription(sdp=answer_sdp, type="answer")
|
||||||
|
|
||||||
|
await session.pc.setLocalDescription(answer)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(session.ice_gathering_complete.wait(), timeout=5.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if session.pc.localDescription:
|
||||||
|
session.local_candidates = self._extract_candidates_from_sdp(
|
||||||
|
session.pc.localDescription.sdp
|
||||||
|
)
|
||||||
|
|
||||||
|
jingle_xml = self._build_session_accept(session, sid, our_jid)
|
||||||
|
return session, jingle_xml
|
||||||
|
|
||||||
|
def _build_session_accept(self, session: JingleSession, sid: str, responder: str) -> str:
|
||||||
|
"""Constructs the session-accept Jingle XML stanza."""
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
root = ET.Element('jingle', {'xmlns': self.NS_JINGLE, 'action': 'session-accept', 'sid': sid, 'responder': responder})
|
||||||
|
content = ET.SubElement(root, 'content', {'creator': 'initiator', 'name': '0', 'senders': 'both'})
|
||||||
|
desc = ET.SubElement(content, 'description', {'xmlns': 'urn:xmpp:jingle:apps:rtp:1', 'media': 'audio'})
|
||||||
|
|
||||||
|
sdp = session.pc.localDescription.sdp
|
||||||
|
codecs = {}
|
||||||
|
for line in sdp.splitlines():
|
||||||
|
if line.startswith("a=rtpmap:"):
|
||||||
|
match = re.match(r'a=rtpmap:(\d+)\s+(\S+)/(\d+)(?:/(\d+))?', line)
|
||||||
|
if match:
|
||||||
|
codec_id = match.group(1)
|
||||||
|
codecs[codec_id] = {'name': match.group(2), 'clockrate': match.group(3), 'channels': match.group(4) or '1', 'fmtp': {}}
|
||||||
|
elif line.startswith("a=fmtp:"):
|
||||||
|
match = re.match(r'a=fmtp:(\d+)\s+(.+)', line)
|
||||||
|
if match:
|
||||||
|
codec_id = match.group(1)
|
||||||
|
fmtp_params = {}
|
||||||
|
for param in match.group(2).split(';'):
|
||||||
|
if '=' in param:
|
||||||
|
key, val = param.strip().split('=', 1)
|
||||||
|
fmtp_params[key.strip()] = val.strip()
|
||||||
|
if codec_id in codecs:
|
||||||
|
codecs[codec_id]['fmtp'] = fmtp_params
|
||||||
|
|
||||||
|
for codec_id, codec_info in codecs.items():
|
||||||
|
pt_attrs = {'id': codec_id, 'name': codec_info['name'], 'clockrate': codec_info['clockrate']}
|
||||||
|
if codec_info['channels'] != '1': pt_attrs['channels'] = codec_info['channels']
|
||||||
|
pt = ET.SubElement(desc, 'payload-type', pt_attrs)
|
||||||
|
for param_name, param_value in codec_info['fmtp'].items():
|
||||||
|
ET.SubElement(pt, 'parameter', {'name': param_name, 'value': param_value})
|
||||||
|
|
||||||
|
ET.SubElement(desc, 'rtcp-mux')
|
||||||
|
trans = ET.SubElement(content, 'transport', {'xmlns': self.NS_JINGLE_ICE})
|
||||||
|
|
||||||
|
for line in sdp.splitlines():
|
||||||
|
if line.startswith("a=ice-ufrag:"): trans.set('ufrag', line.split(':')[1].strip())
|
||||||
|
if line.startswith("a=ice-pwd:"): trans.set('pwd', line.split(':')[1].strip())
|
||||||
|
if line.startswith("a=fingerprint:"):
|
||||||
|
parts = line.split(':', 1)[1].strip().split(None, 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
fp = ET.SubElement(trans, 'fingerprint', {'xmlns': 'urn:xmpp:jingle:apps:dtls:0', 'hash': parts[0], 'setup': 'passive'})
|
||||||
|
fp.text = parts[1]
|
||||||
|
|
||||||
|
for cand in session.local_candidates:
|
||||||
|
cand_elem = ET.SubElement(trans, 'candidate', {
|
||||||
|
'component': str(cand.get('component', '1')), 'foundation': str(cand.get('foundation', '1')),
|
||||||
|
'generation': str(cand.get('generation', '0')), 'id': cand.get('id', 'cand-0'),
|
||||||
|
'ip': cand.get('ip', ''), 'port': str(cand.get('port', '0')),
|
||||||
|
'priority': str(cand.get('priority', '1')), 'protocol': cand.get('protocol', 'udp'),
|
||||||
|
'type': cand.get('type', 'host')
|
||||||
|
})
|
||||||
|
if 'rel-addr' in cand: cand_elem.set('rel-addr', cand['rel-addr'])
|
||||||
|
if 'rel-port' in cand: cand_elem.set('rel-port', str(cand['rel-port']))
|
||||||
|
|
||||||
|
return ET.tostring(root, encoding='unicode')
|
||||||
|
|
||||||
|
async def add_ice_candidate(self, session: JingleSession, xml_cand):
|
||||||
|
try:
|
||||||
|
cand = RTCIceCandidate(
|
||||||
|
foundation=xml_cand.get('foundation', '1'), component=int(xml_cand.get('component', '1')),
|
||||||
|
protocol=xml_cand.get('protocol', 'udp'), priority=int(xml_cand.get('priority', '1')),
|
||||||
|
ip=xml_cand.get('ip'), port=int(xml_cand.get('port')), type=xml_cand.get('type', 'host'),
|
||||||
|
sdpMid="0", sdpMLineIndex=0
|
||||||
|
)
|
||||||
|
if session.pc.remoteDescription: await session.pc.addIceCandidate(cand)
|
||||||
|
else: session.pending_candidates.append(cand)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Candidate error: {e}")
|
||||||
|
|
||||||
|
async def stop_session(self, sid):
|
||||||
|
if sid in self.sessions:
|
||||||
|
s = self.sessions.pop(sid)
|
||||||
|
if s.audio_track:
|
||||||
|
try: s.audio_track.stop()
|
||||||
|
except: pass
|
||||||
|
s.audio_track = None
|
||||||
|
if s.pc:
|
||||||
|
try: await s.pc.close()
|
||||||
|
except: pass
|
||||||
|
s.pc = None
|
||||||
|
s.cleanup()
|
||||||
|
del s
|
||||||
|
logger.info(f"Stopped session {sid}")
|
||||||
|
self.clear_proposed_session(sid)
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
async def end_all_sessions(self):
|
||||||
|
"""Terminates all active sessions and shuts down the broadcaster."""
|
||||||
|
self._shutdown = True
|
||||||
|
self._broadcaster.shutdown()
|
||||||
|
if self._cleanup_task:
|
||||||
|
self._cleanup_task.cancel()
|
||||||
|
try: await self._cleanup_task
|
||||||
|
except asyncio.CancelledError: pass
|
||||||
|
self._cleanup_task = None
|
||||||
|
for sid in list(self.sessions.keys()):
|
||||||
|
await self.stop_session(sid)
|
||||||
|
self._proposed_sessions.clear()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
def get_active_sessions(self):
|
||||||
|
return [s for s in self.sessions.values() if s.state == CallState.ACTIVE]
|
||||||
|
|
||||||
|
def _jingle_to_sdp(self, xml):
|
||||||
|
"""Converts Jingle XML format to standard SDP text."""
|
||||||
|
lines = ["v=0", "o=- 0 0 IN IP4 0.0.0.0", "s=-", "t=0 0", "a=group:BUNDLE 0", "a=msid-semantic: WMS *"]
|
||||||
|
content = xml.find(f"{{urn:xmpp:jingle:1}}content")
|
||||||
|
desc = content.find(f"{{urn:xmpp:jingle:apps:rtp:1}}description")
|
||||||
|
trans = content.find(f"{{urn:xmpp:jingle:transports:ice-udp:1}}transport")
|
||||||
|
|
||||||
|
payloads = [pt.get('id') for pt in desc.findall(f"{{urn:xmpp:jingle:apps:rtp:1}}payload-type")]
|
||||||
|
lines.append(f"m=audio 9 UDP/TLS/RTP/SAVPF {' '.join(payloads)}")
|
||||||
|
lines.append("c=IN IP4 0.0.0.0")
|
||||||
|
lines.append("a=rtcp:9 IN IP4 0.0.0.0")
|
||||||
|
lines.append(f"a=ice-ufrag:{trans.get('ufrag')}")
|
||||||
|
lines.append(f"a=ice-pwd:{trans.get('pwd')}")
|
||||||
|
|
||||||
|
fp = trans.find(f"{{urn:xmpp:jingle:apps:dtls:0}}fingerprint")
|
||||||
|
if fp is not None:
|
||||||
|
lines.append(f"a=fingerprint:{fp.get('hash')} {fp.text}")
|
||||||
|
setup = fp.get('setup', 'actpass')
|
||||||
|
if setup == 'actpass': setup = 'active'
|
||||||
|
lines.append(f"a=setup:{setup}")
|
||||||
|
|
||||||
|
lines.append("a=mid:0")
|
||||||
|
lines.append("a=sendrecv")
|
||||||
|
lines.append("a=rtcp-mux")
|
||||||
|
|
||||||
|
for pt in desc.findall(f"{{urn:xmpp:jingle:apps:rtp:1}}payload-type"):
|
||||||
|
lines.append(f"a=rtpmap:{pt.get('id')} {pt.get('name')}/{pt.get('clockrate')}/{pt.get('channels', '1')}")
|
||||||
|
|
||||||
|
return "\r\n".join(lines)
|
||||||
Reference in New Issue
Block a user