2200 lines
84 KiB
Python
2200 lines
84 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import slixmpp
|
|
import requests
|
|
import json
|
|
import time
|
|
from ratelimit import limits, sleep_and_retry
|
|
import configparser
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Dict, FrozenSet, Optional, Set
|
|
import re
|
|
import base64
|
|
import wave
|
|
from io import BytesIO
|
|
from urllib.parse import urlparse, urljoin, urlunparse, parse_qs
|
|
from bs4 import BeautifulSoup
|
|
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
|
from cryptography.hazmat.backends import default_backend
|
|
|
|
from slixmpp.jid import JID
|
|
from slixmpp.stanza import Message
|
|
from slixmpp.xmlstream.handler import CoroutineCallback
|
|
from slixmpp.xmlstream.matcher import MatchXPath
|
|
from slixmpp.plugins import register_plugin
|
|
|
|
from omemo.storage import Just, Maybe, Nothing, Storage
|
|
from omemo.types import DeviceInformation, JSONType
|
|
from slixmpp_omemo import TrustLevel, XEP_0384
|
|
|
|
|
|
try:
|
|
from google import genai
|
|
from google.genai import types
|
|
GOOGLE_GENAI_AVAILABLE = True
|
|
except ImportError:
|
|
GOOGLE_GENAI_AVAILABLE = False
|
|
logging.warning("google-genai library not available.")
|
|
|
|
try:
|
|
from openai import OpenAI
|
|
OPENAI_AVAILABLE = True
|
|
except ImportError:
|
|
OPENAI_AVAILABLE = False
|
|
logging.warning("openai library not available.")
|
|
|
|
try:
|
|
from pydub import AudioSegment
|
|
PYDUB_AVAILABLE = True
|
|
except ImportError:
|
|
PYDUB_AVAILABLE = False
|
|
logging.warning("pydub library not available.")
|
|
|
|
|
|
class FileUploader:
|
|
|
|
def __init__(self, service='catbox', api_key=None, xmpp_client=None):
|
|
self.service = service.lower()
|
|
self.api_key = api_key
|
|
self.xmpp_client = xmpp_client
|
|
|
|
def upload(self, file_data: bytes, filename: str = "file.bin", mime_type: str = "application/octet-stream") -> Optional[str]:
|
|
try:
|
|
if self.service == 'catbox':
|
|
return self._upload_catbox(file_data, filename, mime_type)
|
|
elif self.service == 'litterbox':
|
|
return self._upload_litterbox(file_data, filename, mime_type)
|
|
elif self.service == '0x0':
|
|
return self._upload_0x0(file_data, filename, mime_type)
|
|
elif self.service == 'imgur':
|
|
return self._upload_imgur(file_data)
|
|
elif self.service == 'imgbb':
|
|
return self._upload_imgbb(file_data)
|
|
elif self.service == 'envs':
|
|
return self._upload_envs(file_data, filename, mime_type)
|
|
elif self.service == 'uguu':
|
|
return self._upload_uguu(file_data, filename, mime_type)
|
|
elif self.service == 'xmpp':
|
|
return self._upload_xmpp(file_data, filename, mime_type)
|
|
else:
|
|
logging.error(f"Unknown hosting service: {self.service}")
|
|
return None
|
|
except Exception as e:
|
|
logging.error(f"File upload failed: {e}")
|
|
return None
|
|
|
|
def _upload_xmpp(self, file_data: bytes, filename: str, mime_type: str) -> Optional[str]:
|
|
if not self.xmpp_client:
|
|
logging.error("XMPP client not available for upload. Ensure 'xmpp' is selected as file_host.")
|
|
return None
|
|
|
|
try:
|
|
logging.info(f"Starting XMPP HTTP File Upload for {filename} ({len(file_data)} bytes)...")
|
|
|
|
input_file = BytesIO(file_data)
|
|
|
|
future = asyncio.run_coroutine_threadsafe(
|
|
self.xmpp_client['xep_0363'].upload_file(
|
|
filename=filename,
|
|
size=len(file_data),
|
|
input_file=input_file,
|
|
content_type=mime_type,
|
|
timeout=60
|
|
),
|
|
self.xmpp_client.loop
|
|
)
|
|
|
|
result = future.result(timeout=70)
|
|
|
|
logging.info(f"Uploaded to XMPP server: {result}")
|
|
return result
|
|
|
|
except asyncio.TimeoutError:
|
|
logging.error("XMPP file upload timed out.")
|
|
return None
|
|
except Exception as e:
|
|
logging.error(f"XMPP file upload failed: {e}")
|
|
return None
|
|
|
|
def _upload_catbox(self, file_data: bytes, filename: str, mime_type: str) -> Optional[str]:
|
|
url = "https://catbox.moe/user/api.php"
|
|
files = {'fileToUpload': (filename, BytesIO(file_data), mime_type)}
|
|
data = {'reqtype': 'fileupload'}
|
|
if self.api_key:
|
|
data['userhash'] = self.api_key
|
|
|
|
response = requests.post(url, files=files, data=data, timeout=30)
|
|
response.raise_for_status()
|
|
|
|
result = response.text.strip()
|
|
if result.startswith('https://'):
|
|
logging.info(f"Uploaded to catbox: {result}")
|
|
return result
|
|
return None
|
|
|
|
def _upload_litterbox(self, file_data: bytes, filename: str, mime_type: str, expiry: str = "72h") -> Optional[str]:
|
|
url = "https://litterbox.catbox.moe/resources/internals/api.php"
|
|
files = {'fileToUpload': (filename, BytesIO(file_data), mime_type)}
|
|
data = {'reqtype': 'fileupload', 'time': expiry}
|
|
|
|
response = requests.post(url, files=files, data=data, timeout=30)
|
|
response.raise_for_status()
|
|
|
|
result = response.text.strip()
|
|
if result.startswith('https://'):
|
|
logging.info(f"Uploaded to litterbox: {result}")
|
|
return result
|
|
return None
|
|
|
|
def _upload_0x0(self, file_data: bytes, filename: str, mime_type: str) -> Optional[str]:
|
|
url = "https://0x0.st"
|
|
files = {'file': (filename, BytesIO(file_data), mime_type)}
|
|
|
|
response = requests.post(url, files=files, timeout=30)
|
|
response.raise_for_status()
|
|
|
|
result = response.text.strip()
|
|
if result.startswith('https://'):
|
|
logging.info(f"Uploaded to 0x0.st: {result}")
|
|
return result
|
|
return None
|
|
|
|
def _upload_imgur(self, file_data: bytes) -> Optional[str]:
|
|
if not self.api_key:
|
|
logging.error("Imgur requires an API key (Client-ID)")
|
|
return None
|
|
|
|
url = "https://api.imgur.com/3/image"
|
|
headers = {'Authorization': f'Client-ID {self.api_key}'}
|
|
data = {
|
|
'image': base64.b64encode(file_data).decode('utf-8'),
|
|
'type': 'base64'
|
|
}
|
|
|
|
response = requests.post(url, headers=headers, data=data, timeout=30)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
if result.get('success'):
|
|
link = result['data']['link']
|
|
logging.info(f"Uploaded to imgur: {link}")
|
|
return link
|
|
return None
|
|
|
|
def _upload_imgbb(self, file_data: bytes) -> Optional[str]:
|
|
if not self.api_key:
|
|
logging.error("imgbb requires an API key")
|
|
return None
|
|
|
|
url = "https://api.imgbb.com/1/upload"
|
|
data = {
|
|
'key': self.api_key,
|
|
'image': base64.b64encode(file_data).decode('utf-8')
|
|
}
|
|
|
|
response = requests.post(url, data=data, timeout=30)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
if result.get('success'):
|
|
link = result['data']['url']
|
|
logging.info(f"Uploaded to imgbb: {link}")
|
|
return link
|
|
return None
|
|
|
|
def _upload_envs(self, file_data: bytes, filename: str, mime_type: str) -> Optional[str]:
|
|
url = "https://envs.sh"
|
|
files = {'file': (filename, BytesIO(file_data), mime_type)}
|
|
|
|
response = requests.post(url, files=files, timeout=30)
|
|
response.raise_for_status()
|
|
|
|
result = response.text.strip()
|
|
if result.startswith('https://'):
|
|
logging.info(f"Uploaded to envs.sh: {result}")
|
|
return result
|
|
return None
|
|
|
|
def _upload_uguu(self, file_data: bytes, filename: str, mime_type: str) -> Optional[str]:
|
|
url = "https://uguu.se/upload.php"
|
|
files = {'files[]': (filename, BytesIO(file_data), mime_type)}
|
|
|
|
response = requests.post(url, files=files, timeout=30)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
if result.get('success') and result.get('files'):
|
|
link = result['files'][0]['url']
|
|
logging.info(f"Uploaded to uguu.se: {link}")
|
|
return link
|
|
return None
|
|
|
|
|
|
class StorageImpl(Storage):
|
|
|
|
def __init__(self, json_file_path: Path) -> None:
|
|
super().__init__()
|
|
self.__json_file_path = json_file_path
|
|
self.__data: Dict[str, JSONType] = {}
|
|
try:
|
|
with open(self.__json_file_path, encoding="utf8") as f:
|
|
self.__data = json.load(f)
|
|
except Exception:
|
|
pass
|
|
|
|
async def _load(self, key: str) -> Maybe[JSONType]:
|
|
if key in self.__data:
|
|
return Just(self.__data[key])
|
|
return Nothing()
|
|
|
|
async def _store(self, key: str, value: JSONType) -> None:
|
|
self.__data[key] = value
|
|
with open(self.__json_file_path, "w", encoding="utf8") as f:
|
|
json.dump(self.__data, f)
|
|
|
|
async def _delete(self, key: str) -> None:
|
|
self.__data.pop(key, None)
|
|
with open(self.__json_file_path, "w", encoding="utf8") as f:
|
|
json.dump(self.__data, f)
|
|
|
|
|
|
class MemoryStorage:
|
|
|
|
def __init__(self, file_path: str):
|
|
self.file_path = Path(file_path)
|
|
self.data: Dict[str, list] = {}
|
|
self._load()
|
|
|
|
def _load(self):
|
|
try:
|
|
if self.file_path.exists():
|
|
with open(self.file_path, 'r', encoding='utf-8') as f:
|
|
self.data = json.load(f)
|
|
logging.info(f"Loaded persistent memory from {self.file_path}")
|
|
except Exception as e:
|
|
logging.error(f"Failed to load memory from {self.file_path}: {e}")
|
|
self.data = {}
|
|
|
|
def _save(self):
|
|
try:
|
|
with open(self.file_path, 'w', encoding='utf-8') as f:
|
|
json.dump(self.data, f, ensure_ascii=False, indent=2)
|
|
except Exception as e:
|
|
logging.error(f"Failed to save memory to {self.file_path}: {e}")
|
|
|
|
def get_history(self, room_jid: str) -> list:
|
|
return self.data.get(room_jid, [])[:]
|
|
|
|
def set_history(self, room_jid: str, history: list):
|
|
self.data[room_jid] = history
|
|
self._save()
|
|
|
|
def append_to_history(self, room_jid: str, user_msg: str, assistant_msg: str, limit: int = 0):
|
|
if room_jid not in self.data:
|
|
self.data[room_jid] = []
|
|
|
|
self.data[room_jid].extend([
|
|
{"role": "user", "content": user_msg},
|
|
{"role": "assistant", "content": assistant_msg}
|
|
])
|
|
|
|
if limit > 0 and len(self.data[room_jid]) > limit * 2:
|
|
self.data[room_jid] = self.data[room_jid][-limit * 2:]
|
|
|
|
self._save()
|
|
|
|
def clear_history(self, room_jid: str):
|
|
if room_jid in self.data:
|
|
del self.data[room_jid]
|
|
self._save()
|
|
|
|
|
|
class PluginCouldNotLoad(Exception):
|
|
pass
|
|
|
|
|
|
class XEP_0384Impl(XEP_0384):
|
|
|
|
default_config = {
|
|
"fallback_message": "This message is OMEMO encrypted.",
|
|
"json_file_path": None
|
|
}
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.__storage: Storage
|
|
|
|
def plugin_init(self) -> None:
|
|
if not self.json_file_path:
|
|
raise PluginCouldNotLoad("JSON file path not specified.")
|
|
self.__storage = StorageImpl(Path(self.json_file_path))
|
|
super().plugin_init()
|
|
|
|
@property
|
|
def storage(self) -> Storage:
|
|
return self.__storage
|
|
|
|
@property
|
|
def _btbv_enabled(self) -> bool:
|
|
return True
|
|
|
|
async def _devices_blindly_trusted(
|
|
self,
|
|
blindly_trusted: FrozenSet[DeviceInformation],
|
|
identifier: Optional[str]
|
|
) -> None:
|
|
logging.info(f"[{identifier}] Devices trusted blindly: {blindly_trusted}")
|
|
|
|
async def _prompt_manual_trust(
|
|
self,
|
|
manually_trusted: FrozenSet[DeviceInformation],
|
|
identifier: Optional[str]
|
|
) -> None:
|
|
session_manager = await self.get_session_manager()
|
|
for device in manually_trusted:
|
|
logging.info(f"Auto-trusting device: {device}")
|
|
await session_manager.set_trust(
|
|
device.bare_jid,
|
|
device.identity_key,
|
|
TrustLevel.TRUSTED.value
|
|
)
|
|
|
|
|
|
register_plugin(XEP_0384Impl)
|
|
|
|
|
|
class LLMBot(slixmpp.ClientXMPP):
|
|
|
|
AUDIO_EXTENSIONS = ['.wav', '.mp3', '.ogg', '.opus', '.aac', '.flac', '.m4a', '.wma', '.amr', '.pcm', '.aiff']
|
|
AUDIO_MIME_TYPES = {
|
|
'audio/wav': 'wav',
|
|
'audio/x-wav': 'wav',
|
|
'audio/wave': 'wav',
|
|
'audio/mp3': 'mp3',
|
|
'audio/mpeg': 'mp3',
|
|
'audio/ogg': 'ogg',
|
|
'audio/opus': 'opus',
|
|
'audio/aac': 'aac',
|
|
'audio/flac': 'flac',
|
|
'audio/m4a': 'aac',
|
|
'audio/x-m4a': 'aac',
|
|
'audio/mp4': 'aac',
|
|
'audio/amr': 'amr',
|
|
'audio/pcm': 'pcm',
|
|
'audio/webm': 'webm',
|
|
'audio/aiff': 'aiff',
|
|
'audio/x-aiff': 'aiff'
|
|
}
|
|
|
|
def __init__(self, jid, password, rooms, room_nicknames, trigger, mentions, rate_limit_calls, rate_limit_period,
|
|
max_length, nickname, api_url, privileged_users, max_retries, system_prompts,
|
|
remember_conversations, history_per_room,
|
|
quote_reply=True, mention_reply=True, skip_thinking=False,
|
|
request_timeout=20, allow_dm=True, dm_mode='whitelist', dm_list=None,
|
|
use_openai_api=False, api_token=None, openai_model="gpt-4",
|
|
enable_omemo=True, omemo_store_path="omemo_store.json", omemo_only=False,
|
|
answer_to_links=False, fetch_link_content=False, support_images=False,
|
|
join_retry_attempts=5, join_retry_delay=10,
|
|
persistent_memory=False, memory_file_path="memory.json",
|
|
imagen_trigger="!imagen",
|
|
cf_account_id=None, cf_api_token=None,
|
|
support_audio=False,
|
|
enable_url_context=False,
|
|
file_host='catbox', file_host_api_key=None,
|
|
tts_trigger="!tts", tts_enabled=False, tts_voice_name="Kore",
|
|
tts_model="gemini-2.5-flash-preview-tts", tts_auto_reply=False,
|
|
loop=None):
|
|
self.request_timeout = request_timeout
|
|
super().__init__(jid, password, loop=loop, sasl_mech='PLAIN')
|
|
self.enable_direct_tls = True
|
|
|
|
self.rooms = rooms
|
|
self.room_nicknames = room_nicknames or {}
|
|
self.trigger = trigger
|
|
self.mentions = mentions
|
|
self.max_length = max_length
|
|
self.nickname = nickname
|
|
|
|
self.api_url = api_url
|
|
self.use_openai_api = use_openai_api
|
|
self.api_token = api_token
|
|
self.openai_model = openai_model
|
|
|
|
self.privileged_users = {u.lower() for u in privileged_users}
|
|
|
|
self.max_retries = max_retries
|
|
self.system_prompts = system_prompts
|
|
|
|
self.remember = remember_conversations
|
|
self.history_limit = history_per_room
|
|
self.history = {}
|
|
|
|
self.persistent_memory = persistent_memory
|
|
self.memory_file_path = memory_file_path
|
|
self.memory_storage = None
|
|
if self.persistent_memory:
|
|
self.memory_storage = MemoryStorage(self.memory_file_path)
|
|
logging.info(f"Persistent memory enabled, using {self.memory_file_path}")
|
|
|
|
self.quote_reply = quote_reply
|
|
self.mention_reply = mention_reply
|
|
self.skip_thinking = skip_thinking
|
|
|
|
self.allow_dm = allow_dm
|
|
self.dm_mode = dm_mode.lower()
|
|
self.dm_list = {x.lower() for x in (dm_list or [])}
|
|
|
|
self.enable_omemo = enable_omemo
|
|
self.omemo_store_path = omemo_store_path
|
|
self.omemo_only = omemo_only
|
|
|
|
self.answer_to_links = answer_to_links
|
|
self.fetch_link_content = fetch_link_content
|
|
self.support_images = support_images
|
|
self.support_audio = support_audio
|
|
self.join_retry_attempts = join_retry_attempts
|
|
self.join_retry_delay = join_retry_delay
|
|
|
|
self.imagen_trigger = imagen_trigger
|
|
self.cf_account_id = cf_account_id
|
|
self.cf_api_token = cf_api_token
|
|
|
|
self.enable_url_context = enable_url_context
|
|
|
|
self.file_uploader = FileUploader(service=file_host, api_key=file_host_api_key, xmpp_client=self)
|
|
|
|
self.tts_trigger = tts_trigger
|
|
self.tts_enabled = tts_enabled
|
|
self.tts_voice_name = tts_voice_name
|
|
self.tts_model = tts_model
|
|
self.tts_auto_reply = tts_auto_reply
|
|
|
|
self.genai_client = None
|
|
if GOOGLE_GENAI_AVAILABLE and self.api_token:
|
|
try:
|
|
self.genai_client = genai.Client(api_key=self.api_token)
|
|
logging.info("Initialized Google GenAI client")
|
|
except Exception as e:
|
|
logging.error(f"Failed to initialize Google GenAI client: {e}")
|
|
|
|
self.openai_client = None
|
|
if self.use_openai_api and OPENAI_AVAILABLE and self.api_token:
|
|
try:
|
|
base_url = self.api_url
|
|
if base_url.endswith('/chat/completions'):
|
|
base_url = base_url[:-len('/chat/completions')]
|
|
elif base_url.endswith('/v1/chat/completions'):
|
|
base_url = base_url[:-len('/chat/completions')]
|
|
|
|
self.openai_client = OpenAI(api_key=self.api_token, base_url=base_url)
|
|
logging.info(f"Initialized OpenAI client with base URL: {base_url}")
|
|
except Exception as e:
|
|
logging.error(f"Failed to initialize OpenAI client: {e}")
|
|
|
|
self.url_pattern = re.compile(
|
|
r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
|
|
)
|
|
self.aesgcm_pattern = re.compile(r'aesgcm://[^\s]+')
|
|
|
|
self.register_plugin('xep_0030')
|
|
self.register_plugin('xep_0045')
|
|
self.register_plugin('xep_0199')
|
|
self.register_plugin('xep_0066')
|
|
self.register_plugin('xep_0363')
|
|
|
|
if self.enable_omemo:
|
|
self.register_plugin('xep_0085')
|
|
self.register_plugin('xep_0380')
|
|
|
|
import sys
|
|
self.register_plugin(
|
|
"xep_0384",
|
|
{"json_file_path": self.omemo_store_path},
|
|
module=sys.modules[__name__]
|
|
)
|
|
logging.info(f"OMEMO support enabled (OMEMO Only Mode: {self.omemo_only})")
|
|
|
|
self.add_event_handler("session_start", self.start)
|
|
self.add_event_handler("groupchat_message", self.groupchat_message)
|
|
|
|
if self.enable_omemo:
|
|
self.register_handler(CoroutineCallback(
|
|
"DirectMessages",
|
|
MatchXPath(f"{{{self.default_ns}}}message[@type='chat']"),
|
|
self.direct_message_async
|
|
))
|
|
else:
|
|
self.add_event_handler("message", self.direct_message)
|
|
|
|
self.rate_limited_send = sleep_and_retry(
|
|
limits(calls=rate_limit_calls, period=rate_limit_period)(self.send_to_llm)
|
|
)
|
|
|
|
def is_gemini_api(self):
|
|
return "generativelanguage.googleapis.com" in self.api_url
|
|
|
|
def start(self, event):
|
|
self.send_presence()
|
|
self.get_roster()
|
|
|
|
for room in self.rooms:
|
|
nick = self.room_nicknames.get(room, self.nickname)
|
|
self.join_room_with_retry(room, nick)
|
|
|
|
def join_room_with_retry(self, room, nick):
|
|
for attempt in range(self.join_retry_attempts):
|
|
try:
|
|
self.plugin['xep_0045'].join_muc(room, nick)
|
|
logging.info(f"Successfully joined room: {room} as {nick}")
|
|
return True
|
|
except Exception as e:
|
|
logging.error(f"Failed to join {room} (attempt {attempt+1}/{self.join_retry_attempts}): {e}")
|
|
if attempt < self.join_retry_attempts - 1:
|
|
logging.info(f"Retrying in {self.join_retry_delay} seconds...")
|
|
time.sleep(self.join_retry_delay)
|
|
else:
|
|
logging.error(f"Failed to join {room} after {self.join_retry_attempts} attempts")
|
|
return False
|
|
|
|
def extract_urls(self, text):
|
|
return self.url_pattern.findall(text)
|
|
|
|
def clean_aesgcm_urls(self, text):
|
|
return self.aesgcm_pattern.sub('', text).strip()
|
|
|
|
def fetch_url_content(self, url, max_length=5000):
|
|
try:
|
|
headers = {
|
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/142.0.0.0 Safari/537.36'
|
|
}
|
|
response = requests.get(url, headers=headers, timeout=10, allow_redirects=True)
|
|
response.raise_for_status()
|
|
|
|
content_type = response.headers.get('content-type', '').lower()
|
|
|
|
if 'text/html' in content_type:
|
|
soup = BeautifulSoup(response.content, 'html.parser')
|
|
|
|
for script in soup(["script", "style", "nav", "footer", "header"]):
|
|
script.decompose()
|
|
|
|
text = soup.get_text(separator='\n', strip=True)
|
|
|
|
lines = [line.strip() for line in text.splitlines() if line.strip()]
|
|
text = '\n'.join(lines)
|
|
|
|
if len(text) > max_length:
|
|
text = text[:max_length] + "... (truncated)"
|
|
|
|
return f"Content from {url}:\n{text}"
|
|
|
|
elif 'text/plain' in content_type:
|
|
text = response.text
|
|
if len(text) > max_length:
|
|
text = text[:max_length] + "... (truncated)"
|
|
return f"Content from {url}:\n{text}"
|
|
|
|
else:
|
|
return f"URL {url} contains non-text content ({content_type})"
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error fetching URL {url}: {e}")
|
|
return f"Error fetching content from {url}: {str(e)}"
|
|
|
|
def decrypt_aesgcm_url(self, url):
|
|
try:
|
|
parsed = urlparse(url)
|
|
|
|
if parsed.scheme != 'aesgcm':
|
|
return None
|
|
|
|
if not parsed.fragment:
|
|
logging.error("AESGCM URL missing fragment with key/IV")
|
|
return None
|
|
|
|
https_url = urlunparse(('https', parsed.netloc, parsed.path, parsed.params, parsed.query, ''))
|
|
|
|
try:
|
|
key_iv = bytes.fromhex(parsed.fragment)
|
|
except ValueError as e:
|
|
logging.error(f"Invalid hex in AESGCM fragment: {e}")
|
|
return None
|
|
|
|
key_iv_len = len(key_iv)
|
|
|
|
if key_iv_len == 44:
|
|
iv = key_iv[:12]
|
|
key = key_iv[12:44]
|
|
elif key_iv_len == 48:
|
|
key = key_iv[:32]
|
|
iv = key_iv[32:48]
|
|
elif key_iv_len >= 44:
|
|
iv = key_iv[:12]
|
|
key = key_iv[12:44]
|
|
else:
|
|
logging.error(f"AESGCM key/IV unexpected length: {key_iv_len} bytes")
|
|
return None
|
|
|
|
logging.info(f"Decrypting AESGCM URL: {https_url} (key_iv_len={key_iv_len})")
|
|
|
|
headers = {
|
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/142.0.0.0 Safari/537.36'
|
|
}
|
|
response = requests.get(https_url, timeout=30, headers=headers, allow_redirects=True)
|
|
response.raise_for_status()
|
|
|
|
encrypted_data = response.content
|
|
|
|
aesgcm = AESGCM(key)
|
|
decrypted_data = aesgcm.decrypt(iv, encrypted_data, None)
|
|
|
|
logging.info(f"Successfully decrypted AESGCM data: {len(decrypted_data)} bytes")
|
|
|
|
return decrypted_data
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error decrypting AESGCM URL: {e}", exc_info=True)
|
|
return None
|
|
|
|
def is_audio_url(self, url):
|
|
if url.startswith('aesgcm://'):
|
|
path = urlparse(url.replace('aesgcm://', 'https://')).path.lower()
|
|
return any(path.endswith(ext) for ext in self.AUDIO_EXTENSIONS)
|
|
return any(url.lower().endswith(ext) for ext in self.AUDIO_EXTENSIONS)
|
|
|
|
def get_audio_format_from_mime(self, mime_type):
|
|
mime_lower = mime_type.lower()
|
|
return self.AUDIO_MIME_TYPES.get(mime_lower, 'wav')
|
|
|
|
def get_audio_format_from_extension(self, url):
|
|
path = urlparse(url.replace('aesgcm://', 'https://')).path.lower()
|
|
for ext in self.AUDIO_EXTENSIONS:
|
|
if path.endswith(ext):
|
|
return ext[1:]
|
|
return 'wav'
|
|
|
|
def get_audio_mime_from_format(self, fmt):
|
|
mime_map = {
|
|
'wav': 'audio/wav',
|
|
'mp3': 'audio/mp3',
|
|
'ogg': 'audio/ogg',
|
|
'opus': 'audio/opus',
|
|
'aac': 'audio/aac',
|
|
'flac': 'audio/flac',
|
|
'aiff': 'audio/aiff',
|
|
'amr': 'audio/amr',
|
|
'pcm': 'audio/pcm',
|
|
'webm': 'audio/webm'
|
|
}
|
|
return mime_map.get(fmt, 'audio/wav')
|
|
|
|
def fetch_audio_from_url(self, url):
|
|
try:
|
|
logging.info(f"Fetching audio from URL: {url}")
|
|
|
|
if url.startswith('aesgcm://'):
|
|
decrypted_data = self.decrypt_aesgcm_url(url)
|
|
if not decrypted_data:
|
|
logging.error(f"Failed to decrypt AESGCM audio URL: {url}")
|
|
return None
|
|
|
|
audio_format = self.get_audio_format_from_extension(url)
|
|
|
|
return {
|
|
'format': audio_format,
|
|
'data': decrypted_data,
|
|
'mime_type': self.get_audio_mime_from_format(audio_format)
|
|
}
|
|
|
|
headers = {
|
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/142.0.0.0 Safari/537.36'
|
|
}
|
|
response = requests.get(url, timeout=30, headers=headers, allow_redirects=True)
|
|
response.raise_for_status()
|
|
|
|
content_type = response.headers.get('content-type', '').lower()
|
|
|
|
is_audio = any(audio_type in content_type for audio_type in self.AUDIO_MIME_TYPES.keys())
|
|
|
|
if not is_audio:
|
|
audio_format = self.get_audio_format_from_extension(url)
|
|
if not audio_format:
|
|
logging.warning(f"URL returned non-audio content-type: {content_type}")
|
|
return None
|
|
else:
|
|
audio_format = self.get_audio_format_from_mime(content_type)
|
|
|
|
logging.info(f"Successfully fetched audio: {audio_format}, size: {len(response.content)} bytes")
|
|
|
|
return {
|
|
'format': audio_format,
|
|
'data': response.content,
|
|
'mime_type': self.get_audio_mime_from_format(audio_format)
|
|
}
|
|
except Exception as e:
|
|
logging.error(f"Error fetching audio from {url}: {e}")
|
|
return None
|
|
|
|
def extract_audio_from_message(self, msg, decrypted_body=None):
|
|
try:
|
|
if hasattr(msg, 'xml'):
|
|
xml_elem = msg.xml
|
|
elif hasattr(msg, '_get_stanza_values'):
|
|
xml_elem = msg
|
|
else:
|
|
logging.debug("Message object doesn't have accessible XML")
|
|
return None
|
|
|
|
oob = xml_elem.find('.//{jabber:x:oob}url')
|
|
if oob is not None and oob.text:
|
|
audio_url = oob.text.strip()
|
|
if self.is_audio_url(audio_url):
|
|
logging.info(f"Found OOB audio URL: {audio_url}")
|
|
return self.fetch_audio_from_url(audio_url)
|
|
|
|
if decrypted_body:
|
|
url_data = self.extract_urls_and_media(decrypted_body)
|
|
if url_data['audio_urls']:
|
|
logging.info(f"Found {len(url_data['audio_urls'])} audio URLs in decrypted body")
|
|
for audio_url in url_data['audio_urls']:
|
|
audio_data = self.fetch_audio_from_url(audio_url)
|
|
if audio_data:
|
|
return audio_data
|
|
|
|
body_elem = xml_elem.find('.//{jabber:client}body')
|
|
if body_elem is None:
|
|
body_elem = xml_elem.find('.//body')
|
|
|
|
if body_elem is not None and body_elem.text:
|
|
body_text = body_elem.text
|
|
if "doesn't seem to support that" not in body_text and "OMEMO" not in body_text:
|
|
url_data = self.extract_urls_and_media(body_text)
|
|
if url_data['audio_urls']:
|
|
logging.info(f"Found {len(url_data['audio_urls'])} audio URLs in body")
|
|
for audio_url in url_data['audio_urls']:
|
|
return self.fetch_audio_from_url(audio_url)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error extracting audio: {e}", exc_info=True)
|
|
|
|
return None
|
|
|
|
def extract_urls_and_media(self, text):
|
|
image_urls = []
|
|
audio_urls = []
|
|
regular_urls = []
|
|
|
|
aesgcm_urls = self.aesgcm_pattern.findall(text)
|
|
|
|
for url in aesgcm_urls:
|
|
if self.is_image_url(url):
|
|
image_urls.append(url)
|
|
elif self.is_audio_url(url):
|
|
audio_urls.append(url)
|
|
else:
|
|
regular_urls.append(url)
|
|
|
|
http_urls = self.url_pattern.findall(text)
|
|
for url in http_urls:
|
|
if self.is_image_url(url):
|
|
image_urls.append(url)
|
|
elif self.is_audio_url(url):
|
|
audio_urls.append(url)
|
|
else:
|
|
regular_urls.append(url)
|
|
|
|
return {
|
|
'image_urls': image_urls,
|
|
'audio_urls': audio_urls,
|
|
'regular_urls': regular_urls,
|
|
'all_urls': image_urls + audio_urls + regular_urls
|
|
}
|
|
|
|
def extract_image_from_message(self, msg, decrypted_body=None):
|
|
try:
|
|
if hasattr(msg, 'xml'):
|
|
xml_elem = msg.xml
|
|
elif hasattr(msg, '_get_stanza_values'):
|
|
xml_elem = msg
|
|
else:
|
|
logging.debug("Message object doesn't have accessible XML")
|
|
return None
|
|
|
|
bob = xml_elem.find('.//{urn:xmpp:bob}data')
|
|
if bob is not None:
|
|
mime_type = bob.get('type', 'image/jpeg')
|
|
image_data = bob.text
|
|
if image_data:
|
|
logging.info(f"Found BoB image: {mime_type}")
|
|
return {
|
|
'mime_type': mime_type,
|
|
'data': base64.b64decode(image_data.strip())
|
|
}
|
|
|
|
oob = xml_elem.find('.//{jabber:x:oob}url')
|
|
if oob is not None and oob.text:
|
|
image_url = oob.text.strip()
|
|
logging.info(f"Found OOB URL: {image_url}")
|
|
|
|
if self.is_image_url(image_url):
|
|
logging.info(f"OOB URL is an image, fetching...")
|
|
return self.fetch_image_from_url(image_url)
|
|
|
|
if decrypted_body:
|
|
logging.debug(f"Parsing decrypted body for image URLs...")
|
|
url_data = self.extract_urls_and_media(decrypted_body)
|
|
|
|
if url_data['image_urls']:
|
|
logging.info(f"Found {len(url_data['image_urls'])} image URLs in decrypted OMEMO body")
|
|
for img_url in url_data['image_urls']:
|
|
logging.info(f"Fetching image from decrypted body URL: {img_url}")
|
|
image_data = self.fetch_image_from_url(img_url)
|
|
if image_data:
|
|
return image_data
|
|
|
|
body_elem = xml_elem.find('.//{jabber:client}body')
|
|
if body_elem is None:
|
|
body_elem = xml_elem.find('.//body')
|
|
|
|
if body_elem is not None and body_elem.text:
|
|
body_text = body_elem.text
|
|
|
|
if "doesn't seem to support that" not in body_text and "OMEMO" not in body_text:
|
|
url_data = self.extract_urls_and_media(body_text)
|
|
|
|
if url_data['image_urls']:
|
|
for img_url in url_data['image_urls']:
|
|
return self.fetch_image_from_url(img_url)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error extracting image: {e}", exc_info=True)
|
|
|
|
return None
|
|
|
|
def fetch_image_from_url(self, url):
|
|
try:
|
|
logging.info(f"Fetching image from URL: {url}")
|
|
|
|
if url.startswith('aesgcm://'):
|
|
decrypted_data = self.decrypt_aesgcm_url(url)
|
|
if not decrypted_data:
|
|
return None
|
|
|
|
mime_type = 'image/jpeg'
|
|
if decrypted_data[:4] == b'\x89PNG':
|
|
mime_type = 'image/png'
|
|
elif decrypted_data[:3] == b'GIF':
|
|
mime_type = 'image/gif'
|
|
elif decrypted_data[:2] == b'\xff\xd8':
|
|
mime_type = 'image/jpeg'
|
|
elif len(decrypted_data) > 12 and decrypted_data[:4] == b'RIFF' and decrypted_data[8:12] == b'WEBP':
|
|
mime_type = 'image/webp'
|
|
|
|
logging.info(f"Successfully processed AESGCM image: {mime_type}, size: {len(decrypted_data)} bytes")
|
|
|
|
return {
|
|
'mime_type': mime_type,
|
|
'data': decrypted_data
|
|
}
|
|
|
|
headers = {
|
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/142.0.0.0 Safari/537.36'
|
|
}
|
|
response = requests.get(url, timeout=10, headers=headers, allow_redirects=True)
|
|
response.raise_for_status()
|
|
|
|
content_type = response.headers.get('content-type', '').lower()
|
|
|
|
valid_types = ['image/jpeg', 'image/jpg', 'image/png', 'image/gif', 'image/webp', 'image/bmp']
|
|
is_image = any(img_type in content_type for img_type in valid_types)
|
|
|
|
if not is_image:
|
|
logging.warning(f"URL returned non-image content-type: {content_type}")
|
|
return None
|
|
|
|
if 'jpeg' in content_type or 'jpg' in content_type:
|
|
mime_type = 'image/jpeg'
|
|
elif 'png' in content_type:
|
|
mime_type = 'image/png'
|
|
elif 'gif' in content_type:
|
|
mime_type = 'image/gif'
|
|
elif 'webp' in content_type:
|
|
mime_type = 'image/webp'
|
|
else:
|
|
mime_type = 'image/jpeg'
|
|
|
|
logging.info(f"Successfully fetched image: {mime_type}, size: {len(response.content)} bytes")
|
|
|
|
return {
|
|
'mime_type': mime_type,
|
|
'data': response.content
|
|
}
|
|
except Exception as e:
|
|
logging.error(f"Error fetching image from {url}: {e}")
|
|
return None
|
|
|
|
def sanitize_input(self, text):
|
|
if not isinstance(text, str):
|
|
return ""
|
|
|
|
text = text.replace('\x00', '')
|
|
text = ''.join(char for char in text if char.isprintable() or char in '\n\r\t')
|
|
text = '\n'.join(line.strip() for line in text.split('\n'))
|
|
text = re.sub(r'\n{3,}', '\n\n', text)
|
|
|
|
return text.strip()
|
|
|
|
def extract_quoted_text(self, body):
|
|
lines = body.split('\n')
|
|
quoted = []
|
|
non_quoted = []
|
|
|
|
for line in lines:
|
|
stripped = line.strip()
|
|
if stripped.startswith('>'):
|
|
quoted.append(stripped[1:].strip())
|
|
else:
|
|
non_quoted.append(line)
|
|
|
|
return '\n'.join(quoted), '\n'.join(non_quoted)
|
|
|
|
def is_replying_to_bot(self, msg):
|
|
room_jid = msg['from'].bare
|
|
bot_nick = self.room_nicknames.get(room_jid, self.nickname)
|
|
body = msg['body']
|
|
|
|
quoted_text, _ = self.extract_quoted_text(body)
|
|
if quoted_text:
|
|
quoted_lines = quoted_text.split('\n')
|
|
for line in quoted_lines:
|
|
line = line.strip()
|
|
|
|
while line.startswith('>'):
|
|
line = line[1:].strip()
|
|
|
|
if line.startswith(f"{bot_nick}:") or line.startswith(f"{bot_nick} "):
|
|
return True
|
|
|
|
if line.lower().startswith(f"{bot_nick.lower()}:") or line.lower().startswith(f"{bot_nick.lower()} "):
|
|
return True
|
|
|
|
for punct in [',', ';', '-', '—']:
|
|
if line.startswith(f"{bot_nick}{punct}") or line.lower().startswith(f"{bot_nick.lower()}{punct}"):
|
|
return True
|
|
|
|
return False
|
|
|
|
def contains_links(self, text):
|
|
return bool(self.url_pattern.search(text)) or bool(self.aesgcm_pattern.search(text))
|
|
|
|
def is_image_url(self, url):
|
|
if url.startswith('aesgcm://'):
|
|
path = urlparse(url.replace('aesgcm://', 'https://')).path.lower()
|
|
return any(path.endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp'])
|
|
return any(url.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp'])
|
|
|
|
def extract_urls_and_images(self, text):
|
|
image_urls = []
|
|
regular_urls = []
|
|
|
|
aesgcm_urls = self.aesgcm_pattern.findall(text)
|
|
|
|
for url in aesgcm_urls:
|
|
if self.is_image_url(url):
|
|
image_urls.append(url)
|
|
elif not self.is_audio_url(url):
|
|
regular_urls.append(url)
|
|
|
|
http_urls = self.url_pattern.findall(text)
|
|
for url in http_urls:
|
|
if self.is_image_url(url):
|
|
image_urls.append(url)
|
|
elif not self.is_audio_url(url):
|
|
regular_urls.append(url)
|
|
|
|
return {
|
|
'image_urls': image_urls,
|
|
'regular_urls': regular_urls,
|
|
'all_urls': image_urls + regular_urls
|
|
}
|
|
|
|
def _create_wave_file(self, pcm_data, channels=1, rate=24000, sample_width=2):
|
|
buffer = BytesIO()
|
|
with wave.open(buffer, "wb") as wf:
|
|
wf.setnchannels(channels)
|
|
wf.setsampwidth(sample_width)
|
|
wf.setframerate(rate)
|
|
wf.writeframes(pcm_data)
|
|
return buffer.getvalue()
|
|
|
|
def _encrypt_and_upload(self, data: bytes, filename: str, should_encrypt: bool = False) -> Optional[str]:
|
|
""""""
|
|
|
|
|
|
|
|
upload_mime = "application/octet-stream"
|
|
|
|
if not should_encrypt:
|
|
if filename.endswith('.png'): upload_mime = "image/png"
|
|
elif filename.endswith('.jpg'): upload_mime = "image/jpeg"
|
|
elif filename.endswith('.wav'): upload_mime = "audio/wav"
|
|
elif filename.endswith('.ogg'): upload_mime = "audio/ogg"
|
|
|
|
if should_encrypt:
|
|
try:
|
|
|
|
key = os.urandom(32)
|
|
iv = os.urandom(12)
|
|
|
|
|
|
|
|
aesgcm = AESGCM(key)
|
|
encrypted_data = aesgcm.encrypt(iv, data, None)
|
|
|
|
|
|
|
|
|
|
upload_url = self.file_uploader.upload(encrypted_data, filename, "application/octet-stream")
|
|
|
|
if not upload_url:
|
|
return None
|
|
|
|
|
|
|
|
fragment = iv.hex() + key.hex()
|
|
|
|
parsed = urlparse(upload_url)
|
|
|
|
|
|
|
|
new_url = urlunparse(('aesgcm', parsed.netloc, parsed.path, '', '', fragment))
|
|
|
|
logging.info(f"Encrypted upload successful: {new_url}")
|
|
return new_url
|
|
|
|
except Exception as e:
|
|
logging.error(f"Encryption failed: {e}")
|
|
return None
|
|
else:
|
|
return self.file_uploader.upload(data, filename, upload_mime)
|
|
|
|
def synthesize_speech(self, text, should_encrypt=False, max_retries=3):
|
|
if not self.tts_enabled:
|
|
logging.error("TTS not enabled")
|
|
return None
|
|
|
|
if not self.genai_client:
|
|
logging.error("GenAI client not initialized for TTS")
|
|
return None
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
logging.info(f"Synthesizing speech with Gemini TTS: {text[:50]}...")
|
|
|
|
response = self.genai_client.models.generate_content(
|
|
model=self.tts_model,
|
|
contents=text,
|
|
config=types.GenerateContentConfig(
|
|
response_modalities=["AUDIO"],
|
|
speech_config=types.SpeechConfig(
|
|
voice_config=types.VoiceConfig(
|
|
prebuilt_voice_config=types.PrebuiltVoiceConfig(
|
|
voice_name=self.tts_voice_name,
|
|
)
|
|
)
|
|
),
|
|
)
|
|
)
|
|
|
|
if (response.candidates and
|
|
response.candidates[0].content and
|
|
response.candidates[0].content.parts):
|
|
|
|
audio_part = response.candidates[0].content.parts[0]
|
|
if hasattr(audio_part, 'inline_data') and audio_part.inline_data:
|
|
raw_pcm_data = audio_part.inline_data.data
|
|
|
|
final_data = None
|
|
filename = "tts.wav"
|
|
|
|
if PYDUB_AVAILABLE:
|
|
try:
|
|
sound = AudioSegment(
|
|
data=raw_pcm_data,
|
|
sample_width=2,
|
|
frame_rate=24000,
|
|
channels=1
|
|
)
|
|
|
|
buffer = BytesIO()
|
|
sound.export(buffer, format="ogg", codec="libopus")
|
|
final_data = buffer.getvalue()
|
|
filename = "tts.ogg"
|
|
logging.info(f"Converted TTS to Ogg Opus, size: {len(final_data)} bytes")
|
|
except Exception as e:
|
|
logging.error(f"Audio conversion failed (missing ffmpeg?): {e}")
|
|
|
|
if final_data is None:
|
|
logging.info("Falling back to WAV format")
|
|
final_data = self._create_wave_file(raw_pcm_data)
|
|
|
|
upload_url = self._encrypt_and_upload(final_data, filename, should_encrypt)
|
|
|
|
if upload_url:
|
|
return {"type": "url", "content": upload_url}
|
|
else:
|
|
|
|
return {"type": "base64", "content": base64.b64encode(final_data).decode('utf-8')}
|
|
|
|
logging.error("No audio data in TTS response")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logging.error(f"TTS error (attempt {attempt+1}): {e}")
|
|
time.sleep(1)
|
|
|
|
return None
|
|
|
|
def generate_image(self, prompt, should_encrypt=False, max_retries=3):
|
|
if not self.cf_account_id or not self.cf_api_token:
|
|
logging.error("Cloudflare credentials not configured for image generation")
|
|
return None
|
|
|
|
url = f"https://api.cloudflare.com/client/v4/accounts/{self.cf_account_id}/ai/run/@cf/black-forest-labs/flux-2-dev"
|
|
headers = {"Authorization": f"Bearer {self.cf_api_token}"}
|
|
|
|
files = {
|
|
'prompt': (None, prompt),
|
|
'width': (None, '1024'),
|
|
'height': (None, '1024'),
|
|
'steps': (None, '20')
|
|
}
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
logging.info(f"Generating image with Cloudflare Workers AI: {prompt[:50]}...")
|
|
|
|
response = requests.post(url, headers=headers, files=files, timeout=60)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
|
|
b64_image = None
|
|
if 'result' in result and isinstance(result['result'], dict) and 'image' in result['result']:
|
|
b64_image = result['result']['image']
|
|
elif 'image' in result:
|
|
b64_image = result['image']
|
|
|
|
if b64_image:
|
|
image_bytes = base64.b64decode(b64_image)
|
|
logging.info(f"Image generated, size: {len(image_bytes)} bytes")
|
|
|
|
upload_url = self._encrypt_and_upload(image_bytes, "generated.png", should_encrypt)
|
|
|
|
if upload_url:
|
|
return {"type": "url", "content": upload_url}
|
|
else:
|
|
return {"type": "base64", "content": b64_image}
|
|
|
|
logging.error(f"No image in Cloudflare response: {result}")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logging.error(f"Image generation error (attempt {attempt+1}): {e}")
|
|
time.sleep(2)
|
|
|
|
return None
|
|
|
|
def direct_message(self, msg):
|
|
if not self.allow_dm:
|
|
return
|
|
|
|
if msg['type'] not in ('chat', 'normal'):
|
|
return
|
|
|
|
if msg['from'].bare == self.boundjid.bare:
|
|
return
|
|
|
|
sender = msg['from'].bare.lower()
|
|
|
|
if self.dm_mode == 'whitelist' and sender not in self.dm_list:
|
|
return
|
|
if self.dm_mode == 'blacklist' and sender in self.dm_list:
|
|
return
|
|
|
|
|
|
if self.enable_omemo and self.omemo_only:
|
|
logging.warning(f"Ignoring plaintext DM from {sender} (OMEMO only mode active).")
|
|
return
|
|
|
|
body = self.sanitize_input(msg['body'])
|
|
if not body or len(body) > self.max_length:
|
|
return
|
|
|
|
logging.info(f"Direct message from {sender}: {body[:50]}...")
|
|
|
|
if body.startswith(self.imagen_trigger):
|
|
prompt = body[len(self.imagen_trigger):].strip()
|
|
if prompt:
|
|
|
|
result = self.generate_image(prompt, should_encrypt=False)
|
|
if result:
|
|
if result["type"] == "url":
|
|
response = result['content']
|
|
reply = msg.reply(response)
|
|
reply['oob']['url'] = response
|
|
reply.send()
|
|
else:
|
|
msg.reply("Image generated (base64)").send()
|
|
else:
|
|
msg.reply("Failed to generate image.").send()
|
|
return
|
|
|
|
if self.tts_enabled and body.startswith(self.tts_trigger):
|
|
text = body[len(self.tts_trigger):].strip()
|
|
if text:
|
|
result = self.synthesize_speech(text, should_encrypt=False)
|
|
if result:
|
|
if result["type"] == "url":
|
|
response = result['content']
|
|
reply = msg.reply(response)
|
|
reply['oob']['url'] = response
|
|
reply.send()
|
|
else:
|
|
msg.reply("Speech synthesized (base64)").send()
|
|
else:
|
|
msg.reply("Failed to synthesize speech.").send()
|
|
return
|
|
|
|
quoted_text, non_quoted_text = self.extract_quoted_text(body)
|
|
if non_quoted_text.strip():
|
|
query = non_quoted_text.strip()
|
|
else:
|
|
query = body
|
|
|
|
has_links = self.contains_links(query)
|
|
|
|
image_data = None
|
|
audio_data = None
|
|
|
|
if self.support_images:
|
|
image_data = self.extract_image_from_message(msg)
|
|
if image_data:
|
|
logging.info(f"Image detected in DM from {sender}")
|
|
|
|
if self.support_audio:
|
|
audio_data = self.extract_audio_from_message(msg)
|
|
if audio_data:
|
|
logging.info(f"Audio detected in DM from {sender}")
|
|
|
|
query = self.clean_aesgcm_urls(query)
|
|
|
|
link_context = ""
|
|
if self.fetch_link_content and has_links:
|
|
urls = self.extract_urls_and_media(query)
|
|
|
|
for url in urls['regular_urls'][:3]:
|
|
if not url.startswith('aesgcm://'):
|
|
content = self.fetch_url_content(url)
|
|
if content:
|
|
link_context += f"\n\n{content}\n"
|
|
|
|
if link_context:
|
|
query = f"{query}\n\n--- Additional context from links ---{link_context}"
|
|
|
|
response = self.rate_limited_send(query, max_retries=self.max_retries, room_jid=sender, image_data=image_data, audio_data=audio_data)
|
|
|
|
if response:
|
|
msg.reply(response).send()
|
|
logging.info(f"Replied to {sender}")
|
|
|
|
if self.tts_auto_reply and self.tts_enabled:
|
|
tts_result = self.synthesize_speech(response, should_encrypt=False)
|
|
if tts_result and tts_result["type"] == "url":
|
|
reply = msg.reply(tts_result['content'])
|
|
reply['oob']['url'] = tts_result['content']
|
|
reply.send()
|
|
|
|
async def direct_message_async(self, stanza: Message) -> None:
|
|
if not self.allow_dm:
|
|
return
|
|
|
|
mfrom = stanza["from"]
|
|
mtype = stanza["type"]
|
|
|
|
if mtype not in {"chat", "normal"}:
|
|
return
|
|
|
|
if mfrom.bare == self.boundjid.bare:
|
|
return
|
|
|
|
sender = mfrom.bare.lower()
|
|
|
|
if self.dm_mode == 'whitelist' and sender not in self.dm_list:
|
|
return
|
|
if self.dm_mode == 'blacklist' and sender in self.dm_list:
|
|
return
|
|
|
|
xep_0384 = self["xep_0384"]
|
|
namespace = xep_0384.is_encrypted(stanza)
|
|
|
|
body = None
|
|
is_encrypted = False
|
|
decrypted_body = None
|
|
|
|
if namespace:
|
|
logging.debug(f"Encrypted message received from {mfrom}")
|
|
try:
|
|
decrypted_msg, device_info = await xep_0384.decrypt_message(stanza)
|
|
if decrypted_msg.get("body"):
|
|
body = self.sanitize_input(decrypted_msg["body"])
|
|
decrypted_body = body
|
|
is_encrypted = True
|
|
logging.info(f"Decrypted message from {sender}: {body[:50]}...")
|
|
except Exception as e:
|
|
logging.error(f"Decryption failed: {e}")
|
|
await self._plain_reply(mfrom, mtype, f"Error decrypting message: {e}")
|
|
return
|
|
else:
|
|
if stanza["body"]:
|
|
body = self.sanitize_input(stanza["body"])
|
|
logging.info(f"Plaintext message from {sender}: {body[:50]}...")
|
|
|
|
|
|
if self.omemo_only and not is_encrypted:
|
|
logging.warning(f"Ignoring plaintext DM from {sender} (OMEMO only mode active).")
|
|
return
|
|
|
|
if not body or len(body) > self.max_length:
|
|
return
|
|
|
|
if body.startswith(self.imagen_trigger):
|
|
prompt = body[len(self.imagen_trigger):].strip()
|
|
if prompt:
|
|
loop = asyncio.get_running_loop()
|
|
|
|
result = await loop.run_in_executor(None, self.generate_image, prompt, is_encrypted)
|
|
|
|
if result:
|
|
response = result['content'] if result["type"] == "url" else "Image generated (base64)"
|
|
else:
|
|
response = "Failed to generate image."
|
|
|
|
if is_encrypted:
|
|
await self._encrypted_reply(mfrom, mtype, response)
|
|
else:
|
|
await self._plain_reply(mfrom, mtype, response)
|
|
return
|
|
|
|
if self.tts_enabled and body.startswith(self.tts_trigger):
|
|
text = body[len(self.tts_trigger):].strip()
|
|
if text:
|
|
loop = asyncio.get_running_loop()
|
|
result = await loop.run_in_executor(None, self.synthesize_speech, text, is_encrypted)
|
|
|
|
if result:
|
|
response = result['content'] if result["type"] == "url" else "Speech synthesized (base64)"
|
|
else:
|
|
response = "Failed to synthesize speech."
|
|
|
|
if is_encrypted:
|
|
await self._encrypted_reply(mfrom, mtype, response)
|
|
else:
|
|
await self._plain_reply(mfrom, mtype, response)
|
|
return
|
|
|
|
quoted_text, non_quoted_text = self.extract_quoted_text(body)
|
|
if non_quoted_text.strip():
|
|
query = non_quoted_text.strip()
|
|
else:
|
|
query = body
|
|
|
|
has_links = self.contains_links(query)
|
|
query = self.clean_aesgcm_urls(query)
|
|
|
|
link_context = ""
|
|
if self.fetch_link_content and has_links:
|
|
urls = self.extract_urls(query)
|
|
|
|
for url in urls[:3]:
|
|
if not url.startswith('aesgcm://'):
|
|
content = self.fetch_url_content(url)
|
|
if content:
|
|
link_context += f"\n\n{content}\n"
|
|
|
|
if link_context:
|
|
query = f"{query}\n\n--- Additional context from links ---{link_context}"
|
|
|
|
image_data = None
|
|
audio_data = None
|
|
|
|
if self.support_images:
|
|
image_data = self.extract_image_from_message(stanza, decrypted_body)
|
|
if image_data:
|
|
logging.info(f"Image detected in DM from {sender}")
|
|
|
|
if self.support_audio:
|
|
audio_data = self.extract_audio_from_message(stanza, decrypted_body)
|
|
if audio_data:
|
|
logging.info(f"Audio detected in DM from {sender}")
|
|
|
|
loop = asyncio.get_running_loop()
|
|
response = await loop.run_in_executor(
|
|
None,
|
|
self.rate_limited_send,
|
|
query,
|
|
self.max_retries,
|
|
sender,
|
|
image_data,
|
|
audio_data
|
|
)
|
|
|
|
if response:
|
|
if is_encrypted:
|
|
await self._encrypted_reply(mfrom, mtype, response)
|
|
else:
|
|
await self._plain_reply(mfrom, mtype, response)
|
|
logging.info(f"Replied to {sender}")
|
|
|
|
if self.tts_auto_reply and self.tts_enabled:
|
|
tts_result = await loop.run_in_executor(None, self.synthesize_speech, response, is_encrypted)
|
|
if tts_result and tts_result["type"] == "url":
|
|
if is_encrypted:
|
|
await self._encrypted_reply(mfrom, mtype, tts_result['content'])
|
|
else:
|
|
await self._plain_reply(mfrom, mtype, tts_result['content'])
|
|
|
|
async def _plain_reply(self, mto: JID, mtype: str, reply_text: str) -> None:
|
|
msg = self.make_message(mto=mto, mtype=mtype)
|
|
msg["body"] = reply_text
|
|
if reply_text.startswith("http") and not reply_text.startswith("aesgcm"):
|
|
msg['oob']['url'] = reply_text
|
|
msg.send()
|
|
|
|
async def _encrypted_reply(self, mto: JID, mtype: str, reply_text: str) -> None:
|
|
xep_0384 = self["xep_0384"]
|
|
|
|
|
|
if isinstance(mto, JID):
|
|
target_jid = mto.bare
|
|
else:
|
|
target_jid = JID(mto).bare
|
|
|
|
msg = self.make_message(mto=target_jid, mtype=mtype)
|
|
msg["body"] = reply_text
|
|
|
|
|
|
|
|
encrypt_for: Set[JID] = {JID(target_jid)}
|
|
|
|
try:
|
|
messages, encryption_errors = await xep_0384.encrypt_message(msg, encrypt_for)
|
|
|
|
if encryption_errors:
|
|
logging.warning(f"Encryption errors: {encryption_errors}")
|
|
|
|
for namespace, message in messages.items():
|
|
message["eme"]["namespace"] = namespace
|
|
message["eme"]["name"] = self["xep_0380"].mechanisms[namespace]
|
|
message.send()
|
|
logging.debug(f"Sent encrypted message to {target_jid}")
|
|
except Exception as e:
|
|
logging.error(f"Failed to send encrypted reply: {e}")
|
|
await self._plain_reply(mto, mtype, f"Error encrypting reply: {e}")
|
|
|
|
def groupchat_message(self, msg):
|
|
if msg['type'] != 'groupchat':
|
|
return
|
|
|
|
room_jid = msg['from'].bare
|
|
bot_nick = self.room_nicknames.get(room_jid, self.nickname)
|
|
|
|
if msg['mucnick'] == bot_nick:
|
|
return
|
|
|
|
body = self.sanitize_input(msg['body'])
|
|
|
|
if len(body) > self.max_length:
|
|
return
|
|
|
|
sender_nick = msg['mucnick']
|
|
sender_lower = sender_nick.lower()
|
|
is_privileged = sender_lower in self.privileged_users
|
|
|
|
if body.startswith(self.imagen_trigger):
|
|
prompt = body[len(self.imagen_trigger):].strip()
|
|
if prompt:
|
|
result = self.generate_image(prompt, should_encrypt=False)
|
|
if result:
|
|
if result["type"] == "url":
|
|
response = f"{sender_nick}: {result['content']}"
|
|
reply = msg.reply(response)
|
|
reply['oob']['url'] = result['content']
|
|
reply.send()
|
|
else:
|
|
msg.reply(f"{sender_nick}: Image generated (base64)").send()
|
|
else:
|
|
msg.reply(f"{sender_nick}: Failed to generate image.").send()
|
|
return
|
|
|
|
if self.tts_enabled and body.startswith(self.tts_trigger):
|
|
text = body[len(self.tts_trigger):].strip()
|
|
if text:
|
|
result = self.synthesize_speech(text, should_encrypt=False)
|
|
if result:
|
|
if result["type"] == "url":
|
|
response = f"{sender_nick}: {result['content']}"
|
|
reply = msg.reply(response)
|
|
reply['oob']['url'] = result['content']
|
|
reply.send()
|
|
else:
|
|
msg.reply(f"{sender_nick}: Speech synthesized (base64)").send()
|
|
else:
|
|
msg.reply(f"{sender_nick}: Failed to synthesize speech.").send()
|
|
return
|
|
|
|
query = None
|
|
is_reply = self.is_replying_to_bot(msg)
|
|
triggered_by_reply_only = False
|
|
triggered_by_links = False
|
|
has_links = self.contains_links(body)
|
|
mention_separator = ': '
|
|
|
|
if body.startswith(self.trigger):
|
|
query = body[len(self.trigger):].strip()
|
|
|
|
elif self.mentions and f"@{bot_nick}" in body:
|
|
query = body.replace(f"@{bot_nick}", "").strip()
|
|
mention_separator = ': '
|
|
|
|
elif self.mentions and body.lower().startswith(f"{bot_nick.lower()}:"):
|
|
query = body[len(bot_nick)+1:].strip()
|
|
mention_separator = ':'
|
|
|
|
elif self.mentions and body.lower().startswith(f"{bot_nick.lower()},"):
|
|
query = body[len(bot_nick)+1:].strip()
|
|
mention_separator = ','
|
|
|
|
elif self.mentions and body.lower().startswith(f"{bot_nick.lower()};"):
|
|
query = body[len(bot_nick)+1:].strip()
|
|
mention_separator = ';'
|
|
|
|
elif self.mentions and body.lower().startswith(f"{bot_nick.lower()}-"):
|
|
query = body[len(bot_nick)+1:].strip()
|
|
mention_separator = '-'
|
|
|
|
elif self.mentions and body.lower().startswith(f"{bot_nick.lower()}—"):
|
|
query = body[len(bot_nick)+1:].strip()
|
|
mention_separator = '—'
|
|
|
|
elif self.mentions and body.lower().startswith(f"{bot_nick.lower()} "):
|
|
query = body[len(bot_nick)+1:].strip()
|
|
mention_separator = ' '
|
|
|
|
elif is_reply:
|
|
_, non_quoted = self.extract_quoted_text(body)
|
|
query = non_quoted.strip()
|
|
triggered_by_reply_only = True
|
|
logging.info(f"Detected reply to bot from {sender_nick}")
|
|
|
|
elif self.answer_to_links and has_links:
|
|
query = body
|
|
triggered_by_links = True
|
|
logging.info(f"Detected message with links from {sender_nick}")
|
|
|
|
elif is_privileged and body:
|
|
query = body
|
|
|
|
if query:
|
|
logging.info(f"Query from {sender_nick} in {room_jid}: {query[:50]}...")
|
|
|
|
image_data = None
|
|
audio_data = None
|
|
|
|
if self.support_images:
|
|
image_data = self.extract_image_from_message(msg)
|
|
|
|
if not image_data:
|
|
url_data = self.extract_urls_and_media(query)
|
|
if url_data['image_urls']:
|
|
for img_url in url_data['image_urls']:
|
|
image_data = self.fetch_image_from_url(img_url)
|
|
if image_data:
|
|
break
|
|
|
|
if self.support_audio:
|
|
audio_data = self.extract_audio_from_message(msg)
|
|
|
|
if not audio_data:
|
|
url_data = self.extract_urls_and_media(query)
|
|
if url_data['audio_urls']:
|
|
for audio_url in url_data['audio_urls']:
|
|
audio_data = self.fetch_audio_from_url(audio_url)
|
|
if audio_data:
|
|
break
|
|
|
|
query = self.clean_aesgcm_urls(query)
|
|
|
|
link_context = ""
|
|
if self.fetch_link_content and has_links:
|
|
url_data = self.extract_urls_and_media(body)
|
|
|
|
for url in url_data['regular_urls'][:3]:
|
|
if not url.startswith('aesgcm://'):
|
|
content = self.fetch_url_content(url)
|
|
if content:
|
|
link_context += f"\n\n{content}\n"
|
|
|
|
if link_context:
|
|
query = f"{query}\n\n--- Additional context from links ---{link_context}"
|
|
|
|
response = self.rate_limited_send(query, max_retries=self.max_retries, room_jid=room_jid, image_data=image_data, audio_data=audio_data)
|
|
|
|
if response:
|
|
if self.mention_reply:
|
|
if mention_separator == ' ':
|
|
response = f"{sender_nick} {response}"
|
|
else:
|
|
response = f"{sender_nick}{mention_separator} {response}"
|
|
|
|
if self.quote_reply:
|
|
_, non_quoted_original = self.extract_quoted_text(body)
|
|
|
|
message_to_quote = non_quoted_original.strip() if non_quoted_original.strip() else body
|
|
|
|
if message_to_quote:
|
|
lines = message_to_quote.split('\n')
|
|
|
|
first_line = lines[0].strip() if lines else ""
|
|
already_has_bot_name = any([
|
|
first_line.lower().startswith(f"{bot_nick.lower()}{sep}")
|
|
for sep in [':', ',', ';', '-', '—', ' ']
|
|
])
|
|
|
|
if (triggered_by_reply_only or triggered_by_links) and lines and not already_has_bot_name:
|
|
lines[0] = f"{bot_nick}: {lines[0]}"
|
|
|
|
quoted = '\n'.join(f"> {line}" for line in lines)
|
|
response = f"{quoted}\n{response}"
|
|
|
|
msg.reply(response).send()
|
|
logging.info(f"Replied in {room_jid}")
|
|
|
|
if self.tts_auto_reply and self.tts_enabled:
|
|
tts_text = response
|
|
if self.quote_reply:
|
|
tts_lines = [l for l in response.split('\n') if not l.startswith('>')]
|
|
tts_text = '\n'.join(tts_lines).strip()
|
|
|
|
tts_result = self.synthesize_speech(tts_text, should_encrypt=False)
|
|
if tts_result and tts_result["type"] == "url":
|
|
reply = msg.reply(tts_result['content'])
|
|
reply['oob']['url'] = tts_result['content']
|
|
reply.send()
|
|
|
|
def send_to_llm(self, message, max_retries, room_jid=None, image_data=None, audio_data=None):
|
|
if self.is_gemini_api() and self.genai_client:
|
|
return self._send_to_gemini_native(message, max_retries, room_jid, image_data, audio_data)
|
|
elif self.use_openai_api:
|
|
if self.openai_client:
|
|
return self._send_to_openai_library(message, max_retries, room_jid, image_data, audio_data)
|
|
else:
|
|
return self._send_to_openai_requests(message, max_retries, room_jid, image_data, audio_data)
|
|
else:
|
|
return self._send_to_custom_api(message, max_retries, room_jid, image_data, audio_data)
|
|
|
|
def _send_to_gemini_native(self, message, max_retries, room_jid=None, image_data=None, audio_data=None):
|
|
if not self.genai_client:
|
|
logging.error("GenAI client not initialized")
|
|
return None
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
contents = []
|
|
|
|
if self.remember and room_jid:
|
|
history = []
|
|
if self.persistent_memory and self.memory_storage:
|
|
history = self.memory_storage.get_history(room_jid)
|
|
elif room_jid in self.history:
|
|
history = self.history[room_jid][:]
|
|
|
|
for item in history:
|
|
role = "user" if item["role"] == "user" else "model"
|
|
contents.append(
|
|
types.Content(
|
|
role=role,
|
|
parts=[types.Part.from_text(text=item["content"])]
|
|
)
|
|
)
|
|
|
|
user_parts = [types.Part.from_text(text=message)]
|
|
|
|
if image_data and self.support_images:
|
|
img_bytes = image_data['data'] if isinstance(image_data['data'], bytes) else base64.b64decode(image_data['data'])
|
|
user_parts.append(
|
|
types.Part.from_bytes(
|
|
data=img_bytes,
|
|
mime_type=image_data['mime_type']
|
|
)
|
|
)
|
|
|
|
if audio_data and self.support_audio:
|
|
audio_bytes = audio_data['data'] if isinstance(audio_data['data'], bytes) else base64.b64decode(audio_data['data'])
|
|
user_parts.append(
|
|
types.Part.from_bytes(
|
|
data=audio_bytes,
|
|
mime_type=audio_data['mime_type']
|
|
)
|
|
)
|
|
|
|
contents.append(
|
|
types.Content(role="user", parts=user_parts)
|
|
)
|
|
|
|
config_kwargs = {}
|
|
|
|
system_prompt = self.system_prompts.get(room_jid) or self.system_prompts.get("global")
|
|
if system_prompt:
|
|
config_kwargs['system_instruction'] = system_prompt
|
|
|
|
if self.enable_url_context:
|
|
config_kwargs['tools'] = [types.Tool(url_context=types.UrlContext())]
|
|
|
|
config = types.GenerateContentConfig(**config_kwargs) if config_kwargs else None
|
|
|
|
response = self.genai_client.models.generate_content(
|
|
model=self.openai_model,
|
|
contents=contents,
|
|
config=config
|
|
)
|
|
|
|
content = response.text
|
|
|
|
if not content:
|
|
logging.warning("Empty response from Gemini")
|
|
time.sleep(0.5)
|
|
continue
|
|
|
|
if len(content) > self.max_length:
|
|
logging.warning("Response too long, retrying...")
|
|
time.sleep(0.5)
|
|
continue
|
|
|
|
if self.skip_thinking and content.lower().startswith("<thinking"):
|
|
logging.info("Skipping thinking response")
|
|
time.sleep(0.5)
|
|
continue
|
|
|
|
if hasattr(response, 'candidates') and response.candidates:
|
|
candidate = response.candidates[0]
|
|
if hasattr(candidate, 'url_context_metadata') and candidate.url_context_metadata:
|
|
url_metadata = candidate.url_context_metadata.url_metadata or []
|
|
sources = []
|
|
for meta in url_metadata:
|
|
if meta.url_retrieval_status == "URL_RETRIEVAL_STATUS_SUCCESS":
|
|
sources.append(meta.retrieved_url)
|
|
if sources:
|
|
content += "\n\nSources:\n" + "\n".join(f"- {s}" for s in sources[:5])
|
|
|
|
self._update_history(room_jid, message, content)
|
|
|
|
return content
|
|
|
|
except Exception as e:
|
|
logging.error(f"Gemini native error (attempt {attempt+1}): {e}")
|
|
time.sleep(1)
|
|
|
|
logging.error("All retry attempts exhausted")
|
|
return None
|
|
|
|
def _send_to_openai_library(self, message, max_retries, room_jid=None, image_data=None, audio_data=None):
|
|
if not self.openai_client:
|
|
logging.error("OpenAI client not initialized")
|
|
return self._send_to_openai_requests(message, max_retries, room_jid, image_data, audio_data)
|
|
|
|
messages = self._build_message_history_openai(message, room_jid, image_data, audio_data)
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
response = self.openai_client.chat.completions.create(
|
|
model=self.openai_model,
|
|
messages=messages
|
|
)
|
|
|
|
message_obj = response.choices[0].message
|
|
content = message_obj.content.strip()
|
|
|
|
if self.openai_model.startswith("groq/compound"):
|
|
executed_tools = getattr(message_obj, 'executed_tools', None)
|
|
if executed_tools:
|
|
content += "\n\n--- Executed Tools ---"
|
|
for tool in executed_tools:
|
|
if isinstance(tool, dict):
|
|
tool_type = tool.get('type', 'unknown')
|
|
tool_args = tool.get('arguments', '')
|
|
else:
|
|
tool_type = getattr(tool, 'type', 'unknown')
|
|
tool_args = getattr(tool, 'arguments', '')
|
|
content += f"\n- {tool_type}: {tool_args}"
|
|
|
|
if len(content) > self.max_length:
|
|
logging.warning("Response too long, retrying...")
|
|
time.sleep(0.5)
|
|
continue
|
|
|
|
if self.skip_thinking and content.lower().startswith("<thinking"):
|
|
logging.info("Skipping thinking response")
|
|
time.sleep(0.5)
|
|
continue
|
|
|
|
self._update_history(room_jid, message, content)
|
|
|
|
return content
|
|
|
|
except Exception as e:
|
|
logging.error(f"OpenAI library error (attempt {attempt+1}): {e}")
|
|
time.sleep(1)
|
|
|
|
logging.error("All retry attempts exhausted")
|
|
return None
|
|
|
|
def _send_to_openai_requests(self, message, max_retries, room_jid=None, image_data=None, audio_data=None):
|
|
if not self.api_token:
|
|
logging.error("API token not configured for OpenAI API")
|
|
return None
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {self.api_token}"
|
|
}
|
|
|
|
messages = self._build_message_history_openai(message, room_jid, image_data, audio_data)
|
|
|
|
data = {
|
|
"model": self.openai_model,
|
|
"messages": messages
|
|
}
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
resp = requests.post(
|
|
self.api_url,
|
|
headers=headers,
|
|
json=data,
|
|
timeout=self.request_timeout,
|
|
allow_redirects=False
|
|
)
|
|
|
|
if resp.status_code == 500:
|
|
logging.warning(f"Server error (500), retrying... (attempt {attempt+1})")
|
|
time.sleep(1)
|
|
continue
|
|
|
|
resp.raise_for_status()
|
|
|
|
result = json.loads(resp.text)
|
|
message_obj = result["choices"][0]["message"]
|
|
content = message_obj["content"].strip()
|
|
|
|
if self.openai_model.startswith("groq/compound"):
|
|
executed_tools = message_obj.get("executed_tools")
|
|
if executed_tools:
|
|
content += "\n\n--- Executed Tools ---"
|
|
for tool in executed_tools:
|
|
tool_type = tool.get('type', 'unknown')
|
|
tool_args = tool.get('arguments', '')
|
|
content += f"\n- {tool_type}: {tool_args}"
|
|
|
|
if len(content) > self.max_length:
|
|
logging.warning("Response too long, retrying...")
|
|
time.sleep(0.5)
|
|
continue
|
|
|
|
if self.skip_thinking and content.lower().startswith("<thinking"):
|
|
logging.info("Skipping thinking response")
|
|
time.sleep(0.5)
|
|
continue
|
|
|
|
self._update_history(room_jid, message, content)
|
|
|
|
return content
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
logging.error(f"Request error (attempt {attempt+1}): {e}")
|
|
time.sleep(1)
|
|
except (json.JSONDecodeError, KeyError) as e:
|
|
logging.error(f"Response parsing error (attempt {attempt+1}): {e}")
|
|
time.sleep(1)
|
|
except Exception as e:
|
|
logging.error(f"Unexpected error (attempt {attempt+1}): {e}")
|
|
time.sleep(1)
|
|
|
|
logging.error("All retry attempts exhausted")
|
|
return None
|
|
|
|
def _send_to_custom_api(self, message, max_retries, room_jid=None, image_data=None, audio_data=None):
|
|
from urllib.parse import urlparse
|
|
|
|
try:
|
|
parsed_url = urlparse(self.api_url)
|
|
if parsed_url.scheme not in ['http', 'https']:
|
|
logging.error(f"Invalid URL scheme: {parsed_url.scheme}")
|
|
return None
|
|
|
|
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
except Exception as e:
|
|
logging.error(f"URL parsing error: {e}")
|
|
return None
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"User-Agent": "Mozilla/5.0",
|
|
"Accept": "*/*",
|
|
"Origin": base_url,
|
|
"Referer": f"{base_url}/"
|
|
}
|
|
|
|
messages = self._build_message_history_openai(message, room_jid, image_data, audio_data)
|
|
data = {"messages": messages}
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
resp = requests.post(
|
|
self.api_url,
|
|
headers=headers,
|
|
json=data,
|
|
timeout=self.request_timeout,
|
|
allow_redirects=False
|
|
)
|
|
|
|
if resp.status_code == 500:
|
|
logging.warning(f"Server error (500), retrying... (attempt {attempt+1})")
|
|
time.sleep(1)
|
|
continue
|
|
|
|
resp.raise_for_status()
|
|
|
|
content = json.loads(resp.text)["content"].strip()
|
|
|
|
if len(content) > self.max_length:
|
|
logging.warning("Response too long, retrying...")
|
|
time.sleep(0.5)
|
|
continue
|
|
|
|
if self.skip_thinking and content.lower().startswith("<thinking"):
|
|
logging.info("Skipping thinking response")
|
|
time.sleep(0.5)
|
|
continue
|
|
|
|
self._update_history(room_jid, message, content)
|
|
|
|
return content
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
logging.error(f"Request error (attempt {attempt+1}): {e}")
|
|
time.sleep(1)
|
|
except (json.JSONDecodeError, KeyError) as e:
|
|
logging.error(f"Response parsing error (attempt {attempt+1}): {e}")
|
|
time.sleep(1)
|
|
except Exception as e:
|
|
logging.error(f"Unexpected error (attempt {attempt+1}): {e}")
|
|
time.sleep(1)
|
|
|
|
logging.error("All retry attempts exhausted")
|
|
return None
|
|
|
|
def _build_message_history_openai(self, message, room_jid, image_data=None, audio_data=None):
|
|
messages = []
|
|
|
|
if self.remember and room_jid:
|
|
if self.persistent_memory and self.memory_storage:
|
|
messages = self.memory_storage.get_history(room_jid)
|
|
elif room_jid in self.history:
|
|
messages = self.history[room_jid][:]
|
|
|
|
system_prompt = self.system_prompts.get(room_jid) or self.system_prompts.get("global")
|
|
|
|
if system_prompt:
|
|
if messages and messages[0]["role"] == "system":
|
|
messages[0] = {"role": "system", "content": system_prompt}
|
|
else:
|
|
messages.insert(0, {"role": "system", "content": system_prompt})
|
|
|
|
has_multimodal = (image_data and self.support_images) or (audio_data and self.support_audio)
|
|
|
|
if has_multimodal:
|
|
content_parts = [{"type": "text", "text": message}]
|
|
|
|
if image_data and self.support_images:
|
|
img_b64 = image_data['data'] if isinstance(image_data['data'], str) else base64.b64decode(image_data['data']).decode('utf-8')
|
|
content_parts.append({
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:{image_data['mime_type']};base64,{img_b64}"
|
|
}
|
|
})
|
|
|
|
if audio_data and self.support_audio:
|
|
audio_b64 = audio_data['data'] if isinstance(audio_data['data'], str) else base64.b64encode(audio_data['data']).decode('utf-8')
|
|
content_parts.append({
|
|
"type": "input_audio",
|
|
"input_audio": {
|
|
"data": audio_b64,
|
|
"format": audio_data['format']
|
|
}
|
|
})
|
|
|
|
user_message = {"role": "user", "content": content_parts}
|
|
else:
|
|
user_message = {"role": "user", "content": message}
|
|
|
|
messages.append(user_message)
|
|
|
|
return messages
|
|
|
|
def _update_history(self, room_jid, user_message, assistant_message):
|
|
if self.remember and room_jid:
|
|
if self.persistent_memory and self.memory_storage:
|
|
self.memory_storage.append_to_history(
|
|
room_jid,
|
|
user_message,
|
|
assistant_message,
|
|
self.history_limit
|
|
)
|
|
else:
|
|
if room_jid not in self.history:
|
|
self.history[room_jid] = []
|
|
|
|
self.history[room_jid].extend([
|
|
{"role": "user", "content": user_message},
|
|
{"role": "assistant", "content": assistant_message}
|
|
])
|
|
|
|
if self.history_limit > 0 and len(self.history[room_jid]) > self.history_limit * 2:
|
|
self.history[room_jid] = self.history[room_jid][-self.history_limit * 2:]
|
|
|
|
|
|
if __name__ == '__main__':
|
|
config = configparser.ConfigParser()
|
|
config.read('config.ini')
|
|
|
|
jid = config['XMPP']['jid']
|
|
password = config['XMPP']['password']
|
|
rooms = [r.strip() for r in config['XMPP']['rooms'].split(',')]
|
|
|
|
api_url = config['Bot']['api_url']
|
|
use_openai_api = config.getboolean('Bot', 'use_openai_api', fallback=False)
|
|
api_token = config.get('Bot', 'api_token', fallback='')
|
|
openai_model = config.get('Bot', 'openai_model', fallback='gpt-4')
|
|
|
|
trigger = config.get('Bot', 'trigger', fallback='!aibot')
|
|
mentions = config.getboolean('Bot', 'mentions', fallback=True)
|
|
rate_limit_calls = config.getint('Bot', 'rate_limit_calls', fallback=20)
|
|
rate_limit_period = config.getint('Bot', 'rate_limit_period', fallback=60)
|
|
max_length = config.getint('Bot', 'max_length', fallback=4000)
|
|
nickname = config.get('Bot', 'nickname', fallback='LLMBot')
|
|
privileged_users = [u.strip() for u in config.get('Bot', 'privileged_users', fallback='').split(',') if u.strip()]
|
|
max_retries = config.getint('Bot', 'max_retries', fallback=20)
|
|
|
|
remember_conversations = config.getboolean('Bot', 'remember_conversations', fallback=False)
|
|
history_per_room = config.getint('Bot', 'history_per_room', fallback=10)
|
|
|
|
persistent_memory = config.getboolean('Bot', 'persistent_memory', fallback=False)
|
|
memory_file_path = config.get('Bot', 'memory_file_path', fallback='memory.json')
|
|
|
|
quote_reply = config.getboolean('Bot', 'quote_reply', fallback=True)
|
|
mention_reply = config.getboolean('Bot', 'mention_reply', fallback=True)
|
|
skip_thinking = config.getboolean('Bot', 'skip_thinking_models', fallback=False)
|
|
request_timeout = config.getint('Bot', 'request_timeout', fallback=20)
|
|
|
|
allow_dm = config.getboolean('Bot', 'allow_dm', fallback=True)
|
|
dm_mode = config.get('Bot', 'dm_mode', fallback='none')
|
|
dm_list = [x.strip() for x in config.get('Bot', 'dm_list', fallback='').split(',') if x.strip()]
|
|
|
|
enable_omemo = config.getboolean('Bot', 'enable_omemo', fallback=True)
|
|
omemo_store_path = config.get('Bot', 'omemo_store_path', fallback='omemo_store.json')
|
|
omemo_only = config.getboolean('Bot', 'omemo_only', fallback=False)
|
|
|
|
answer_to_links = config.getboolean('Bot', 'answer_to_links', fallback=False)
|
|
fetch_link_content = config.getboolean('Bot', 'fetch_link_content', fallback=False)
|
|
support_images = config.getboolean('Bot', 'support_images', fallback=False)
|
|
support_audio = config.getboolean('Bot', 'support_audio', fallback=False)
|
|
join_retry_attempts = config.getint('Bot', 'join_retry_attempts', fallback=5)
|
|
join_retry_delay = config.getint('Bot', 'join_retry_delay', fallback=10)
|
|
|
|
imagen_trigger = config.get('Bot', 'imagen_trigger', fallback='!imagen')
|
|
cf_account_id = config.get('Bot', 'cloudflare_account_id', fallback=None)
|
|
cf_api_token = config.get('Bot', 'cloudflare_api_token', fallback=None)
|
|
|
|
enable_url_context = config.getboolean('Bot', 'enable_url_context', fallback=False)
|
|
|
|
|
|
file_host = config.get('Bot', 'file_host', fallback='catbox')
|
|
file_host_api_key = config.get('Bot', 'file_host_api_key', fallback='')
|
|
|
|
tts_trigger = config.get('Bot', 'tts_trigger', fallback='!tts')
|
|
tts_enabled = config.getboolean('Bot', 'tts_enabled', fallback=False)
|
|
tts_voice_name = config.get('Bot', 'tts_voice_name', fallback='Kore')
|
|
tts_model = config.get('Bot', 'tts_model', fallback='gemini-2.5-flash-preview-tts')
|
|
tts_auto_reply = config.getboolean('Bot', 'tts_auto_reply', fallback=False)
|
|
|
|
system_prompts = {}
|
|
global_prompt = config.get('Bot', 'system_prompt', fallback='').strip()
|
|
if global_prompt:
|
|
system_prompts["global"] = global_prompt
|
|
|
|
for key in config['Bot']:
|
|
if key.startswith('system_prompt.'):
|
|
room = key.split('.', 1)[1].strip()
|
|
prompt = config['Bot'][key].strip()
|
|
if prompt:
|
|
system_prompts[room] = prompt
|
|
|
|
room_nicknames = {}
|
|
for key in config['Bot']:
|
|
if key.startswith('nickname.'):
|
|
room = key.split('.', 1)[1].strip()
|
|
nick = config['Bot'][key].strip()
|
|
if nick:
|
|
room_nicknames[room] = nick
|
|
|
|
async def main():
|
|
logging.basicConfig(level=logging.DEBUG, format="%(levelname)-8s %(message)s")
|
|
|
|
loop = asyncio.get_running_loop()
|
|
|
|
bot = LLMBot(jid, password, rooms, room_nicknames, trigger, mentions, rate_limit_calls, rate_limit_period,
|
|
max_length, nickname, api_url, privileged_users, max_retries, system_prompts,
|
|
remember_conversations, history_per_room,
|
|
quote_reply=quote_reply, mention_reply=mention_reply,
|
|
skip_thinking=skip_thinking,
|
|
request_timeout=request_timeout, allow_dm=allow_dm,
|
|
dm_mode=dm_mode, dm_list=dm_list,
|
|
use_openai_api=use_openai_api, api_token=api_token,
|
|
openai_model=openai_model,
|
|
enable_omemo=enable_omemo, omemo_store_path=omemo_store_path, omemo_only=omemo_only,
|
|
answer_to_links=answer_to_links, fetch_link_content=fetch_link_content,
|
|
support_images=support_images, support_audio=support_audio,
|
|
join_retry_attempts=join_retry_attempts, join_retry_delay=join_retry_delay,
|
|
persistent_memory=persistent_memory, memory_file_path=memory_file_path,
|
|
imagen_trigger=imagen_trigger,
|
|
cf_account_id=cf_account_id, cf_api_token=cf_api_token,
|
|
enable_url_context=enable_url_context,
|
|
file_host=file_host, file_host_api_key=file_host_api_key,
|
|
tts_trigger=tts_trigger, tts_enabled=tts_enabled,
|
|
tts_voice_name=tts_voice_name, tts_model=tts_model,
|
|
tts_auto_reply=tts_auto_reply,
|
|
loop=loop)
|
|
|
|
bot.connect()
|
|
logging.info("Bot connected, waiting for messages...")
|
|
|
|
await bot.disconnected
|
|
|
|
asyncio.run(main())
|
|
|