From f88840849d09f275a768d3aadf50f68e42a1bded Mon Sep 17 00:00:00 2001 From: just n Date: Sun, 28 Dec 2025 16:28:02 +0000 Subject: [PATCH] Add bot.py --- bot.py | 2199 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 2199 insertions(+) create mode 100644 bot.py diff --git a/bot.py b/bot.py new file mode 100644 index 0000000..7643ae8 --- /dev/null +++ b/bot.py @@ -0,0 +1,2199 @@ +#!/usr/bin/env python3 + +import slixmpp +import requests +import json +import time +from ratelimit import limits, sleep_and_retry +import configparser +import asyncio +import logging +import os +from pathlib import Path +from typing import Any, Dict, FrozenSet, Optional, Set +import re +import base64 +import wave +from io import BytesIO +from urllib.parse import urlparse, urljoin, urlunparse, parse_qs +from bs4 import BeautifulSoup +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.backends import default_backend + +from slixmpp.jid import JID +from slixmpp.stanza import Message +from slixmpp.xmlstream.handler import CoroutineCallback +from slixmpp.xmlstream.matcher import MatchXPath +from slixmpp.plugins import register_plugin + +from omemo.storage import Just, Maybe, Nothing, Storage +from omemo.types import DeviceInformation, JSONType +from slixmpp_omemo import TrustLevel, XEP_0384 + + +try: + from google import genai + from google.genai import types + GOOGLE_GENAI_AVAILABLE = True +except ImportError: + GOOGLE_GENAI_AVAILABLE = False + logging.warning("google-genai library not available.") + +try: + from openai import OpenAI + OPENAI_AVAILABLE = True +except ImportError: + OPENAI_AVAILABLE = False + logging.warning("openai library not available.") + +try: + from pydub import AudioSegment + PYDUB_AVAILABLE = True +except ImportError: + PYDUB_AVAILABLE = False + logging.warning("pydub library not available.") + + +class FileUploader: + + def __init__(self, service='catbox', api_key=None, xmpp_client=None): + self.service = service.lower() + self.api_key = api_key + self.xmpp_client = xmpp_client + + def upload(self, file_data: bytes, filename: str = "file.bin", mime_type: str = "application/octet-stream") -> Optional[str]: + try: + if self.service == 'catbox': + return self._upload_catbox(file_data, filename, mime_type) + elif self.service == 'litterbox': + return self._upload_litterbox(file_data, filename, mime_type) + elif self.service == '0x0': + return self._upload_0x0(file_data, filename, mime_type) + elif self.service == 'imgur': + return self._upload_imgur(file_data) + elif self.service == 'imgbb': + return self._upload_imgbb(file_data) + elif self.service == 'envs': + return self._upload_envs(file_data, filename, mime_type) + elif self.service == 'uguu': + return self._upload_uguu(file_data, filename, mime_type) + elif self.service == 'xmpp': + return self._upload_xmpp(file_data, filename, mime_type) + else: + logging.error(f"Unknown hosting service: {self.service}") + return None + except Exception as e: + logging.error(f"File upload failed: {e}") + return None + + def _upload_xmpp(self, file_data: bytes, filename: str, mime_type: str) -> Optional[str]: + if not self.xmpp_client: + logging.error("XMPP client not available for upload. Ensure 'xmpp' is selected as file_host.") + return None + + try: + logging.info(f"Starting XMPP HTTP File Upload for {filename} ({len(file_data)} bytes)...") + + input_file = BytesIO(file_data) + + future = asyncio.run_coroutine_threadsafe( + self.xmpp_client['xep_0363'].upload_file( + filename=filename, + size=len(file_data), + input_file=input_file, + content_type=mime_type, + timeout=60 + ), + self.xmpp_client.loop + ) + + result = future.result(timeout=70) + + logging.info(f"Uploaded to XMPP server: {result}") + return result + + except asyncio.TimeoutError: + logging.error("XMPP file upload timed out.") + return None + except Exception as e: + logging.error(f"XMPP file upload failed: {e}") + return None + + def _upload_catbox(self, file_data: bytes, filename: str, mime_type: str) -> Optional[str]: + url = "https://catbox.moe/user/api.php" + files = {'fileToUpload': (filename, BytesIO(file_data), mime_type)} + data = {'reqtype': 'fileupload'} + if self.api_key: + data['userhash'] = self.api_key + + response = requests.post(url, files=files, data=data, timeout=30) + response.raise_for_status() + + result = response.text.strip() + if result.startswith('https://'): + logging.info(f"Uploaded to catbox: {result}") + return result + return None + + def _upload_litterbox(self, file_data: bytes, filename: str, mime_type: str, expiry: str = "72h") -> Optional[str]: + url = "https://litterbox.catbox.moe/resources/internals/api.php" + files = {'fileToUpload': (filename, BytesIO(file_data), mime_type)} + data = {'reqtype': 'fileupload', 'time': expiry} + + response = requests.post(url, files=files, data=data, timeout=30) + response.raise_for_status() + + result = response.text.strip() + if result.startswith('https://'): + logging.info(f"Uploaded to litterbox: {result}") + return result + return None + + def _upload_0x0(self, file_data: bytes, filename: str, mime_type: str) -> Optional[str]: + url = "https://0x0.st" + files = {'file': (filename, BytesIO(file_data), mime_type)} + + response = requests.post(url, files=files, timeout=30) + response.raise_for_status() + + result = response.text.strip() + if result.startswith('https://'): + logging.info(f"Uploaded to 0x0.st: {result}") + return result + return None + + def _upload_imgur(self, file_data: bytes) -> Optional[str]: + if not self.api_key: + logging.error("Imgur requires an API key (Client-ID)") + return None + + url = "https://api.imgur.com/3/image" + headers = {'Authorization': f'Client-ID {self.api_key}'} + data = { + 'image': base64.b64encode(file_data).decode('utf-8'), + 'type': 'base64' + } + + response = requests.post(url, headers=headers, data=data, timeout=30) + response.raise_for_status() + + result = response.json() + if result.get('success'): + link = result['data']['link'] + logging.info(f"Uploaded to imgur: {link}") + return link + return None + + def _upload_imgbb(self, file_data: bytes) -> Optional[str]: + if not self.api_key: + logging.error("imgbb requires an API key") + return None + + url = "https://api.imgbb.com/1/upload" + data = { + 'key': self.api_key, + 'image': base64.b64encode(file_data).decode('utf-8') + } + + response = requests.post(url, data=data, timeout=30) + response.raise_for_status() + + result = response.json() + if result.get('success'): + link = result['data']['url'] + logging.info(f"Uploaded to imgbb: {link}") + return link + return None + + def _upload_envs(self, file_data: bytes, filename: str, mime_type: str) -> Optional[str]: + url = "https://envs.sh" + files = {'file': (filename, BytesIO(file_data), mime_type)} + + response = requests.post(url, files=files, timeout=30) + response.raise_for_status() + + result = response.text.strip() + if result.startswith('https://'): + logging.info(f"Uploaded to envs.sh: {result}") + return result + return None + + def _upload_uguu(self, file_data: bytes, filename: str, mime_type: str) -> Optional[str]: + url = "https://uguu.se/upload.php" + files = {'files[]': (filename, BytesIO(file_data), mime_type)} + + response = requests.post(url, files=files, timeout=30) + response.raise_for_status() + + result = response.json() + if result.get('success') and result.get('files'): + link = result['files'][0]['url'] + logging.info(f"Uploaded to uguu.se: {link}") + return link + return None + + +class StorageImpl(Storage): + + def __init__(self, json_file_path: Path) -> None: + super().__init__() + self.__json_file_path = json_file_path + self.__data: Dict[str, JSONType] = {} + try: + with open(self.__json_file_path, encoding="utf8") as f: + self.__data = json.load(f) + except Exception: + pass + + async def _load(self, key: str) -> Maybe[JSONType]: + if key in self.__data: + return Just(self.__data[key]) + return Nothing() + + async def _store(self, key: str, value: JSONType) -> None: + self.__data[key] = value + with open(self.__json_file_path, "w", encoding="utf8") as f: + json.dump(self.__data, f) + + async def _delete(self, key: str) -> None: + self.__data.pop(key, None) + with open(self.__json_file_path, "w", encoding="utf8") as f: + json.dump(self.__data, f) + + +class MemoryStorage: + + def __init__(self, file_path: str): + self.file_path = Path(file_path) + self.data: Dict[str, list] = {} + self._load() + + def _load(self): + try: + if self.file_path.exists(): + with open(self.file_path, 'r', encoding='utf-8') as f: + self.data = json.load(f) + logging.info(f"Loaded persistent memory from {self.file_path}") + except Exception as e: + logging.error(f"Failed to load memory from {self.file_path}: {e}") + self.data = {} + + def _save(self): + try: + with open(self.file_path, 'w', encoding='utf-8') as f: + json.dump(self.data, f, ensure_ascii=False, indent=2) + except Exception as e: + logging.error(f"Failed to save memory to {self.file_path}: {e}") + + def get_history(self, room_jid: str) -> list: + return self.data.get(room_jid, [])[:] + + def set_history(self, room_jid: str, history: list): + self.data[room_jid] = history + self._save() + + def append_to_history(self, room_jid: str, user_msg: str, assistant_msg: str, limit: int = 0): + if room_jid not in self.data: + self.data[room_jid] = [] + + self.data[room_jid].extend([ + {"role": "user", "content": user_msg}, + {"role": "assistant", "content": assistant_msg} + ]) + + if limit > 0 and len(self.data[room_jid]) > limit * 2: + self.data[room_jid] = self.data[room_jid][-limit * 2:] + + self._save() + + def clear_history(self, room_jid: str): + if room_jid in self.data: + del self.data[room_jid] + self._save() + + +class PluginCouldNotLoad(Exception): + pass + + +class XEP_0384Impl(XEP_0384): + + default_config = { + "fallback_message": "This message is OMEMO encrypted.", + "json_file_path": None + } + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.__storage: Storage + + def plugin_init(self) -> None: + if not self.json_file_path: + raise PluginCouldNotLoad("JSON file path not specified.") + self.__storage = StorageImpl(Path(self.json_file_path)) + super().plugin_init() + + @property + def storage(self) -> Storage: + return self.__storage + + @property + def _btbv_enabled(self) -> bool: + return True + + async def _devices_blindly_trusted( + self, + blindly_trusted: FrozenSet[DeviceInformation], + identifier: Optional[str] + ) -> None: + logging.info(f"[{identifier}] Devices trusted blindly: {blindly_trusted}") + + async def _prompt_manual_trust( + self, + manually_trusted: FrozenSet[DeviceInformation], + identifier: Optional[str] + ) -> None: + session_manager = await self.get_session_manager() + for device in manually_trusted: + logging.info(f"Auto-trusting device: {device}") + await session_manager.set_trust( + device.bare_jid, + device.identity_key, + TrustLevel.TRUSTED.value + ) + + +register_plugin(XEP_0384Impl) + + +class LLMBot(slixmpp.ClientXMPP): + + AUDIO_EXTENSIONS = ['.wav', '.mp3', '.ogg', '.opus', '.aac', '.flac', '.m4a', '.wma', '.amr', '.pcm', '.aiff'] + AUDIO_MIME_TYPES = { + 'audio/wav': 'wav', + 'audio/x-wav': 'wav', + 'audio/wave': 'wav', + 'audio/mp3': 'mp3', + 'audio/mpeg': 'mp3', + 'audio/ogg': 'ogg', + 'audio/opus': 'opus', + 'audio/aac': 'aac', + 'audio/flac': 'flac', + 'audio/m4a': 'aac', + 'audio/x-m4a': 'aac', + 'audio/mp4': 'aac', + 'audio/amr': 'amr', + 'audio/pcm': 'pcm', + 'audio/webm': 'webm', + 'audio/aiff': 'aiff', + 'audio/x-aiff': 'aiff' + } + + def __init__(self, jid, password, rooms, room_nicknames, trigger, mentions, rate_limit_calls, rate_limit_period, + max_length, nickname, api_url, privileged_users, max_retries, system_prompts, + remember_conversations, history_per_room, + quote_reply=True, mention_reply=True, skip_thinking=False, + request_timeout=20, allow_dm=True, dm_mode='whitelist', dm_list=None, + use_openai_api=False, api_token=None, openai_model="gpt-4", + enable_omemo=True, omemo_store_path="omemo_store.json", omemo_only=False, + answer_to_links=False, fetch_link_content=False, support_images=False, + join_retry_attempts=5, join_retry_delay=10, + persistent_memory=False, memory_file_path="memory.json", + imagen_trigger="!imagen", + cf_account_id=None, cf_api_token=None, + support_audio=False, + enable_url_context=False, + file_host='catbox', file_host_api_key=None, + tts_trigger="!tts", tts_enabled=False, tts_voice_name="Kore", + tts_model="gemini-2.5-flash-preview-tts", tts_auto_reply=False, + loop=None): + self.request_timeout = request_timeout + super().__init__(jid, password, loop=loop, sasl_mech='PLAIN') + self.enable_direct_tls = True + + self.rooms = rooms + self.room_nicknames = room_nicknames or {} + self.trigger = trigger + self.mentions = mentions + self.max_length = max_length + self.nickname = nickname + + self.api_url = api_url + self.use_openai_api = use_openai_api + self.api_token = api_token + self.openai_model = openai_model + + self.privileged_users = {u.lower() for u in privileged_users} + + self.max_retries = max_retries + self.system_prompts = system_prompts + + self.remember = remember_conversations + self.history_limit = history_per_room + self.history = {} + + self.persistent_memory = persistent_memory + self.memory_file_path = memory_file_path + self.memory_storage = None + if self.persistent_memory: + self.memory_storage = MemoryStorage(self.memory_file_path) + logging.info(f"Persistent memory enabled, using {self.memory_file_path}") + + self.quote_reply = quote_reply + self.mention_reply = mention_reply + self.skip_thinking = skip_thinking + + self.allow_dm = allow_dm + self.dm_mode = dm_mode.lower() + self.dm_list = {x.lower() for x in (dm_list or [])} + + self.enable_omemo = enable_omemo + self.omemo_store_path = omemo_store_path + self.omemo_only = omemo_only + + self.answer_to_links = answer_to_links + self.fetch_link_content = fetch_link_content + self.support_images = support_images + self.support_audio = support_audio + self.join_retry_attempts = join_retry_attempts + self.join_retry_delay = join_retry_delay + + self.imagen_trigger = imagen_trigger + self.cf_account_id = cf_account_id + self.cf_api_token = cf_api_token + + self.enable_url_context = enable_url_context + + self.file_uploader = FileUploader(service=file_host, api_key=file_host_api_key, xmpp_client=self) + + self.tts_trigger = tts_trigger + self.tts_enabled = tts_enabled + self.tts_voice_name = tts_voice_name + self.tts_model = tts_model + self.tts_auto_reply = tts_auto_reply + + self.genai_client = None + if GOOGLE_GENAI_AVAILABLE and self.api_token: + try: + self.genai_client = genai.Client(api_key=self.api_token) + logging.info("Initialized Google GenAI client") + except Exception as e: + logging.error(f"Failed to initialize Google GenAI client: {e}") + + self.openai_client = None + if self.use_openai_api and OPENAI_AVAILABLE and self.api_token: + try: + base_url = self.api_url + if base_url.endswith('/chat/completions'): + base_url = base_url[:-len('/chat/completions')] + elif base_url.endswith('/v1/chat/completions'): + base_url = base_url[:-len('/chat/completions')] + + self.openai_client = OpenAI(api_key=self.api_token, base_url=base_url) + logging.info(f"Initialized OpenAI client with base URL: {base_url}") + except Exception as e: + logging.error(f"Failed to initialize OpenAI client: {e}") + + self.url_pattern = re.compile( + r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+' + ) + self.aesgcm_pattern = re.compile(r'aesgcm://[^\s]+') + + self.register_plugin('xep_0030') + self.register_plugin('xep_0045') + self.register_plugin('xep_0199') + self.register_plugin('xep_0066') + self.register_plugin('xep_0363') + + if self.enable_omemo: + self.register_plugin('xep_0085') + self.register_plugin('xep_0380') + + import sys + self.register_plugin( + "xep_0384", + {"json_file_path": self.omemo_store_path}, + module=sys.modules[__name__] + ) + logging.info(f"OMEMO support enabled (OMEMO Only Mode: {self.omemo_only})") + + self.add_event_handler("session_start", self.start) + self.add_event_handler("groupchat_message", self.groupchat_message) + + if self.enable_omemo: + self.register_handler(CoroutineCallback( + "DirectMessages", + MatchXPath(f"{{{self.default_ns}}}message[@type='chat']"), + self.direct_message_async + )) + else: + self.add_event_handler("message", self.direct_message) + + self.rate_limited_send = sleep_and_retry( + limits(calls=rate_limit_calls, period=rate_limit_period)(self.send_to_llm) + ) + + def is_gemini_api(self): + return "generativelanguage.googleapis.com" in self.api_url + + def start(self, event): + self.send_presence() + self.get_roster() + + for room in self.rooms: + nick = self.room_nicknames.get(room, self.nickname) + self.join_room_with_retry(room, nick) + + def join_room_with_retry(self, room, nick): + for attempt in range(self.join_retry_attempts): + try: + self.plugin['xep_0045'].join_muc(room, nick) + logging.info(f"Successfully joined room: {room} as {nick}") + return True + except Exception as e: + logging.error(f"Failed to join {room} (attempt {attempt+1}/{self.join_retry_attempts}): {e}") + if attempt < self.join_retry_attempts - 1: + logging.info(f"Retrying in {self.join_retry_delay} seconds...") + time.sleep(self.join_retry_delay) + else: + logging.error(f"Failed to join {room} after {self.join_retry_attempts} attempts") + return False + + def extract_urls(self, text): + return self.url_pattern.findall(text) + + def clean_aesgcm_urls(self, text): + return self.aesgcm_pattern.sub('', text).strip() + + def fetch_url_content(self, url, max_length=5000): + try: + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/142.0.0.0 Safari/537.36' + } + response = requests.get(url, headers=headers, timeout=10, allow_redirects=True) + response.raise_for_status() + + content_type = response.headers.get('content-type', '').lower() + + if 'text/html' in content_type: + soup = BeautifulSoup(response.content, 'html.parser') + + for script in soup(["script", "style", "nav", "footer", "header"]): + script.decompose() + + text = soup.get_text(separator='\n', strip=True) + + lines = [line.strip() for line in text.splitlines() if line.strip()] + text = '\n'.join(lines) + + if len(text) > max_length: + text = text[:max_length] + "... (truncated)" + + return f"Content from {url}:\n{text}" + + elif 'text/plain' in content_type: + text = response.text + if len(text) > max_length: + text = text[:max_length] + "... (truncated)" + return f"Content from {url}:\n{text}" + + else: + return f"URL {url} contains non-text content ({content_type})" + + except Exception as e: + logging.error(f"Error fetching URL {url}: {e}") + return f"Error fetching content from {url}: {str(e)}" + + def decrypt_aesgcm_url(self, url): + try: + parsed = urlparse(url) + + if parsed.scheme != 'aesgcm': + return None + + if not parsed.fragment: + logging.error("AESGCM URL missing fragment with key/IV") + return None + + https_url = urlunparse(('https', parsed.netloc, parsed.path, parsed.params, parsed.query, '')) + + try: + key_iv = bytes.fromhex(parsed.fragment) + except ValueError as e: + logging.error(f"Invalid hex in AESGCM fragment: {e}") + return None + + key_iv_len = len(key_iv) + + if key_iv_len == 44: + iv = key_iv[:12] + key = key_iv[12:44] + elif key_iv_len == 48: + key = key_iv[:32] + iv = key_iv[32:48] + elif key_iv_len >= 44: + iv = key_iv[:12] + key = key_iv[12:44] + else: + logging.error(f"AESGCM key/IV unexpected length: {key_iv_len} bytes") + return None + + logging.info(f"Decrypting AESGCM URL: {https_url} (key_iv_len={key_iv_len})") + + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/142.0.0.0 Safari/537.36' + } + response = requests.get(https_url, timeout=30, headers=headers, allow_redirects=True) + response.raise_for_status() + + encrypted_data = response.content + + aesgcm = AESGCM(key) + decrypted_data = aesgcm.decrypt(iv, encrypted_data, None) + + logging.info(f"Successfully decrypted AESGCM data: {len(decrypted_data)} bytes") + + return decrypted_data + + except Exception as e: + logging.error(f"Error decrypting AESGCM URL: {e}", exc_info=True) + return None + + def is_audio_url(self, url): + if url.startswith('aesgcm://'): + path = urlparse(url.replace('aesgcm://', 'https://')).path.lower() + return any(path.endswith(ext) for ext in self.AUDIO_EXTENSIONS) + return any(url.lower().endswith(ext) for ext in self.AUDIO_EXTENSIONS) + + def get_audio_format_from_mime(self, mime_type): + mime_lower = mime_type.lower() + return self.AUDIO_MIME_TYPES.get(mime_lower, 'wav') + + def get_audio_format_from_extension(self, url): + path = urlparse(url.replace('aesgcm://', 'https://')).path.lower() + for ext in self.AUDIO_EXTENSIONS: + if path.endswith(ext): + return ext[1:] + return 'wav' + + def get_audio_mime_from_format(self, fmt): + mime_map = { + 'wav': 'audio/wav', + 'mp3': 'audio/mp3', + 'ogg': 'audio/ogg', + 'opus': 'audio/opus', + 'aac': 'audio/aac', + 'flac': 'audio/flac', + 'aiff': 'audio/aiff', + 'amr': 'audio/amr', + 'pcm': 'audio/pcm', + 'webm': 'audio/webm' + } + return mime_map.get(fmt, 'audio/wav') + + def fetch_audio_from_url(self, url): + try: + logging.info(f"Fetching audio from URL: {url}") + + if url.startswith('aesgcm://'): + decrypted_data = self.decrypt_aesgcm_url(url) + if not decrypted_data: + logging.error(f"Failed to decrypt AESGCM audio URL: {url}") + return None + + audio_format = self.get_audio_format_from_extension(url) + + return { + 'format': audio_format, + 'data': decrypted_data, + 'mime_type': self.get_audio_mime_from_format(audio_format) + } + + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/142.0.0.0 Safari/537.36' + } + response = requests.get(url, timeout=30, headers=headers, allow_redirects=True) + response.raise_for_status() + + content_type = response.headers.get('content-type', '').lower() + + is_audio = any(audio_type in content_type for audio_type in self.AUDIO_MIME_TYPES.keys()) + + if not is_audio: + audio_format = self.get_audio_format_from_extension(url) + if not audio_format: + logging.warning(f"URL returned non-audio content-type: {content_type}") + return None + else: + audio_format = self.get_audio_format_from_mime(content_type) + + logging.info(f"Successfully fetched audio: {audio_format}, size: {len(response.content)} bytes") + + return { + 'format': audio_format, + 'data': response.content, + 'mime_type': self.get_audio_mime_from_format(audio_format) + } + except Exception as e: + logging.error(f"Error fetching audio from {url}: {e}") + return None + + def extract_audio_from_message(self, msg, decrypted_body=None): + try: + if hasattr(msg, 'xml'): + xml_elem = msg.xml + elif hasattr(msg, '_get_stanza_values'): + xml_elem = msg + else: + logging.debug("Message object doesn't have accessible XML") + return None + + oob = xml_elem.find('.//{jabber:x:oob}url') + if oob is not None and oob.text: + audio_url = oob.text.strip() + if self.is_audio_url(audio_url): + logging.info(f"Found OOB audio URL: {audio_url}") + return self.fetch_audio_from_url(audio_url) + + if decrypted_body: + url_data = self.extract_urls_and_media(decrypted_body) + if url_data['audio_urls']: + logging.info(f"Found {len(url_data['audio_urls'])} audio URLs in decrypted body") + for audio_url in url_data['audio_urls']: + audio_data = self.fetch_audio_from_url(audio_url) + if audio_data: + return audio_data + + body_elem = xml_elem.find('.//{jabber:client}body') + if body_elem is None: + body_elem = xml_elem.find('.//body') + + if body_elem is not None and body_elem.text: + body_text = body_elem.text + if "doesn't seem to support that" not in body_text and "OMEMO" not in body_text: + url_data = self.extract_urls_and_media(body_text) + if url_data['audio_urls']: + logging.info(f"Found {len(url_data['audio_urls'])} audio URLs in body") + for audio_url in url_data['audio_urls']: + return self.fetch_audio_from_url(audio_url) + + except Exception as e: + logging.error(f"Error extracting audio: {e}", exc_info=True) + + return None + + def extract_urls_and_media(self, text): + image_urls = [] + audio_urls = [] + regular_urls = [] + + aesgcm_urls = self.aesgcm_pattern.findall(text) + + for url in aesgcm_urls: + if self.is_image_url(url): + image_urls.append(url) + elif self.is_audio_url(url): + audio_urls.append(url) + else: + regular_urls.append(url) + + http_urls = self.url_pattern.findall(text) + for url in http_urls: + if self.is_image_url(url): + image_urls.append(url) + elif self.is_audio_url(url): + audio_urls.append(url) + else: + regular_urls.append(url) + + return { + 'image_urls': image_urls, + 'audio_urls': audio_urls, + 'regular_urls': regular_urls, + 'all_urls': image_urls + audio_urls + regular_urls + } + + def extract_image_from_message(self, msg, decrypted_body=None): + try: + if hasattr(msg, 'xml'): + xml_elem = msg.xml + elif hasattr(msg, '_get_stanza_values'): + xml_elem = msg + else: + logging.debug("Message object doesn't have accessible XML") + return None + + bob = xml_elem.find('.//{urn:xmpp:bob}data') + if bob is not None: + mime_type = bob.get('type', 'image/jpeg') + image_data = bob.text + if image_data: + logging.info(f"Found BoB image: {mime_type}") + return { + 'mime_type': mime_type, + 'data': base64.b64decode(image_data.strip()) + } + + oob = xml_elem.find('.//{jabber:x:oob}url') + if oob is not None and oob.text: + image_url = oob.text.strip() + logging.info(f"Found OOB URL: {image_url}") + + if self.is_image_url(image_url): + logging.info(f"OOB URL is an image, fetching...") + return self.fetch_image_from_url(image_url) + + if decrypted_body: + logging.debug(f"Parsing decrypted body for image URLs...") + url_data = self.extract_urls_and_media(decrypted_body) + + if url_data['image_urls']: + logging.info(f"Found {len(url_data['image_urls'])} image URLs in decrypted OMEMO body") + for img_url in url_data['image_urls']: + logging.info(f"Fetching image from decrypted body URL: {img_url}") + image_data = self.fetch_image_from_url(img_url) + if image_data: + return image_data + + body_elem = xml_elem.find('.//{jabber:client}body') + if body_elem is None: + body_elem = xml_elem.find('.//body') + + if body_elem is not None and body_elem.text: + body_text = body_elem.text + + if "doesn't seem to support that" not in body_text and "OMEMO" not in body_text: + url_data = self.extract_urls_and_media(body_text) + + if url_data['image_urls']: + for img_url in url_data['image_urls']: + return self.fetch_image_from_url(img_url) + + except Exception as e: + logging.error(f"Error extracting image: {e}", exc_info=True) + + return None + + def fetch_image_from_url(self, url): + try: + logging.info(f"Fetching image from URL: {url}") + + if url.startswith('aesgcm://'): + decrypted_data = self.decrypt_aesgcm_url(url) + if not decrypted_data: + return None + + mime_type = 'image/jpeg' + if decrypted_data[:4] == b'\x89PNG': + mime_type = 'image/png' + elif decrypted_data[:3] == b'GIF': + mime_type = 'image/gif' + elif decrypted_data[:2] == b'\xff\xd8': + mime_type = 'image/jpeg' + elif len(decrypted_data) > 12 and decrypted_data[:4] == b'RIFF' and decrypted_data[8:12] == b'WEBP': + mime_type = 'image/webp' + + logging.info(f"Successfully processed AESGCM image: {mime_type}, size: {len(decrypted_data)} bytes") + + return { + 'mime_type': mime_type, + 'data': decrypted_data + } + + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/142.0.0.0 Safari/537.36' + } + response = requests.get(url, timeout=10, headers=headers, allow_redirects=True) + response.raise_for_status() + + content_type = response.headers.get('content-type', '').lower() + + valid_types = ['image/jpeg', 'image/jpg', 'image/png', 'image/gif', 'image/webp', 'image/bmp'] + is_image = any(img_type in content_type for img_type in valid_types) + + if not is_image: + logging.warning(f"URL returned non-image content-type: {content_type}") + return None + + if 'jpeg' in content_type or 'jpg' in content_type: + mime_type = 'image/jpeg' + elif 'png' in content_type: + mime_type = 'image/png' + elif 'gif' in content_type: + mime_type = 'image/gif' + elif 'webp' in content_type: + mime_type = 'image/webp' + else: + mime_type = 'image/jpeg' + + logging.info(f"Successfully fetched image: {mime_type}, size: {len(response.content)} bytes") + + return { + 'mime_type': mime_type, + 'data': response.content + } + except Exception as e: + logging.error(f"Error fetching image from {url}: {e}") + return None + + def sanitize_input(self, text): + if not isinstance(text, str): + return "" + + text = text.replace('\x00', '') + text = ''.join(char for char in text if char.isprintable() or char in '\n\r\t') + text = '\n'.join(line.strip() for line in text.split('\n')) + text = re.sub(r'\n{3,}', '\n\n', text) + + return text.strip() + + def extract_quoted_text(self, body): + lines = body.split('\n') + quoted = [] + non_quoted = [] + + for line in lines: + stripped = line.strip() + if stripped.startswith('>'): + quoted.append(stripped[1:].strip()) + else: + non_quoted.append(line) + + return '\n'.join(quoted), '\n'.join(non_quoted) + + def is_replying_to_bot(self, msg): + room_jid = msg['from'].bare + bot_nick = self.room_nicknames.get(room_jid, self.nickname) + body = msg['body'] + + quoted_text, _ = self.extract_quoted_text(body) + if quoted_text: + quoted_lines = quoted_text.split('\n') + for line in quoted_lines: + line = line.strip() + + while line.startswith('>'): + line = line[1:].strip() + + if line.startswith(f"{bot_nick}:") or line.startswith(f"{bot_nick} "): + return True + + if line.lower().startswith(f"{bot_nick.lower()}:") or line.lower().startswith(f"{bot_nick.lower()} "): + return True + + for punct in [',', ';', '-', '—']: + if line.startswith(f"{bot_nick}{punct}") or line.lower().startswith(f"{bot_nick.lower()}{punct}"): + return True + + return False + + def contains_links(self, text): + return bool(self.url_pattern.search(text)) or bool(self.aesgcm_pattern.search(text)) + + def is_image_url(self, url): + if url.startswith('aesgcm://'): + path = urlparse(url.replace('aesgcm://', 'https://')).path.lower() + return any(path.endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp']) + return any(url.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp']) + + def extract_urls_and_images(self, text): + image_urls = [] + regular_urls = [] + + aesgcm_urls = self.aesgcm_pattern.findall(text) + + for url in aesgcm_urls: + if self.is_image_url(url): + image_urls.append(url) + elif not self.is_audio_url(url): + regular_urls.append(url) + + http_urls = self.url_pattern.findall(text) + for url in http_urls: + if self.is_image_url(url): + image_urls.append(url) + elif not self.is_audio_url(url): + regular_urls.append(url) + + return { + 'image_urls': image_urls, + 'regular_urls': regular_urls, + 'all_urls': image_urls + regular_urls + } + + def _create_wave_file(self, pcm_data, channels=1, rate=24000, sample_width=2): + buffer = BytesIO() + with wave.open(buffer, "wb") as wf: + wf.setnchannels(channels) + wf.setsampwidth(sample_width) + wf.setframerate(rate) + wf.writeframes(pcm_data) + return buffer.getvalue() + + def _encrypt_and_upload(self, data: bytes, filename: str, should_encrypt: bool = False) -> Optional[str]: + """""" + + + + upload_mime = "application/octet-stream" + + if not should_encrypt: + if filename.endswith('.png'): upload_mime = "image/png" + elif filename.endswith('.jpg'): upload_mime = "image/jpeg" + elif filename.endswith('.wav'): upload_mime = "audio/wav" + elif filename.endswith('.ogg'): upload_mime = "audio/ogg" + + if should_encrypt: + try: + + key = os.urandom(32) + iv = os.urandom(12) + + + + aesgcm = AESGCM(key) + encrypted_data = aesgcm.encrypt(iv, data, None) + + + + + upload_url = self.file_uploader.upload(encrypted_data, filename, "application/octet-stream") + + if not upload_url: + return None + + + + fragment = iv.hex() + key.hex() + + parsed = urlparse(upload_url) + + + + new_url = urlunparse(('aesgcm', parsed.netloc, parsed.path, '', '', fragment)) + + logging.info(f"Encrypted upload successful: {new_url}") + return new_url + + except Exception as e: + logging.error(f"Encryption failed: {e}") + return None + else: + return self.file_uploader.upload(data, filename, upload_mime) + + def synthesize_speech(self, text, should_encrypt=False, max_retries=3): + if not self.tts_enabled: + logging.error("TTS not enabled") + return None + + if not self.genai_client: + logging.error("GenAI client not initialized for TTS") + return None + + for attempt in range(max_retries): + try: + logging.info(f"Synthesizing speech with Gemini TTS: {text[:50]}...") + + response = self.genai_client.models.generate_content( + model=self.tts_model, + contents=text, + config=types.GenerateContentConfig( + response_modalities=["AUDIO"], + speech_config=types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig( + voice_name=self.tts_voice_name, + ) + ) + ), + ) + ) + + if (response.candidates and + response.candidates[0].content and + response.candidates[0].content.parts): + + audio_part = response.candidates[0].content.parts[0] + if hasattr(audio_part, 'inline_data') and audio_part.inline_data: + raw_pcm_data = audio_part.inline_data.data + + final_data = None + filename = "tts.wav" + + if PYDUB_AVAILABLE: + try: + sound = AudioSegment( + data=raw_pcm_data, + sample_width=2, + frame_rate=24000, + channels=1 + ) + + buffer = BytesIO() + sound.export(buffer, format="ogg", codec="libopus") + final_data = buffer.getvalue() + filename = "tts.ogg" + logging.info(f"Converted TTS to Ogg Opus, size: {len(final_data)} bytes") + except Exception as e: + logging.error(f"Audio conversion failed (missing ffmpeg?): {e}") + + if final_data is None: + logging.info("Falling back to WAV format") + final_data = self._create_wave_file(raw_pcm_data) + + upload_url = self._encrypt_and_upload(final_data, filename, should_encrypt) + + if upload_url: + return {"type": "url", "content": upload_url} + else: + + return {"type": "base64", "content": base64.b64encode(final_data).decode('utf-8')} + + logging.error("No audio data in TTS response") + return None + + except Exception as e: + logging.error(f"TTS error (attempt {attempt+1}): {e}") + time.sleep(1) + + return None + + def generate_image(self, prompt, should_encrypt=False, max_retries=3): + if not self.cf_account_id or not self.cf_api_token: + logging.error("Cloudflare credentials not configured for image generation") + return None + + url = f"https://api.cloudflare.com/client/v4/accounts/{self.cf_account_id}/ai/run/@cf/black-forest-labs/flux-2-dev" + headers = {"Authorization": f"Bearer {self.cf_api_token}"} + + files = { + 'prompt': (None, prompt), + 'width': (None, '1024'), + 'height': (None, '1024'), + 'steps': (None, '20') + } + + for attempt in range(max_retries): + try: + logging.info(f"Generating image with Cloudflare Workers AI: {prompt[:50]}...") + + response = requests.post(url, headers=headers, files=files, timeout=60) + response.raise_for_status() + + result = response.json() + + b64_image = None + if 'result' in result and isinstance(result['result'], dict) and 'image' in result['result']: + b64_image = result['result']['image'] + elif 'image' in result: + b64_image = result['image'] + + if b64_image: + image_bytes = base64.b64decode(b64_image) + logging.info(f"Image generated, size: {len(image_bytes)} bytes") + + upload_url = self._encrypt_and_upload(image_bytes, "generated.png", should_encrypt) + + if upload_url: + return {"type": "url", "content": upload_url} + else: + return {"type": "base64", "content": b64_image} + + logging.error(f"No image in Cloudflare response: {result}") + return None + + except Exception as e: + logging.error(f"Image generation error (attempt {attempt+1}): {e}") + time.sleep(2) + + return None + + def direct_message(self, msg): + if not self.allow_dm: + return + + if msg['type'] not in ('chat', 'normal'): + return + + if msg['from'].bare == self.boundjid.bare: + return + + sender = msg['from'].bare.lower() + + if self.dm_mode == 'whitelist' and sender not in self.dm_list: + return + if self.dm_mode == 'blacklist' and sender in self.dm_list: + return + + + if self.enable_omemo and self.omemo_only: + logging.warning(f"Ignoring plaintext DM from {sender} (OMEMO only mode active).") + return + + body = self.sanitize_input(msg['body']) + if not body or len(body) > self.max_length: + return + + logging.info(f"Direct message from {sender}: {body[:50]}...") + + if body.startswith(self.imagen_trigger): + prompt = body[len(self.imagen_trigger):].strip() + if prompt: + + result = self.generate_image(prompt, should_encrypt=False) + if result: + if result["type"] == "url": + response = result['content'] + reply = msg.reply(response) + reply['oob']['url'] = response + reply.send() + else: + msg.reply("Image generated (base64)").send() + else: + msg.reply("Failed to generate image.").send() + return + + if self.tts_enabled and body.startswith(self.tts_trigger): + text = body[len(self.tts_trigger):].strip() + if text: + result = self.synthesize_speech(text, should_encrypt=False) + if result: + if result["type"] == "url": + response = result['content'] + reply = msg.reply(response) + reply['oob']['url'] = response + reply.send() + else: + msg.reply("Speech synthesized (base64)").send() + else: + msg.reply("Failed to synthesize speech.").send() + return + + quoted_text, non_quoted_text = self.extract_quoted_text(body) + if non_quoted_text.strip(): + query = non_quoted_text.strip() + else: + query = body + + has_links = self.contains_links(query) + + image_data = None + audio_data = None + + if self.support_images: + image_data = self.extract_image_from_message(msg) + if image_data: + logging.info(f"Image detected in DM from {sender}") + + if self.support_audio: + audio_data = self.extract_audio_from_message(msg) + if audio_data: + logging.info(f"Audio detected in DM from {sender}") + + query = self.clean_aesgcm_urls(query) + + link_context = "" + if self.fetch_link_content and has_links: + urls = self.extract_urls_and_media(query) + + for url in urls['regular_urls'][:3]: + if not url.startswith('aesgcm://'): + content = self.fetch_url_content(url) + if content: + link_context += f"\n\n{content}\n" + + if link_context: + query = f"{query}\n\n--- Additional context from links ---{link_context}" + + response = self.rate_limited_send(query, max_retries=self.max_retries, room_jid=sender, image_data=image_data, audio_data=audio_data) + + if response: + msg.reply(response).send() + logging.info(f"Replied to {sender}") + + if self.tts_auto_reply and self.tts_enabled: + tts_result = self.synthesize_speech(response, should_encrypt=False) + if tts_result and tts_result["type"] == "url": + reply = msg.reply(tts_result['content']) + reply['oob']['url'] = tts_result['content'] + reply.send() + + async def direct_message_async(self, stanza: Message) -> None: + if not self.allow_dm: + return + + mfrom = stanza["from"] + mtype = stanza["type"] + + if mtype not in {"chat", "normal"}: + return + + if mfrom.bare == self.boundjid.bare: + return + + sender = mfrom.bare.lower() + + if self.dm_mode == 'whitelist' and sender not in self.dm_list: + return + if self.dm_mode == 'blacklist' and sender in self.dm_list: + return + + xep_0384 = self["xep_0384"] + namespace = xep_0384.is_encrypted(stanza) + + body = None + is_encrypted = False + decrypted_body = None + + if namespace: + logging.debug(f"Encrypted message received from {mfrom}") + try: + decrypted_msg, device_info = await xep_0384.decrypt_message(stanza) + if decrypted_msg.get("body"): + body = self.sanitize_input(decrypted_msg["body"]) + decrypted_body = body + is_encrypted = True + logging.info(f"Decrypted message from {sender}: {body[:50]}...") + except Exception as e: + logging.error(f"Decryption failed: {e}") + await self._plain_reply(mfrom, mtype, f"Error decrypting message: {e}") + return + else: + if stanza["body"]: + body = self.sanitize_input(stanza["body"]) + logging.info(f"Plaintext message from {sender}: {body[:50]}...") + + + if self.omemo_only and not is_encrypted: + logging.warning(f"Ignoring plaintext DM from {sender} (OMEMO only mode active).") + return + + if not body or len(body) > self.max_length: + return + + if body.startswith(self.imagen_trigger): + prompt = body[len(self.imagen_trigger):].strip() + if prompt: + loop = asyncio.get_running_loop() + + result = await loop.run_in_executor(None, self.generate_image, prompt, is_encrypted) + + if result: + response = result['content'] if result["type"] == "url" else "Image generated (base64)" + else: + response = "Failed to generate image." + + if is_encrypted: + await self._encrypted_reply(mfrom, mtype, response) + else: + await self._plain_reply(mfrom, mtype, response) + return + + if self.tts_enabled and body.startswith(self.tts_trigger): + text = body[len(self.tts_trigger):].strip() + if text: + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, self.synthesize_speech, text, is_encrypted) + + if result: + response = result['content'] if result["type"] == "url" else "Speech synthesized (base64)" + else: + response = "Failed to synthesize speech." + + if is_encrypted: + await self._encrypted_reply(mfrom, mtype, response) + else: + await self._plain_reply(mfrom, mtype, response) + return + + quoted_text, non_quoted_text = self.extract_quoted_text(body) + if non_quoted_text.strip(): + query = non_quoted_text.strip() + else: + query = body + + has_links = self.contains_links(query) + query = self.clean_aesgcm_urls(query) + + link_context = "" + if self.fetch_link_content and has_links: + urls = self.extract_urls(query) + + for url in urls[:3]: + if not url.startswith('aesgcm://'): + content = self.fetch_url_content(url) + if content: + link_context += f"\n\n{content}\n" + + if link_context: + query = f"{query}\n\n--- Additional context from links ---{link_context}" + + image_data = None + audio_data = None + + if self.support_images: + image_data = self.extract_image_from_message(stanza, decrypted_body) + if image_data: + logging.info(f"Image detected in DM from {sender}") + + if self.support_audio: + audio_data = self.extract_audio_from_message(stanza, decrypted_body) + if audio_data: + logging.info(f"Audio detected in DM from {sender}") + + loop = asyncio.get_running_loop() + response = await loop.run_in_executor( + None, + self.rate_limited_send, + query, + self.max_retries, + sender, + image_data, + audio_data + ) + + if response: + if is_encrypted: + await self._encrypted_reply(mfrom, mtype, response) + else: + await self._plain_reply(mfrom, mtype, response) + logging.info(f"Replied to {sender}") + + if self.tts_auto_reply and self.tts_enabled: + tts_result = await loop.run_in_executor(None, self.synthesize_speech, response, is_encrypted) + if tts_result and tts_result["type"] == "url": + if is_encrypted: + await self._encrypted_reply(mfrom, mtype, tts_result['content']) + else: + await self._plain_reply(mfrom, mtype, tts_result['content']) + + async def _plain_reply(self, mto: JID, mtype: str, reply_text: str) -> None: + msg = self.make_message(mto=mto, mtype=mtype) + msg["body"] = reply_text + if reply_text.startswith("http") and not reply_text.startswith("aesgcm"): + msg['oob']['url'] = reply_text + msg.send() + + async def _encrypted_reply(self, mto: JID, mtype: str, reply_text: str) -> None: + xep_0384 = self["xep_0384"] + + + if isinstance(mto, JID): + target_jid = mto.bare + else: + target_jid = JID(mto).bare + + msg = self.make_message(mto=target_jid, mtype=mtype) + msg["body"] = reply_text + + + + encrypt_for: Set[JID] = {JID(target_jid)} + + try: + messages, encryption_errors = await xep_0384.encrypt_message(msg, encrypt_for) + + if encryption_errors: + logging.warning(f"Encryption errors: {encryption_errors}") + + for namespace, message in messages.items(): + message["eme"]["namespace"] = namespace + message["eme"]["name"] = self["xep_0380"].mechanisms[namespace] + message.send() + logging.debug(f"Sent encrypted message to {target_jid}") + except Exception as e: + logging.error(f"Failed to send encrypted reply: {e}") + await self._plain_reply(mto, mtype, f"Error encrypting reply: {e}") + + def groupchat_message(self, msg): + if msg['type'] != 'groupchat': + return + + room_jid = msg['from'].bare + bot_nick = self.room_nicknames.get(room_jid, self.nickname) + + if msg['mucnick'] == bot_nick: + return + + body = self.sanitize_input(msg['body']) + + if len(body) > self.max_length: + return + + sender_nick = msg['mucnick'] + sender_lower = sender_nick.lower() + is_privileged = sender_lower in self.privileged_users + + if body.startswith(self.imagen_trigger): + prompt = body[len(self.imagen_trigger):].strip() + if prompt: + result = self.generate_image(prompt, should_encrypt=False) + if result: + if result["type"] == "url": + response = f"{sender_nick}: {result['content']}" + reply = msg.reply(response) + reply['oob']['url'] = result['content'] + reply.send() + else: + msg.reply(f"{sender_nick}: Image generated (base64)").send() + else: + msg.reply(f"{sender_nick}: Failed to generate image.").send() + return + + if self.tts_enabled and body.startswith(self.tts_trigger): + text = body[len(self.tts_trigger):].strip() + if text: + result = self.synthesize_speech(text, should_encrypt=False) + if result: + if result["type"] == "url": + response = f"{sender_nick}: {result['content']}" + reply = msg.reply(response) + reply['oob']['url'] = result['content'] + reply.send() + else: + msg.reply(f"{sender_nick}: Speech synthesized (base64)").send() + else: + msg.reply(f"{sender_nick}: Failed to synthesize speech.").send() + return + + query = None + is_reply = self.is_replying_to_bot(msg) + triggered_by_reply_only = False + triggered_by_links = False + has_links = self.contains_links(body) + mention_separator = ': ' + + if body.startswith(self.trigger): + query = body[len(self.trigger):].strip() + + elif self.mentions and f"@{bot_nick}" in body: + query = body.replace(f"@{bot_nick}", "").strip() + mention_separator = ': ' + + elif self.mentions and body.lower().startswith(f"{bot_nick.lower()}:"): + query = body[len(bot_nick)+1:].strip() + mention_separator = ':' + + elif self.mentions and body.lower().startswith(f"{bot_nick.lower()},"): + query = body[len(bot_nick)+1:].strip() + mention_separator = ',' + + elif self.mentions and body.lower().startswith(f"{bot_nick.lower()};"): + query = body[len(bot_nick)+1:].strip() + mention_separator = ';' + + elif self.mentions and body.lower().startswith(f"{bot_nick.lower()}-"): + query = body[len(bot_nick)+1:].strip() + mention_separator = '-' + + elif self.mentions and body.lower().startswith(f"{bot_nick.lower()}—"): + query = body[len(bot_nick)+1:].strip() + mention_separator = '—' + + elif self.mentions and body.lower().startswith(f"{bot_nick.lower()} "): + query = body[len(bot_nick)+1:].strip() + mention_separator = ' ' + + elif is_reply: + _, non_quoted = self.extract_quoted_text(body) + query = non_quoted.strip() + triggered_by_reply_only = True + logging.info(f"Detected reply to bot from {sender_nick}") + + elif self.answer_to_links and has_links: + query = body + triggered_by_links = True + logging.info(f"Detected message with links from {sender_nick}") + + elif is_privileged and body: + query = body + + if query: + logging.info(f"Query from {sender_nick} in {room_jid}: {query[:50]}...") + + image_data = None + audio_data = None + + if self.support_images: + image_data = self.extract_image_from_message(msg) + + if not image_data: + url_data = self.extract_urls_and_media(query) + if url_data['image_urls']: + for img_url in url_data['image_urls']: + image_data = self.fetch_image_from_url(img_url) + if image_data: + break + + if self.support_audio: + audio_data = self.extract_audio_from_message(msg) + + if not audio_data: + url_data = self.extract_urls_and_media(query) + if url_data['audio_urls']: + for audio_url in url_data['audio_urls']: + audio_data = self.fetch_audio_from_url(audio_url) + if audio_data: + break + + query = self.clean_aesgcm_urls(query) + + link_context = "" + if self.fetch_link_content and has_links: + url_data = self.extract_urls_and_media(body) + + for url in url_data['regular_urls'][:3]: + if not url.startswith('aesgcm://'): + content = self.fetch_url_content(url) + if content: + link_context += f"\n\n{content}\n" + + if link_context: + query = f"{query}\n\n--- Additional context from links ---{link_context}" + + response = self.rate_limited_send(query, max_retries=self.max_retries, room_jid=room_jid, image_data=image_data, audio_data=audio_data) + + if response: + if self.mention_reply: + if mention_separator == ' ': + response = f"{sender_nick} {response}" + else: + response = f"{sender_nick}{mention_separator} {response}" + + if self.quote_reply: + _, non_quoted_original = self.extract_quoted_text(body) + + message_to_quote = non_quoted_original.strip() if non_quoted_original.strip() else body + + if message_to_quote: + lines = message_to_quote.split('\n') + + first_line = lines[0].strip() if lines else "" + already_has_bot_name = any([ + first_line.lower().startswith(f"{bot_nick.lower()}{sep}") + for sep in [':', ',', ';', '-', '—', ' '] + ]) + + if (triggered_by_reply_only or triggered_by_links) and lines and not already_has_bot_name: + lines[0] = f"{bot_nick}: {lines[0]}" + + quoted = '\n'.join(f"> {line}" for line in lines) + response = f"{quoted}\n{response}" + + msg.reply(response).send() + logging.info(f"Replied in {room_jid}") + + if self.tts_auto_reply and self.tts_enabled: + tts_text = response + if self.quote_reply: + tts_lines = [l for l in response.split('\n') if not l.startswith('>')] + tts_text = '\n'.join(tts_lines).strip() + + tts_result = self.synthesize_speech(tts_text, should_encrypt=False) + if tts_result and tts_result["type"] == "url": + reply = msg.reply(tts_result['content']) + reply['oob']['url'] = tts_result['content'] + reply.send() + + def send_to_llm(self, message, max_retries, room_jid=None, image_data=None, audio_data=None): + if self.is_gemini_api() and self.genai_client: + return self._send_to_gemini_native(message, max_retries, room_jid, image_data, audio_data) + elif self.use_openai_api: + if self.openai_client: + return self._send_to_openai_library(message, max_retries, room_jid, image_data, audio_data) + else: + return self._send_to_openai_requests(message, max_retries, room_jid, image_data, audio_data) + else: + return self._send_to_custom_api(message, max_retries, room_jid, image_data, audio_data) + + def _send_to_gemini_native(self, message, max_retries, room_jid=None, image_data=None, audio_data=None): + if not self.genai_client: + logging.error("GenAI client not initialized") + return None + + for attempt in range(max_retries): + try: + contents = [] + + if self.remember and room_jid: + history = [] + if self.persistent_memory and self.memory_storage: + history = self.memory_storage.get_history(room_jid) + elif room_jid in self.history: + history = self.history[room_jid][:] + + for item in history: + role = "user" if item["role"] == "user" else "model" + contents.append( + types.Content( + role=role, + parts=[types.Part.from_text(text=item["content"])] + ) + ) + + user_parts = [types.Part.from_text(text=message)] + + if image_data and self.support_images: + img_bytes = image_data['data'] if isinstance(image_data['data'], bytes) else base64.b64decode(image_data['data']) + user_parts.append( + types.Part.from_bytes( + data=img_bytes, + mime_type=image_data['mime_type'] + ) + ) + + if audio_data and self.support_audio: + audio_bytes = audio_data['data'] if isinstance(audio_data['data'], bytes) else base64.b64decode(audio_data['data']) + user_parts.append( + types.Part.from_bytes( + data=audio_bytes, + mime_type=audio_data['mime_type'] + ) + ) + + contents.append( + types.Content(role="user", parts=user_parts) + ) + + config_kwargs = {} + + system_prompt = self.system_prompts.get(room_jid) or self.system_prompts.get("global") + if system_prompt: + config_kwargs['system_instruction'] = system_prompt + + if self.enable_url_context: + config_kwargs['tools'] = [types.Tool(url_context=types.UrlContext())] + + config = types.GenerateContentConfig(**config_kwargs) if config_kwargs else None + + response = self.genai_client.models.generate_content( + model=self.openai_model, + contents=contents, + config=config + ) + + content = response.text + + if not content: + logging.warning("Empty response from Gemini") + time.sleep(0.5) + continue + + if len(content) > self.max_length: + logging.warning("Response too long, retrying...") + time.sleep(0.5) + continue + + if self.skip_thinking and content.lower().startswith(" self.max_length: + logging.warning("Response too long, retrying...") + time.sleep(0.5) + continue + + if self.skip_thinking and content.lower().startswith(" self.max_length: + logging.warning("Response too long, retrying...") + time.sleep(0.5) + continue + + if self.skip_thinking and content.lower().startswith(" self.max_length: + logging.warning("Response too long, retrying...") + time.sleep(0.5) + continue + + if self.skip_thinking and content.lower().startswith(" 0 and len(self.history[room_jid]) > self.history_limit * 2: + self.history[room_jid] = self.history[room_jid][-self.history_limit * 2:] + + +if __name__ == '__main__': + config = configparser.ConfigParser() + config.read('config.ini') + + jid = config['XMPP']['jid'] + password = config['XMPP']['password'] + rooms = [r.strip() for r in config['XMPP']['rooms'].split(',')] + + api_url = config['Bot']['api_url'] + use_openai_api = config.getboolean('Bot', 'use_openai_api', fallback=False) + api_token = config.get('Bot', 'api_token', fallback='') + openai_model = config.get('Bot', 'openai_model', fallback='gpt-4') + + trigger = config.get('Bot', 'trigger', fallback='!aibot') + mentions = config.getboolean('Bot', 'mentions', fallback=True) + rate_limit_calls = config.getint('Bot', 'rate_limit_calls', fallback=20) + rate_limit_period = config.getint('Bot', 'rate_limit_period', fallback=60) + max_length = config.getint('Bot', 'max_length', fallback=4000) + nickname = config.get('Bot', 'nickname', fallback='LLMBot') + privileged_users = [u.strip() for u in config.get('Bot', 'privileged_users', fallback='').split(',') if u.strip()] + max_retries = config.getint('Bot', 'max_retries', fallback=20) + + remember_conversations = config.getboolean('Bot', 'remember_conversations', fallback=False) + history_per_room = config.getint('Bot', 'history_per_room', fallback=10) + + persistent_memory = config.getboolean('Bot', 'persistent_memory', fallback=False) + memory_file_path = config.get('Bot', 'memory_file_path', fallback='memory.json') + + quote_reply = config.getboolean('Bot', 'quote_reply', fallback=True) + mention_reply = config.getboolean('Bot', 'mention_reply', fallback=True) + skip_thinking = config.getboolean('Bot', 'skip_thinking_models', fallback=False) + request_timeout = config.getint('Bot', 'request_timeout', fallback=20) + + allow_dm = config.getboolean('Bot', 'allow_dm', fallback=True) + dm_mode = config.get('Bot', 'dm_mode', fallback='none') + dm_list = [x.strip() for x in config.get('Bot', 'dm_list', fallback='').split(',') if x.strip()] + + enable_omemo = config.getboolean('Bot', 'enable_omemo', fallback=True) + omemo_store_path = config.get('Bot', 'omemo_store_path', fallback='omemo_store.json') + omemo_only = config.getboolean('Bot', 'omemo_only', fallback=False) + + answer_to_links = config.getboolean('Bot', 'answer_to_links', fallback=False) + fetch_link_content = config.getboolean('Bot', 'fetch_link_content', fallback=False) + support_images = config.getboolean('Bot', 'support_images', fallback=False) + support_audio = config.getboolean('Bot', 'support_audio', fallback=False) + join_retry_attempts = config.getint('Bot', 'join_retry_attempts', fallback=5) + join_retry_delay = config.getint('Bot', 'join_retry_delay', fallback=10) + + imagen_trigger = config.get('Bot', 'imagen_trigger', fallback='!imagen') + cf_account_id = config.get('Bot', 'cloudflare_account_id', fallback=None) + cf_api_token = config.get('Bot', 'cloudflare_api_token', fallback=None) + + enable_url_context = config.getboolean('Bot', 'enable_url_context', fallback=False) + + + file_host = config.get('Bot', 'file_host', fallback='catbox') + file_host_api_key = config.get('Bot', 'file_host_api_key', fallback='') + + tts_trigger = config.get('Bot', 'tts_trigger', fallback='!tts') + tts_enabled = config.getboolean('Bot', 'tts_enabled', fallback=False) + tts_voice_name = config.get('Bot', 'tts_voice_name', fallback='Kore') + tts_model = config.get('Bot', 'tts_model', fallback='gemini-2.5-flash-preview-tts') + tts_auto_reply = config.getboolean('Bot', 'tts_auto_reply', fallback=False) + + system_prompts = {} + global_prompt = config.get('Bot', 'system_prompt', fallback='').strip() + if global_prompt: + system_prompts["global"] = global_prompt + + for key in config['Bot']: + if key.startswith('system_prompt.'): + room = key.split('.', 1)[1].strip() + prompt = config['Bot'][key].strip() + if prompt: + system_prompts[room] = prompt + + room_nicknames = {} + for key in config['Bot']: + if key.startswith('nickname.'): + room = key.split('.', 1)[1].strip() + nick = config['Bot'][key].strip() + if nick: + room_nicknames[room] = nick + + async def main(): + logging.basicConfig(level=logging.DEBUG, format="%(levelname)-8s %(message)s") + + loop = asyncio.get_running_loop() + + bot = LLMBot(jid, password, rooms, room_nicknames, trigger, mentions, rate_limit_calls, rate_limit_period, + max_length, nickname, api_url, privileged_users, max_retries, system_prompts, + remember_conversations, history_per_room, + quote_reply=quote_reply, mention_reply=mention_reply, + skip_thinking=skip_thinking, + request_timeout=request_timeout, allow_dm=allow_dm, + dm_mode=dm_mode, dm_list=dm_list, + use_openai_api=use_openai_api, api_token=api_token, + openai_model=openai_model, + enable_omemo=enable_omemo, omemo_store_path=omemo_store_path, omemo_only=omemo_only, + answer_to_links=answer_to_links, fetch_link_content=fetch_link_content, + support_images=support_images, support_audio=support_audio, + join_retry_attempts=join_retry_attempts, join_retry_delay=join_retry_delay, + persistent_memory=persistent_memory, memory_file_path=memory_file_path, + imagen_trigger=imagen_trigger, + cf_account_id=cf_account_id, cf_api_token=cf_api_token, + enable_url_context=enable_url_context, + file_host=file_host, file_host_api_key=file_host_api_key, + tts_trigger=tts_trigger, tts_enabled=tts_enabled, + tts_voice_name=tts_voice_name, tts_model=tts_model, + tts_auto_reply=tts_auto_reply, + loop=loop) + + bot.connect() + logging.info("Bot connected, waiting for messages...") + + await bot.disconnected + + asyncio.run(main()) +