Add video support
This commit is contained in:
241
bot.py
241
bot.py
@@ -389,6 +389,18 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
'audio/x-aiff': 'aiff'
|
||||
}
|
||||
|
||||
VIDEO_EXTENSIONS = ['.mp4', '.mpeg', '.mov', '.avi', '.flv', '.mpg', '.webm', '.wmv', '.3gp']
|
||||
VIDEO_MIME_TYPES = {
|
||||
'video/mp4': 'mp4',
|
||||
'video/mpeg': 'mpeg',
|
||||
'video/quicktime': 'mov',
|
||||
'video/x-msvideo': 'avi',
|
||||
'video/x-flv': 'flv',
|
||||
'video/webm': 'webm',
|
||||
'video/x-ms-wmv': 'wmv',
|
||||
'video/3gpp': '3gp'
|
||||
}
|
||||
|
||||
def __init__(self, jid, password, rooms, room_nicknames, trigger, mentions, rate_limit_calls, rate_limit_period,
|
||||
max_length, nickname, api_url, privileged_users, max_retries, system_prompts,
|
||||
remember_conversations, history_per_room,
|
||||
@@ -396,12 +408,12 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
request_timeout=20, allow_dm=True, dm_mode='whitelist', dm_list=None,
|
||||
use_openai_api=False, api_token=None, openai_model="gpt-4",
|
||||
enable_omemo=True, omemo_store_path="omemo_store.json", omemo_only=False,
|
||||
answer_to_links=False, fetch_link_content=False, support_images=False,
|
||||
answer_to_links=False, fetch_link_content=False,
|
||||
support_images=False, support_audio=False, support_video=False,
|
||||
join_retry_attempts=5, join_retry_delay=10,
|
||||
persistent_memory=False, memory_file_path="memory.json",
|
||||
imagen_trigger="!imagen",
|
||||
cf_account_id=None, cf_api_token=None,
|
||||
support_audio=False,
|
||||
enable_url_context=False,
|
||||
file_host='catbox', file_host_api_key=None,
|
||||
tts_trigger="!tts", tts_enabled=False, tts_voice_name="Kore",
|
||||
@@ -455,6 +467,7 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
self.fetch_link_content = fetch_link_content
|
||||
self.support_images = support_images
|
||||
self.support_audio = support_audio
|
||||
self.support_video = support_video
|
||||
self.join_retry_attempts = join_retry_attempts
|
||||
self.join_retry_delay = join_retry_delay
|
||||
|
||||
@@ -494,10 +507,17 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize OpenAI client: {e}")
|
||||
|
||||
if self.support_video and not (self.is_gemini_api() and self.genai_client):
|
||||
logging.warning("Video support disabled: Video understanding is only available with Google Gemini API.")
|
||||
self.support_video = False
|
||||
|
||||
self.url_pattern = re.compile(
|
||||
r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
|
||||
)
|
||||
self.aesgcm_pattern = re.compile(r'aesgcm://[^\s]+')
|
||||
self.youtube_pattern = re.compile(
|
||||
r'(https?://)?(www\.)?(youtube|youtu|youtube-nocookie)\.(com|be)/(watch\?v=|embed/|v/|.+\?v=)?([^&=%\?]{11})'
|
||||
)
|
||||
|
||||
self.register_plugin('xep_0030')
|
||||
self.register_plugin('xep_0045')
|
||||
@@ -665,6 +685,15 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
return any(path.endswith(ext) for ext in self.AUDIO_EXTENSIONS)
|
||||
return any(url.lower().endswith(ext) for ext in self.AUDIO_EXTENSIONS)
|
||||
|
||||
def is_video_url(self, url):
|
||||
if url.startswith('aesgcm://'):
|
||||
path = urlparse(url.replace('aesgcm://', 'https://')).path.lower()
|
||||
return any(path.endswith(ext) for ext in self.VIDEO_EXTENSIONS)
|
||||
return any(url.lower().endswith(ext) for ext in self.VIDEO_EXTENSIONS)
|
||||
|
||||
def is_youtube_url(self, url):
|
||||
return bool(self.youtube_pattern.search(url))
|
||||
|
||||
def get_audio_format_from_mime(self, mime_type):
|
||||
mime_lower = mime_type.lower()
|
||||
return self.AUDIO_MIME_TYPES.get(mime_lower, 'wav')
|
||||
@@ -690,6 +719,15 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
'webm': 'audio/webm'
|
||||
}
|
||||
return mime_map.get(fmt, 'audio/wav')
|
||||
|
||||
def get_video_mime_from_extension(self, url):
|
||||
path = urlparse(url.replace('aesgcm://', 'https://')).path.lower()
|
||||
for ext in self.VIDEO_EXTENSIONS:
|
||||
if path.endswith(ext):
|
||||
for mime, extension_val in self.VIDEO_MIME_TYPES.items():
|
||||
if extension_val == ext[1:]:
|
||||
return mime
|
||||
return 'video/mp4'
|
||||
|
||||
def fetch_audio_from_url(self, url):
|
||||
try:
|
||||
@@ -738,6 +776,76 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
logging.error(f"Error fetching audio from {url}: {e}")
|
||||
return None
|
||||
|
||||
def fetch_video_from_url(self, url):
|
||||
try:
|
||||
if self.is_youtube_url(url):
|
||||
logging.info(f"Identified YouTube URL: {url}")
|
||||
return {
|
||||
'type': 'uri',
|
||||
'data': url,
|
||||
'mime_type': 'video/mp4'
|
||||
}
|
||||
|
||||
logging.info(f"Fetching video from URL: {url}")
|
||||
|
||||
if url.startswith('aesgcm://'):
|
||||
decrypted_data = self.decrypt_aesgcm_url(url)
|
||||
if not decrypted_data:
|
||||
logging.error(f"Failed to decrypt AESGCM video URL: {url}")
|
||||
return None
|
||||
|
||||
if len(decrypted_data) > 20 * 1024 * 1024:
|
||||
logging.warning(f"Video size ({len(decrypted_data)} bytes) exceeds inline limit for GenAI")
|
||||
return None
|
||||
|
||||
mime_type = self.get_video_mime_from_extension(url)
|
||||
|
||||
return {
|
||||
'type': 'bytes',
|
||||
'data': decrypted_data,
|
||||
'mime_type': mime_type
|
||||
}
|
||||
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/142.0.0.0 Safari/537.36'
|
||||
}
|
||||
|
||||
head_resp = requests.head(url, timeout=10, headers=headers, allow_redirects=True)
|
||||
if 'content-length' in head_resp.headers:
|
||||
size = int(head_resp.headers['content-length'])
|
||||
if size > 20 * 1024 * 1024:
|
||||
logging.warning(f"Video at {url} is too large ({size} bytes) for inline processing")
|
||||
return None
|
||||
|
||||
response = requests.get(url, timeout=60, headers=headers, allow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
||||
if len(response.content) > 20 * 1024 * 1024:
|
||||
logging.warning(f"Video downloaded from {url} is too large ({len(response.content)} bytes)")
|
||||
return None
|
||||
|
||||
content_type = response.headers.get('content-type', '').lower()
|
||||
|
||||
is_video = any(vtype in content_type for vtype in self.VIDEO_MIME_TYPES.keys())
|
||||
if not is_video:
|
||||
if not self.is_video_url(url):
|
||||
logging.warning(f"URL returned non-video content-type: {content_type}")
|
||||
return None
|
||||
else:
|
||||
content_type = self.get_video_mime_from_extension(url)
|
||||
|
||||
logging.info(f"Successfully fetched video: {content_type}, size: {len(response.content)} bytes")
|
||||
|
||||
return {
|
||||
'type': 'bytes',
|
||||
'data': response.content,
|
||||
'mime_type': content_type
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error fetching video from {url}: {e}")
|
||||
return None
|
||||
|
||||
def extract_audio_from_message(self, msg, decrypted_body=None):
|
||||
try:
|
||||
if hasattr(msg, 'xml'):
|
||||
@@ -782,35 +890,74 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
|
||||
return None
|
||||
|
||||
def extract_video_from_message(self, msg, decrypted_body=None):
|
||||
try:
|
||||
if hasattr(msg, 'xml'):
|
||||
xml_elem = msg.xml
|
||||
elif hasattr(msg, '_get_stanza_values'):
|
||||
xml_elem = msg
|
||||
else:
|
||||
return None
|
||||
|
||||
oob = xml_elem.find('.//{jabber:x:oob}url')
|
||||
if oob is not None and oob.text:
|
||||
vid_url = oob.text.strip()
|
||||
if self.is_video_url(vid_url) or self.is_youtube_url(vid_url):
|
||||
logging.info(f"Found OOB video URL: {vid_url}")
|
||||
return self.fetch_video_from_url(vid_url)
|
||||
|
||||
if decrypted_body:
|
||||
url_data = self.extract_urls_and_media(decrypted_body)
|
||||
if url_data['video_urls']:
|
||||
for vid_url in url_data['video_urls']:
|
||||
return self.fetch_video_from_url(vid_url)
|
||||
|
||||
body_elem = xml_elem.find('.//{jabber:client}body')
|
||||
if body_elem is None:
|
||||
body_elem = xml_elem.find('.//body')
|
||||
|
||||
if body_elem is not None and body_elem.text:
|
||||
body_text = body_elem.text
|
||||
if "doesn't seem to support that" not in body_text and "OMEMO" not in body_text:
|
||||
url_data = self.extract_urls_and_media(body_text)
|
||||
if url_data['video_urls']:
|
||||
for vid_url in url_data['video_urls']:
|
||||
return self.fetch_video_from_url(vid_url)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error extracting video: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def extract_urls_and_media(self, text):
|
||||
image_urls = []
|
||||
audio_urls = []
|
||||
video_urls = []
|
||||
regular_urls = []
|
||||
|
||||
aesgcm_urls = self.aesgcm_pattern.findall(text)
|
||||
|
||||
for url in aesgcm_urls:
|
||||
if self.is_image_url(url):
|
||||
image_urls.append(url)
|
||||
elif self.is_audio_url(url):
|
||||
audio_urls.append(url)
|
||||
def classify_url(u):
|
||||
if self.is_image_url(u):
|
||||
image_urls.append(u)
|
||||
elif self.is_audio_url(u):
|
||||
audio_urls.append(u)
|
||||
elif self.is_video_url(u) or self.is_youtube_url(u):
|
||||
video_urls.append(u)
|
||||
else:
|
||||
regular_urls.append(url)
|
||||
regular_urls.append(u)
|
||||
|
||||
aesgcm_urls = self.aesgcm_pattern.findall(text)
|
||||
for url in aesgcm_urls:
|
||||
classify_url(url)
|
||||
|
||||
http_urls = self.url_pattern.findall(text)
|
||||
for url in http_urls:
|
||||
if self.is_image_url(url):
|
||||
image_urls.append(url)
|
||||
elif self.is_audio_url(url):
|
||||
audio_urls.append(url)
|
||||
else:
|
||||
regular_urls.append(url)
|
||||
classify_url(url)
|
||||
|
||||
return {
|
||||
'image_urls': image_urls,
|
||||
'audio_urls': audio_urls,
|
||||
'video_urls': video_urls,
|
||||
'regular_urls': regular_urls,
|
||||
'all_urls': image_urls + audio_urls + regular_urls
|
||||
'all_urls': image_urls + audio_urls + video_urls + regular_urls
|
||||
}
|
||||
|
||||
def extract_image_from_message(self, msg, decrypted_body=None):
|
||||
@@ -1005,14 +1152,14 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
for url in aesgcm_urls:
|
||||
if self.is_image_url(url):
|
||||
image_urls.append(url)
|
||||
elif not self.is_audio_url(url):
|
||||
elif not self.is_audio_url(url) and not self.is_video_url(url):
|
||||
regular_urls.append(url)
|
||||
|
||||
http_urls = self.url_pattern.findall(text)
|
||||
for url in http_urls:
|
||||
if self.is_image_url(url):
|
||||
image_urls.append(url)
|
||||
elif not self.is_audio_url(url):
|
||||
elif not self.is_audio_url(url) and not self.is_video_url(url):
|
||||
regular_urls.append(url)
|
||||
|
||||
return {
|
||||
@@ -1279,6 +1426,7 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
|
||||
image_data = None
|
||||
audio_data = None
|
||||
video_data = None
|
||||
|
||||
if self.support_images:
|
||||
image_data = self.extract_image_from_message(msg)
|
||||
@@ -1289,6 +1437,11 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
audio_data = self.extract_audio_from_message(msg)
|
||||
if audio_data:
|
||||
logging.info(f"Audio detected in DM from {sender}")
|
||||
|
||||
if self.support_video:
|
||||
video_data = self.extract_video_from_message(msg)
|
||||
if video_data:
|
||||
logging.info(f"Video detected in DM from {sender}")
|
||||
|
||||
query = self.clean_aesgcm_urls(query)
|
||||
|
||||
@@ -1305,7 +1458,7 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
if link_context:
|
||||
query = f"{query}\n\n--- Additional context from links ---{link_context}"
|
||||
|
||||
response = self.rate_limited_send(query, max_retries=self.max_retries, room_jid=sender, image_data=image_data, audio_data=audio_data)
|
||||
response = self.rate_limited_send(query, max_retries=self.max_retries, room_jid=sender, image_data=image_data, audio_data=audio_data, video_data=video_data)
|
||||
|
||||
if response:
|
||||
msg.reply(response).send()
|
||||
@@ -1430,6 +1583,7 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
|
||||
image_data = None
|
||||
audio_data = None
|
||||
video_data = None
|
||||
|
||||
if self.support_images:
|
||||
image_data = self.extract_image_from_message(stanza, decrypted_body)
|
||||
@@ -1440,6 +1594,11 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
audio_data = self.extract_audio_from_message(stanza, decrypted_body)
|
||||
if audio_data:
|
||||
logging.info(f"Audio detected in DM from {sender}")
|
||||
|
||||
if self.support_video:
|
||||
video_data = self.extract_video_from_message(stanza, decrypted_body)
|
||||
if video_data:
|
||||
logging.info(f"Video detected in DM from {sender}")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
response = await loop.run_in_executor(
|
||||
@@ -1449,7 +1608,8 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
self.max_retries,
|
||||
sender,
|
||||
image_data,
|
||||
audio_data
|
||||
audio_data,
|
||||
video_data
|
||||
)
|
||||
|
||||
if response:
|
||||
@@ -1613,6 +1773,7 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
|
||||
image_data = None
|
||||
audio_data = None
|
||||
video_data = None
|
||||
|
||||
if self.support_images:
|
||||
image_data = self.extract_image_from_message(msg)
|
||||
@@ -1635,6 +1796,17 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
audio_data = self.fetch_audio_from_url(audio_url)
|
||||
if audio_data:
|
||||
break
|
||||
|
||||
if self.support_video:
|
||||
video_data = self.extract_video_from_message(msg)
|
||||
|
||||
if not video_data:
|
||||
url_data = self.extract_urls_and_media(query)
|
||||
if url_data['video_urls']:
|
||||
for vid_url in url_data['video_urls']:
|
||||
video_data = self.fetch_video_from_url(vid_url)
|
||||
if video_data:
|
||||
break
|
||||
|
||||
query = self.clean_aesgcm_urls(query)
|
||||
|
||||
@@ -1651,7 +1823,7 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
if link_context:
|
||||
query = f"{query}\n\n--- Additional context from links ---{link_context}"
|
||||
|
||||
response = self.rate_limited_send(query, max_retries=self.max_retries, room_jid=room_jid, image_data=image_data, audio_data=audio_data)
|
||||
response = self.rate_limited_send(query, max_retries=self.max_retries, room_jid=room_jid, image_data=image_data, audio_data=audio_data, video_data=video_data)
|
||||
|
||||
if response:
|
||||
if self.mention_reply:
|
||||
@@ -1695,9 +1867,9 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
reply['oob']['url'] = tts_result['content']
|
||||
reply.send()
|
||||
|
||||
def send_to_llm(self, message, max_retries, room_jid=None, image_data=None, audio_data=None):
|
||||
def send_to_llm(self, message, max_retries, room_jid=None, image_data=None, audio_data=None, video_data=None):
|
||||
if self.is_gemini_api() and self.genai_client:
|
||||
return self._send_to_gemini_native(message, max_retries, room_jid, image_data, audio_data)
|
||||
return self._send_to_gemini_native(message, max_retries, room_jid, image_data, audio_data, video_data)
|
||||
elif self.use_openai_api:
|
||||
if self.openai_client:
|
||||
return self._send_to_openai_library(message, max_retries, room_jid, image_data, audio_data)
|
||||
@@ -1706,7 +1878,7 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
else:
|
||||
return self._send_to_custom_api(message, max_retries, room_jid, image_data, audio_data)
|
||||
|
||||
def _send_to_gemini_native(self, message, max_retries, room_jid=None, image_data=None, audio_data=None):
|
||||
def _send_to_gemini_native(self, message, max_retries, room_jid=None, image_data=None, audio_data=None, video_data=None):
|
||||
if not self.genai_client:
|
||||
logging.error("GenAI client not initialized")
|
||||
return None
|
||||
@@ -1750,6 +1922,22 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
mime_type=audio_data['mime_type']
|
||||
)
|
||||
)
|
||||
|
||||
if video_data and self.support_video:
|
||||
if video_data.get('type') == 'uri':
|
||||
user_parts.append(
|
||||
types.Part(
|
||||
file_data=types.FileData(file_uri=video_data['data'], mime_type=video_data['mime_type'])
|
||||
)
|
||||
)
|
||||
elif video_data.get('type') == 'bytes':
|
||||
video_bytes = video_data['data'] if isinstance(video_data['data'], bytes) else base64.b64decode(video_data['data'])
|
||||
user_parts.append(
|
||||
types.Part.from_bytes(
|
||||
data=video_bytes,
|
||||
mime_type=video_data['mime_type']
|
||||
)
|
||||
)
|
||||
|
||||
contents.append(
|
||||
types.Content(role="user", parts=user_parts)
|
||||
@@ -2123,6 +2311,8 @@ if __name__ == '__main__':
|
||||
fetch_link_content = config.getboolean('Bot', 'fetch_link_content', fallback=False)
|
||||
support_images = config.getboolean('Bot', 'support_images', fallback=False)
|
||||
support_audio = config.getboolean('Bot', 'support_audio', fallback=False)
|
||||
support_video = config.getboolean('Bot', 'support_video', fallback=False)
|
||||
|
||||
join_retry_attempts = config.getint('Bot', 'join_retry_attempts', fallback=5)
|
||||
join_retry_delay = config.getint('Bot', 'join_retry_delay', fallback=10)
|
||||
|
||||
@@ -2179,6 +2369,7 @@ if __name__ == '__main__':
|
||||
enable_omemo=enable_omemo, omemo_store_path=omemo_store_path, omemo_only=omemo_only,
|
||||
answer_to_links=answer_to_links, fetch_link_content=fetch_link_content,
|
||||
support_images=support_images, support_audio=support_audio,
|
||||
support_video=support_video,
|
||||
join_retry_attempts=join_retry_attempts, join_retry_delay=join_retry_delay,
|
||||
persistent_memory=persistent_memory, memory_file_path=memory_file_path,
|
||||
imagen_trigger=imagen_trigger,
|
||||
|
||||
Reference in New Issue
Block a user