Add reply support
This commit is contained in:
120
bot.py
120
bot.py
@@ -418,6 +418,7 @@ class LLMBot(slixmpp.ClientXMPP):
|
|||||||
file_host='catbox', file_host_api_key=None,
|
file_host='catbox', file_host_api_key=None,
|
||||||
tts_trigger="!tts", tts_enabled=False, tts_voice_name="Kore",
|
tts_trigger="!tts", tts_enabled=False, tts_voice_name="Kore",
|
||||||
tts_model="gemini-2.5-flash-preview-tts", tts_auto_reply=False,
|
tts_model="gemini-2.5-flash-preview-tts", tts_auto_reply=False,
|
||||||
|
use_proper_replies=True,
|
||||||
loop=None):
|
loop=None):
|
||||||
self.request_timeout = request_timeout
|
self.request_timeout = request_timeout
|
||||||
super().__init__(jid, password, loop=loop, sasl_mech='PLAIN')
|
super().__init__(jid, password, loop=loop, sasl_mech='PLAIN')
|
||||||
@@ -485,6 +486,8 @@ class LLMBot(slixmpp.ClientXMPP):
|
|||||||
self.tts_model = tts_model
|
self.tts_model = tts_model
|
||||||
self.tts_auto_reply = tts_auto_reply
|
self.tts_auto_reply = tts_auto_reply
|
||||||
|
|
||||||
|
self.use_proper_replies = use_proper_replies
|
||||||
|
|
||||||
self.genai_client = None
|
self.genai_client = None
|
||||||
if GOOGLE_GENAI_AVAILABLE and self.api_token:
|
if GOOGLE_GENAI_AVAILABLE and self.api_token:
|
||||||
try:
|
try:
|
||||||
@@ -524,6 +527,8 @@ class LLMBot(slixmpp.ClientXMPP):
|
|||||||
self.register_plugin('xep_0199')
|
self.register_plugin('xep_0199')
|
||||||
self.register_plugin('xep_0066')
|
self.register_plugin('xep_0066')
|
||||||
self.register_plugin('xep_0363')
|
self.register_plugin('xep_0363')
|
||||||
|
self.register_plugin('xep_0359')
|
||||||
|
self.register_plugin('xep_0461')
|
||||||
|
|
||||||
if self.enable_omemo:
|
if self.enable_omemo:
|
||||||
self.register_plugin('xep_0085')
|
self.register_plugin('xep_0085')
|
||||||
@@ -1355,6 +1360,48 @@ class LLMBot(slixmpp.ClientXMPP):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _create_reply_stanza(self, original_msg, body_text):
|
||||||
|
if self.use_proper_replies:
|
||||||
|
try:
|
||||||
|
reply_id = None
|
||||||
|
try:
|
||||||
|
reply_id = original_msg['stanza_id']['id']
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not reply_id:
|
||||||
|
try:
|
||||||
|
reply_id = original_msg['origin_id']['id']
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not reply_id:
|
||||||
|
reply_id = original_msg['id']
|
||||||
|
|
||||||
|
if not reply_id:
|
||||||
|
return original_msg.reply(body_text)
|
||||||
|
|
||||||
|
reply_to = original_msg['from']
|
||||||
|
mtype = original_msg['type']
|
||||||
|
|
||||||
|
if mtype == 'groupchat':
|
||||||
|
mto = reply_to.bare
|
||||||
|
else:
|
||||||
|
mto = reply_to
|
||||||
|
|
||||||
|
return self['xep_0461'].make_reply(
|
||||||
|
reply_to=reply_to,
|
||||||
|
reply_id=reply_id,
|
||||||
|
mto=mto,
|
||||||
|
mbody=body_text,
|
||||||
|
mtype=mtype
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error creating proper reply stanza: {e}")
|
||||||
|
return original_msg.reply(body_text)
|
||||||
|
else:
|
||||||
|
return original_msg.reply(body_text)
|
||||||
|
|
||||||
def direct_message(self, msg):
|
def direct_message(self, msg):
|
||||||
if not self.allow_dm:
|
if not self.allow_dm:
|
||||||
return
|
return
|
||||||
@@ -1391,13 +1438,13 @@ class LLMBot(slixmpp.ClientXMPP):
|
|||||||
if result:
|
if result:
|
||||||
if result["type"] == "url":
|
if result["type"] == "url":
|
||||||
response = result['content']
|
response = result['content']
|
||||||
reply = msg.reply(response)
|
reply = self._create_reply_stanza(msg, response)
|
||||||
reply['oob']['url'] = response
|
reply['oob']['url'] = response
|
||||||
reply.send()
|
reply.send()
|
||||||
else:
|
else:
|
||||||
msg.reply("Image generated (base64)").send()
|
self._create_reply_stanza(msg, "Image generated (base64)").send()
|
||||||
else:
|
else:
|
||||||
msg.reply("Failed to generate image.").send()
|
self._create_reply_stanza(msg, "Failed to generate image.").send()
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.tts_enabled and body.startswith(self.tts_trigger):
|
if self.tts_enabled and body.startswith(self.tts_trigger):
|
||||||
@@ -1407,13 +1454,13 @@ class LLMBot(slixmpp.ClientXMPP):
|
|||||||
if result:
|
if result:
|
||||||
if result["type"] == "url":
|
if result["type"] == "url":
|
||||||
response = result['content']
|
response = result['content']
|
||||||
reply = msg.reply(response)
|
reply = self._create_reply_stanza(msg, response)
|
||||||
reply['oob']['url'] = response
|
reply['oob']['url'] = response
|
||||||
reply.send()
|
reply.send()
|
||||||
else:
|
else:
|
||||||
msg.reply("Speech synthesized (base64)").send()
|
self._create_reply_stanza(msg, "Speech synthesized (base64)").send()
|
||||||
else:
|
else:
|
||||||
msg.reply("Failed to synthesize speech.").send()
|
self._create_reply_stanza(msg, "Failed to synthesize speech.").send()
|
||||||
return
|
return
|
||||||
|
|
||||||
quoted_text, non_quoted_text = self.extract_quoted_text(body)
|
quoted_text, non_quoted_text = self.extract_quoted_text(body)
|
||||||
@@ -1461,13 +1508,14 @@ class LLMBot(slixmpp.ClientXMPP):
|
|||||||
response = self.rate_limited_send(query, max_retries=self.max_retries, room_jid=sender, image_data=image_data, audio_data=audio_data, video_data=video_data)
|
response = self.rate_limited_send(query, max_retries=self.max_retries, room_jid=sender, image_data=image_data, audio_data=audio_data, video_data=video_data)
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
msg.reply(response).send()
|
reply_msg = self._create_reply_stanza(msg, response)
|
||||||
|
reply_msg.send()
|
||||||
logging.info(f"Replied to {sender}")
|
logging.info(f"Replied to {sender}")
|
||||||
|
|
||||||
if self.tts_auto_reply and self.tts_enabled:
|
if self.tts_auto_reply and self.tts_enabled:
|
||||||
tts_result = self.synthesize_speech(response, should_encrypt=False)
|
tts_result = self.synthesize_speech(response, should_encrypt=False)
|
||||||
if tts_result and tts_result["type"] == "url":
|
if tts_result and tts_result["type"] == "url":
|
||||||
reply = msg.reply(tts_result['content'])
|
reply = self._create_reply_stanza(msg, tts_result['content'])
|
||||||
reply['oob']['url'] = tts_result['content']
|
reply['oob']['url'] = tts_result['content']
|
||||||
reply.send()
|
reply.send()
|
||||||
|
|
||||||
@@ -1509,7 +1557,7 @@ class LLMBot(slixmpp.ClientXMPP):
|
|||||||
logging.info(f"Decrypted message from {sender}: {body[:50]}...")
|
logging.info(f"Decrypted message from {sender}: {body[:50]}...")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Decryption failed: {e}")
|
logging.error(f"Decryption failed: {e}")
|
||||||
await self._plain_reply(mfrom, mtype, f"Error decrypting message: {e}")
|
await self._plain_reply(stanza, mtype, f"Error decrypting message: {e}")
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
if stanza["body"]:
|
if stanza["body"]:
|
||||||
@@ -1537,9 +1585,9 @@ class LLMBot(slixmpp.ClientXMPP):
|
|||||||
response = "Failed to generate image."
|
response = "Failed to generate image."
|
||||||
|
|
||||||
if is_encrypted:
|
if is_encrypted:
|
||||||
await self._encrypted_reply(mfrom, mtype, response)
|
await self._encrypted_reply(stanza, mtype, response)
|
||||||
else:
|
else:
|
||||||
await self._plain_reply(mfrom, mtype, response)
|
await self._plain_reply(stanza, mtype, response)
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.tts_enabled and body.startswith(self.tts_trigger):
|
if self.tts_enabled and body.startswith(self.tts_trigger):
|
||||||
@@ -1554,9 +1602,9 @@ class LLMBot(slixmpp.ClientXMPP):
|
|||||||
response = "Failed to synthesize speech."
|
response = "Failed to synthesize speech."
|
||||||
|
|
||||||
if is_encrypted:
|
if is_encrypted:
|
||||||
await self._encrypted_reply(mfrom, mtype, response)
|
await self._encrypted_reply(stanza, mtype, response)
|
||||||
else:
|
else:
|
||||||
await self._plain_reply(mfrom, mtype, response)
|
await self._plain_reply(stanza, mtype, response)
|
||||||
return
|
return
|
||||||
|
|
||||||
quoted_text, non_quoted_text = self.extract_quoted_text(body)
|
quoted_text, non_quoted_text = self.extract_quoted_text(body)
|
||||||
@@ -1614,37 +1662,32 @@ class LLMBot(slixmpp.ClientXMPP):
|
|||||||
|
|
||||||
if response:
|
if response:
|
||||||
if is_encrypted:
|
if is_encrypted:
|
||||||
await self._encrypted_reply(mfrom, mtype, response)
|
await self._encrypted_reply(stanza, mtype, response)
|
||||||
else:
|
else:
|
||||||
await self._plain_reply(mfrom, mtype, response)
|
await self._plain_reply(stanza, mtype, response)
|
||||||
logging.info(f"Replied to {sender}")
|
logging.info(f"Replied to {sender}")
|
||||||
|
|
||||||
if self.tts_auto_reply and self.tts_enabled:
|
if self.tts_auto_reply and self.tts_enabled:
|
||||||
tts_result = await loop.run_in_executor(None, self.synthesize_speech, response, is_encrypted)
|
tts_result = await loop.run_in_executor(None, self.synthesize_speech, response, is_encrypted)
|
||||||
if tts_result and tts_result["type"] == "url":
|
if tts_result and tts_result["type"] == "url":
|
||||||
if is_encrypted:
|
if is_encrypted:
|
||||||
await self._encrypted_reply(mfrom, mtype, tts_result['content'])
|
await self._encrypted_reply(stanza, mtype, tts_result['content'])
|
||||||
else:
|
else:
|
||||||
await self._plain_reply(mfrom, mtype, tts_result['content'])
|
await self._plain_reply(stanza, mtype, tts_result['content'])
|
||||||
|
|
||||||
async def _plain_reply(self, mto: JID, mtype: str, reply_text: str) -> None:
|
async def _plain_reply(self, original_msg: Message, mtype: str, reply_text: str) -> None:
|
||||||
msg = self.make_message(mto=mto, mtype=mtype)
|
msg = self._create_reply_stanza(original_msg, reply_text)
|
||||||
msg["body"] = reply_text
|
|
||||||
if reply_text.startswith("http") and not reply_text.startswith("aesgcm"):
|
if reply_text.startswith("http") and not reply_text.startswith("aesgcm"):
|
||||||
msg['oob']['url'] = reply_text
|
msg['oob']['url'] = reply_text
|
||||||
msg.send()
|
msg.send()
|
||||||
|
|
||||||
async def _encrypted_reply(self, mto: JID, mtype: str, reply_text: str) -> None:
|
async def _encrypted_reply(self, original_msg: Message, mtype: str, reply_text: str) -> None:
|
||||||
xep_0384 = self["xep_0384"]
|
xep_0384 = self["xep_0384"]
|
||||||
|
|
||||||
|
|
||||||
if isinstance(mto, JID):
|
msg = self._create_reply_stanza(original_msg, reply_text)
|
||||||
target_jid = mto.bare
|
target_jid = JID(msg['to']).bare
|
||||||
else:
|
msg['to'] = target_jid
|
||||||
target_jid = JID(mto).bare
|
|
||||||
|
|
||||||
msg = self.make_message(mto=target_jid, mtype=mtype)
|
|
||||||
msg["body"] = reply_text
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1663,7 +1706,7 @@ class LLMBot(slixmpp.ClientXMPP):
|
|||||||
logging.debug(f"Sent encrypted message to {target_jid}")
|
logging.debug(f"Sent encrypted message to {target_jid}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to send encrypted reply: {e}")
|
logging.error(f"Failed to send encrypted reply: {e}")
|
||||||
await self._plain_reply(mto, mtype, f"Error encrypting reply: {e}")
|
await self._plain_reply(original_msg, mtype, f"Error encrypting reply: {e}")
|
||||||
|
|
||||||
def groupchat_message(self, msg):
|
def groupchat_message(self, msg):
|
||||||
if msg['type'] != 'groupchat':
|
if msg['type'] != 'groupchat':
|
||||||
@@ -1691,13 +1734,13 @@ class LLMBot(slixmpp.ClientXMPP):
|
|||||||
if result:
|
if result:
|
||||||
if result["type"] == "url":
|
if result["type"] == "url":
|
||||||
response = f"{sender_nick}: {result['content']}"
|
response = f"{sender_nick}: {result['content']}"
|
||||||
reply = msg.reply(response)
|
reply = self._create_reply_stanza(msg, response)
|
||||||
reply['oob']['url'] = result['content']
|
reply['oob']['url'] = result['content']
|
||||||
reply.send()
|
reply.send()
|
||||||
else:
|
else:
|
||||||
msg.reply(f"{sender_nick}: Image generated (base64)").send()
|
self._create_reply_stanza(msg, f"{sender_nick}: Image generated (base64)").send()
|
||||||
else:
|
else:
|
||||||
msg.reply(f"{sender_nick}: Failed to generate image.").send()
|
self._create_reply_stanza(msg, f"{sender_nick}: Failed to generate image.").send()
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.tts_enabled and body.startswith(self.tts_trigger):
|
if self.tts_enabled and body.startswith(self.tts_trigger):
|
||||||
@@ -1707,13 +1750,13 @@ class LLMBot(slixmpp.ClientXMPP):
|
|||||||
if result:
|
if result:
|
||||||
if result["type"] == "url":
|
if result["type"] == "url":
|
||||||
response = f"{sender_nick}: {result['content']}"
|
response = f"{sender_nick}: {result['content']}"
|
||||||
reply = msg.reply(response)
|
reply = self._create_reply_stanza(msg, response)
|
||||||
reply['oob']['url'] = result['content']
|
reply['oob']['url'] = result['content']
|
||||||
reply.send()
|
reply.send()
|
||||||
else:
|
else:
|
||||||
msg.reply(f"{sender_nick}: Speech synthesized (base64)").send()
|
self._create_reply_stanza(msg, f"{sender_nick}: Speech synthesized (base64)").send()
|
||||||
else:
|
else:
|
||||||
msg.reply(f"{sender_nick}: Failed to synthesize speech.").send()
|
self._create_reply_stanza(msg, f"{sender_nick}: Failed to synthesize speech.").send()
|
||||||
return
|
return
|
||||||
|
|
||||||
query = None
|
query = None
|
||||||
@@ -1852,7 +1895,7 @@ class LLMBot(slixmpp.ClientXMPP):
|
|||||||
quoted = '\n'.join(f"> {line}" for line in lines)
|
quoted = '\n'.join(f"> {line}" for line in lines)
|
||||||
response = f"{quoted}\n{response}"
|
response = f"{quoted}\n{response}"
|
||||||
|
|
||||||
msg.reply(response).send()
|
self._create_reply_stanza(msg, response).send()
|
||||||
logging.info(f"Replied in {room_jid}")
|
logging.info(f"Replied in {room_jid}")
|
||||||
|
|
||||||
if self.tts_auto_reply and self.tts_enabled:
|
if self.tts_auto_reply and self.tts_enabled:
|
||||||
@@ -1863,7 +1906,7 @@ class LLMBot(slixmpp.ClientXMPP):
|
|||||||
|
|
||||||
tts_result = self.synthesize_speech(tts_text, should_encrypt=False)
|
tts_result = self.synthesize_speech(tts_text, should_encrypt=False)
|
||||||
if tts_result and tts_result["type"] == "url":
|
if tts_result and tts_result["type"] == "url":
|
||||||
reply = msg.reply(tts_result['content'])
|
reply = self._create_reply_stanza(msg, tts_result['content'])
|
||||||
reply['oob']['url'] = tts_result['content']
|
reply['oob']['url'] = tts_result['content']
|
||||||
reply.send()
|
reply.send()
|
||||||
|
|
||||||
@@ -2332,6 +2375,8 @@ if __name__ == '__main__':
|
|||||||
tts_model = config.get('Bot', 'tts_model', fallback='gemini-2.5-flash-preview-tts')
|
tts_model = config.get('Bot', 'tts_model', fallback='gemini-2.5-flash-preview-tts')
|
||||||
tts_auto_reply = config.getboolean('Bot', 'tts_auto_reply', fallback=False)
|
tts_auto_reply = config.getboolean('Bot', 'tts_auto_reply', fallback=False)
|
||||||
|
|
||||||
|
use_proper_replies = config.getboolean('Bot', 'use_proper_replies', fallback=True)
|
||||||
|
|
||||||
system_prompts = {}
|
system_prompts = {}
|
||||||
global_prompt = config.get('Bot', 'system_prompt', fallback='').strip()
|
global_prompt = config.get('Bot', 'system_prompt', fallback='').strip()
|
||||||
if global_prompt:
|
if global_prompt:
|
||||||
@@ -2379,6 +2424,7 @@ if __name__ == '__main__':
|
|||||||
tts_trigger=tts_trigger, tts_enabled=tts_enabled,
|
tts_trigger=tts_trigger, tts_enabled=tts_enabled,
|
||||||
tts_voice_name=tts_voice_name, tts_model=tts_model,
|
tts_voice_name=tts_voice_name, tts_model=tts_model,
|
||||||
tts_auto_reply=tts_auto_reply,
|
tts_auto_reply=tts_auto_reply,
|
||||||
|
use_proper_replies=use_proper_replies,
|
||||||
loop=loop)
|
loop=loop)
|
||||||
|
|
||||||
bot.connect()
|
bot.connect()
|
||||||
|
|||||||
Reference in New Issue
Block a user