Files
xmpp-radio-tower/jingle_webrtc.py
2025-12-17 17:24:34 +00:00

675 lines
26 KiB
Python

#!/usr/bin/env python3
import asyncio
import gc
import random
import logging
import os
import re
import weakref
from typing import Optional, Dict, List, Any, Callable
from dataclasses import dataclass, field
from enum import Enum
from fractions import Fraction
logger = logging.getLogger(__name__)
AIORTC_AVAILABLE = False
try:
from aiortc import RTCPeerConnection, RTCSessionDescription, RTCConfiguration, RTCIceServer
from aiortc import MediaStreamTrack, RTCIceCandidate
from aiortc.contrib.media import MediaPlayer
import av
AIORTC_AVAILABLE = True
except ImportError as e:
logger.error(f"aiortc import failed: {e}")
class CallState(Enum):
IDLE = "idle"
ACTIVE = "active"
ENDED = "ended"
@dataclass(slots=True)
class JingleSession:
sid: str
peer_jid: str
pc: Any = None
audio_track: Any = None
state: CallState = CallState.IDLE
pending_candidates: List[Any] = field(default_factory=list)
local_candidates: List[Dict] = field(default_factory=list)
ice_gathering_complete: asyncio.Event = field(default_factory=asyncio.Event)
def cleanup(self):
self.pending_candidates.clear()
self.local_candidates.clear()
self.pc = None
self.audio_track = None
class SynchronizedRadioBroadcaster:
MAX_SUBSCRIBERS = 10
SAMPLES_PER_FRAME = 960
CLEANUP_INTERVAL = 200
def __init__(self, on_track_end_callback: Optional[Callable] = None):
self._player: Optional[MediaPlayer] = None
self._current_source: Optional[str] = None
self._lock = asyncio.Lock()
self._subscribers: List[weakref.ref] = []
self._broadcast_task: Optional[asyncio.Task] = None
self._pts = 0
self._time_base = Fraction(1, 48000)
self.on_track_end = on_track_end_callback
self._shutdown = False
self._track_end_fired = False
self._stopped = False
self._frame_counter = 0
self._resampler = None
def _get_silence_frame(self) -> 'av.AudioFrame':
frame = av.AudioFrame(format='s16', layout='stereo', samples=self.SAMPLES_PER_FRAME)
frame.sample_rate = 48000
frame.time_base = self._time_base
for plane in frame.planes:
plane.update(bytes(plane.buffer_size))
frame.pts = self._pts
self._pts += self.SAMPLES_PER_FRAME
return frame
async def set_source(self, source: str, force_restart: bool = False):
async with self._lock:
normalized_source = source
if source and not source.startswith(('http:', 'https:', 'rtmp:', 'rtsp:')):
normalized_source = os.path.abspath(source)
broadcast_is_running = (
self._broadcast_task and
not self._broadcast_task.done()
)
if (normalized_source == self._current_source and
self._player and
broadcast_is_running and
not force_restart):
return
logger.info(f"🔄 Changing broadcast: {os.path.basename(normalized_source) if normalized_source else 'None'}")
await self._stop_broadcast_task()
await self._close_player()
self._current_source = normalized_source
self._pts = 0
self._track_end_fired = False
self._stopped = False
self._frame_counter = 0
self._resampler = None
if not source:
return
try:
logger.info(f"🎵 Starting broadcast: {source}")
options = {'rtbufsize': '8M'} if source.startswith(('http', 'rtmp', 'rtsp')) else {}
self._player = MediaPlayer(source, options=options)
self._broadcast_task = asyncio.create_task(self._broadcast_loop())
except Exception as e:
logger.error(f"MediaPlayer Error: {e}")
self._player = None
async def _stop_broadcast_task(self):
if self._broadcast_task:
self._broadcast_task.cancel()
try:
await asyncio.wait_for(self._broadcast_task, timeout=2.0)
except (asyncio.CancelledError, asyncio.TimeoutError):
pass
self._broadcast_task = None
async def _close_player(self):
self._resampler = None
if self._player:
try:
old_player = self._player
self._player = None
if hasattr(old_player, 'audio') and old_player.audio:
try: old_player.audio.stop()
except: pass
if hasattr(old_player, '_container') and old_player._container:
try: old_player._container.close()
except: pass
del old_player
except Exception as e:
logger.debug(f"Error closing player: {e}")
gc.collect()
async def stop_playback(self):
async with self._lock:
self._stopped = True
await self._stop_broadcast_task()
await self._close_player()
self._current_source = None
logger.info("⏹️ Playback stopped")
async def _broadcast_loop(self):
last_error_time = 0
error_count_window = 0
try:
while self._player and not self._shutdown and not self._stopped:
self._frame_counter += 1
if self._frame_counter % self.CLEANUP_INTERVAL == 0:
self._cleanup_subscribers()
if not self._get_active_subscribers():
await asyncio.sleep(0.1)
continue
frame = None
try:
if not self._player or not self._player.audio:
break
frame = await asyncio.wait_for(
self._player.audio.recv(),
timeout=0.5
)
current_time = asyncio.get_event_loop().time()
if current_time - last_error_time > 5.0:
error_count_window = 0
if frame.format.name != 's16' or frame.sample_rate != 48000:
if self._resampler is None:
self._resampler = av.AudioResampler(format='s16', layout='stereo', rate=48000)
resampled_frames = self._resampler.resample(frame)
del frame
if resampled_frames:
frame = resampled_frames[0]
else:
continue
frame.pts = self._pts
frame.time_base = self._time_base
self._pts += frame.samples
await self._distribute_frame(frame)
except asyncio.TimeoutError:
s_frame = self._get_silence_frame()
await self._distribute_frame(s_frame)
del s_frame
except (av.error.EOFError, StopIteration, StopAsyncIteration):
logger.info("📻 Track ended (EOF)")
break
except Exception as e:
error_msg = str(e)
if "MediaStreamError" in str(type(e).__name__):
break
logger.debug(f"Broadcast frame warning: {e}")
error_count_window += 1
if error_count_window >= 20:
logger.error("Too many errors, stopping broadcast")
break
await asyncio.sleep(0.1)
finally:
if frame is not None:
del frame
if self.on_track_end and not self._shutdown and not self._track_end_fired and not self._stopped:
self._track_end_fired = True
try:
result = self.on_track_end()
if asyncio.iscoroutine(result):
asyncio.create_task(result)
except Exception as e:
logger.error(f"Track end callback error: {e}")
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"Broadcast loop error: {e}")
finally:
self._resampler = None
gc.collect()
def _cleanup_subscribers(self):
self._subscribers = [ref for ref in self._subscribers if ref() is not None]
def _get_active_subscribers(self) -> List['SynchronizedAudioTrack']:
active = []
for ref in self._subscribers:
track = ref()
if track is not None and track._active:
active.append(track)
return active
async def _distribute_frame(self, frame: 'av.AudioFrame'):
subscribers = self._get_active_subscribers()
if not subscribers:
return
tasks = [sub._receive_frame(frame) for sub in subscribers]
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
def subscribe(self, track: 'SynchronizedAudioTrack'):
self._cleanup_subscribers()
if len(self._subscribers) >= self.MAX_SUBSCRIBERS:
return
self._subscribers.append(weakref.ref(track))
def unsubscribe(self, track: 'SynchronizedAudioTrack'):
self._subscribers = [ref for ref in self._subscribers
if ref() is not None and ref() is not track]
def shutdown(self):
self._shutdown = True
self._resampler = None
gc.collect()
class SynchronizedAudioTrack(MediaStreamTrack):
kind = "audio"
MAX_QUEUE_SIZE = 3
def __init__(self, broadcaster: SynchronizedRadioBroadcaster):
super().__init__()
self._broadcaster = broadcaster
self._frame_queue: asyncio.Queue = asyncio.Queue(maxsize=self.MAX_QUEUE_SIZE)
self._active = True
self._silence_frame: Optional[av.AudioFrame] = None
self._silence_pts = 0
self._broadcaster.subscribe(self)
def _get_silence_frame(self) -> 'av.AudioFrame':
if self._silence_frame is None:
self._silence_frame = av.AudioFrame(format='s16', layout='stereo', samples=960)
self._silence_frame.sample_rate = 48000
self._silence_frame.time_base = Fraction(1, 48000)
for plane in self._silence_frame.planes:
plane.update(bytes(plane.buffer_size))
self._silence_frame.pts = self._silence_pts
self._silence_pts += 960
return self._silence_frame
async def _receive_frame(self, frame: 'av.AudioFrame'):
if not self._active:
return
try:
if self._frame_queue.full():
try:
old_frame = self._frame_queue.get_nowait()
del old_frame
except asyncio.QueueEmpty:
pass
self._frame_queue.put_nowait(frame)
except Exception:
pass
async def recv(self):
if not self._active:
raise Exception("Track stopped")
try:
frame = await asyncio.wait_for(
self._frame_queue.get(),
timeout=0.05
)
return frame
except asyncio.TimeoutError:
return self._get_silence_frame()
def stop(self):
self._active = False
self._broadcaster.unsubscribe(self)
while not self._frame_queue.empty():
try:
frame = self._frame_queue.get_nowait()
del frame
except asyncio.QueueEmpty:
break
if self._silence_frame is not None:
del self._silence_frame
self._silence_frame = None
super().stop()
class JingleWebRTCHandler:
NS_JINGLE = "urn:xmpp:jingle:1"
NS_JINGLE_ICE = "urn:xmpp:jingle:transports:ice-udp:1"
MAX_SESSIONS = 20
SESSION_CLEANUP_INTERVAL = 60
def __init__(self, stun_server: str, send_transport_info_callback: Optional[Callable] = None, on_track_end: Optional[Callable] = None):
if stun_server and not stun_server.startswith(('stun:', 'turn:', 'stuns:')):
self.stun_server = f"stun:{stun_server}"
else:
self.stun_server = stun_server
self.sessions: Dict[str, JingleSession] = {}
self._proposed_sessions: Dict[str, str] = {}
self._broadcaster = SynchronizedRadioBroadcaster(on_track_end_callback=on_track_end)
self.send_transport_info = send_transport_info_callback
self._cleanup_task: Optional[asyncio.Task] = None
self._shutdown = False
async def start_cleanup_task(self):
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
async def _periodic_cleanup(self):
while not self._shutdown:
await asyncio.sleep(self.SESSION_CLEANUP_INTERVAL)
try:
await self._cleanup_dead_sessions()
gc.collect()
except Exception as e:
logger.error(f"Cleanup error: {e}")
async def _cleanup_dead_sessions(self):
dead_sessions = [
sid for sid, session in self.sessions.items()
if session.state == CallState.ENDED
]
for sid in dead_sessions:
await self.stop_session(sid)
def register_proposed_session(self, sid: str, peer_jid: str):
if len(self._proposed_sessions) > 50:
oldest = list(self._proposed_sessions.keys())[:25]
for key in oldest:
del self._proposed_sessions[key]
self._proposed_sessions[sid] = peer_jid
def clear_proposed_session(self, sid: str):
self._proposed_sessions.pop(sid, None)
def set_audio_source(self, source: str, force_restart: bool = False):
asyncio.create_task(self._broadcaster.set_source(source, force_restart))
async def stop_playback(self):
await self._broadcaster.stop_playback()
def get_session(self, sid: str):
return self.sessions.get(sid)
async def create_session(self, sid, peer_jid) -> JingleSession:
if len(self.sessions) >= self.MAX_SESSIONS:
await self._cleanup_dead_sessions()
if len(self.sessions) >= self.MAX_SESSIONS:
oldest = list(self.sessions.keys())[0]
await self.stop_session(oldest)
servers = []
if self.stun_server:
servers.append(RTCIceServer(urls=self.stun_server))
config = RTCConfiguration(iceServers=servers)
pc = RTCPeerConnection(configuration=config)
session = JingleSession(sid=sid, peer_jid=peer_jid, pc=pc)
self.sessions[sid] = session
@pc.on("icegatheringstatechange")
async def on_gathering_state():
if pc.iceGatheringState == "complete":
session.ice_gathering_complete.set()
@pc.on("connectionstatechange")
async def on_state():
if pc.connectionState == "connected":
session.state = CallState.ACTIVE
elif pc.connectionState in ["failed", "closed", "disconnected"]:
session.state = CallState.ENDED
asyncio.create_task(self._cleanup_session(sid))
return session
async def _cleanup_session(self, sid: str):
await asyncio.sleep(1)
if sid in self.sessions:
session = self.sessions[sid]
if session.state == CallState.ENDED:
await self.stop_session(sid)
def _extract_candidates_from_sdp(self, sdp: str) -> List[Dict]:
candidates = []
for line in sdp.splitlines():
if line.startswith('a=candidate:'):
parts = line[12:].split()
if len(parts) < 8:
continue
cand = {
'foundation': parts[0],
'component': parts[1],
'protocol': parts[2].lower(),
'priority': parts[3],
'ip': parts[4],
'port': parts[5],
'type': parts[7],
'generation': '0',
'id': f"cand-{random.randint(1000, 9999)}"
}
for i in range(8, len(parts), 2):
if i + 1 < len(parts):
key = parts[i]
val = parts[i+1]
if key == 'raddr':
cand['rel-addr'] = val
elif key == 'rport':
cand['rel-port'] = val
elif key == 'generation':
cand['generation'] = val
elif key == 'network':
cand['network'] = val
candidates.append(cand)
return candidates
async def handle_session_initiate(self, jingle_xml, peer_jid, our_jid):
sid = jingle_xml.get('sid')
session = await self.create_session(sid, peer_jid)
content = jingle_xml.find(f"{{{self.NS_JINGLE}}}content")
if content is None:
for child in jingle_xml:
if child.tag.endswith('content'):
content = child
break
content_name = content.get('name', '0') if content is not None else '0'
sdp = self._jingle_to_sdp(jingle_xml)
offer = RTCSessionDescription(sdp=sdp, type="offer")
track = SynchronizedAudioTrack(self._broadcaster)
session.audio_track = track
session.pc.addTrack(track)
await session.pc.setRemoteDescription(offer)
if session.pending_candidates:
for cand in session.pending_candidates:
try: await session.pc.addIceCandidate(cand)
except: pass
session.pending_candidates.clear()
answer = await session.pc.createAnswer()
answer_sdp = answer.sdp.replace('a=setup:active', 'a=setup:passive')
answer = RTCSessionDescription(sdp=answer_sdp, type="answer")
await session.pc.setLocalDescription(answer)
try:
await asyncio.wait_for(session.ice_gathering_complete.wait(), timeout=5.0)
except asyncio.TimeoutError:
pass
if session.pc.localDescription:
session.local_candidates = self._extract_candidates_from_sdp(
session.pc.localDescription.sdp
)
jingle_xml = self._build_session_accept(session, sid, our_jid, content_name)
return session, jingle_xml
def _build_session_accept(self, session: JingleSession, sid: str, responder: str, content_name: str) -> str:
import xml.etree.ElementTree as ET
root = ET.Element('jingle', {'xmlns': self.NS_JINGLE, 'action': 'session-accept', 'sid': sid, 'responder': responder})
content = ET.SubElement(root, 'content', {'creator': 'initiator', 'name': content_name, 'senders': 'both'})
desc = ET.SubElement(content, 'description', {'xmlns': 'urn:xmpp:jingle:apps:rtp:1', 'media': 'audio'})
sdp = session.pc.localDescription.sdp
codecs = {}
for line in sdp.splitlines():
if line.startswith("a=rtpmap:"):
match = re.match(r'a=rtpmap:(\d+)\s+([^/]+)/(\d+)(?:/(\d+))?', line)
if match:
codec_id = match.group(1)
codecs[codec_id] = {'name': match.group(2), 'clockrate': match.group(3), 'channels': match.group(4) or '1', 'fmtp': {}}
elif line.startswith("a=fmtp:"):
match = re.match(r'a=fmtp:(\d+)\s+(.+)', line)
if match:
codec_id = match.group(1)
fmtp_params = {}
for param in match.group(2).split(';'):
if '=' in param:
key, val = param.strip().split('=', 1)
fmtp_params[key.strip()] = val.strip()
if codec_id in codecs:
codecs[codec_id]['fmtp'] = fmtp_params
for codec_id, codec_info in codecs.items():
pt_attrs = {'id': codec_id, 'name': codec_info['name'], 'clockrate': codec_info['clockrate']}
if codec_info['channels'] != '1': pt_attrs['channels'] = codec_info['channels']
pt = ET.SubElement(desc, 'payload-type', pt_attrs)
for param_name, param_value in codec_info['fmtp'].items():
ET.SubElement(pt, 'parameter', {'name': param_name, 'value': param_value})
ET.SubElement(desc, 'rtcp-mux')
trans = ET.SubElement(content, 'transport', {'xmlns': self.NS_JINGLE_ICE})
for line in sdp.splitlines():
if line.startswith("a=ice-ufrag:"): trans.set('ufrag', line.split(':')[1].strip())
if line.startswith("a=ice-pwd:"): trans.set('pwd', line.split(':')[1].strip())
if line.startswith("a=fingerprint:"):
parts = line.split(':', 1)[1].strip().split(None, 1)
if len(parts) == 2:
fp = ET.SubElement(trans, 'fingerprint', {'xmlns': 'urn:xmpp:jingle:apps:dtls:0', 'hash': parts[0], 'setup': 'passive'})
fp.text = parts[1]
for cand in session.local_candidates:
cand_elem = ET.SubElement(trans, 'candidate', {
'component': str(cand.get('component', '1')), 'foundation': str(cand.get('foundation', '1')),
'generation': str(cand.get('generation', '0')), 'id': cand.get('id', 'cand-0'),
'ip': cand.get('ip', ''), 'port': str(cand.get('port', '0')),
'priority': str(cand.get('priority', '1')), 'protocol': cand.get('protocol', 'udp'),
'type': cand.get('type', 'host')
})
if 'rel-addr' in cand: cand_elem.set('rel-addr', cand['rel-addr'])
if 'rel-port' in cand: cand_elem.set('rel-port', str(cand['rel-port']))
return ET.tostring(root, encoding='unicode')
async def add_ice_candidate(self, session: JingleSession, xml_cand):
try:
cand = RTCIceCandidate(
foundation=xml_cand.get('foundation', '1'), component=int(xml_cand.get('component', '1')),
protocol=xml_cand.get('protocol', 'udp'), priority=int(xml_cand.get('priority', '1')),
ip=xml_cand.get('ip'), port=int(xml_cand.get('port')), type=xml_cand.get('type', 'host'),
sdpMid="0", sdpMLineIndex=0
)
if session.pc.remoteDescription: await session.pc.addIceCandidate(cand)
else: session.pending_candidates.append(cand)
except Exception as e:
logger.debug(f"Candidate error: {e}")
async def stop_session(self, sid):
if sid in self.sessions:
s = self.sessions.pop(sid)
if s.audio_track:
try: s.audio_track.stop()
except: pass
s.audio_track = None
if s.pc:
try: await s.pc.close()
except: pass
s.pc = None
s.cleanup()
del s
logger.info(f"🔴 Stopped session {sid}")
self.clear_proposed_session(sid)
gc.collect()
async def end_all_sessions(self):
self._shutdown = True
self._broadcaster.shutdown()
if self._cleanup_task:
self._cleanup_task.cancel()
try: await self._cleanup_task
except asyncio.CancelledError: pass
self._cleanup_task = None
for sid in list(self.sessions.keys()):
await self.stop_session(sid)
self._proposed_sessions.clear()
gc.collect()
def get_active_sessions(self):
return [s for s in self.sessions.values() if s.state == CallState.ACTIVE]
def _jingle_to_sdp(self, xml):
lines = ["v=0", "o=- 0 0 IN IP4 0.0.0.0", "s=-", "t=0 0", "a=group:BUNDLE 0", "a=msid-semantic: WMS *"]
content = xml.find(f"{{urn:xmpp:jingle:1}}content")
desc = content.find(f"{{urn:xmpp:jingle:apps:rtp:1}}description")
trans = content.find(f"{{urn:xmpp:jingle:transports:ice-udp:1}}transport")
payloads = [pt.get('id') for pt in desc.findall(f"{{urn:xmpp:jingle:apps:rtp:1}}payload-type")]
lines.append(f"m=audio 9 UDP/TLS/RTP/SAVPF {' '.join(payloads)}")
lines.append("c=IN IP4 0.0.0.0")
lines.append("a=rtcp:9 IN IP4 0.0.0.0")
lines.append(f"a=ice-ufrag:{trans.get('ufrag')}")
lines.append(f"a=ice-pwd:{trans.get('pwd')}")
fp = trans.find(f"{{urn:xmpp:jingle:apps:dtls:0}}fingerprint")
if fp is not None:
lines.append(f"a=fingerprint:{fp.get('hash')} {fp.text}")
setup = fp.get('setup', 'actpass')
if setup == 'actpass': setup = 'active'
lines.append(f"a=setup:{setup}")
lines.append("a=mid:0")
lines.append("a=sendrecv")
lines.append("a=rtcp-mux")
for pt in desc.findall(f"{{urn:xmpp:jingle:apps:rtp:1}}payload-type"):
lines.append(f"a=rtpmap:{pt.get('id')} {pt.get('name')}/{pt.get('clockrate')}/{pt.get('channels', '1')}")
return "\r\n".join(lines)