Add reply support
This commit is contained in:
120
bot.py
120
bot.py
@@ -418,6 +418,7 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
file_host='catbox', file_host_api_key=None,
|
||||
tts_trigger="!tts", tts_enabled=False, tts_voice_name="Kore",
|
||||
tts_model="gemini-2.5-flash-preview-tts", tts_auto_reply=False,
|
||||
use_proper_replies=True,
|
||||
loop=None):
|
||||
self.request_timeout = request_timeout
|
||||
super().__init__(jid, password, loop=loop, sasl_mech='PLAIN')
|
||||
@@ -484,6 +485,8 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
self.tts_voice_name = tts_voice_name
|
||||
self.tts_model = tts_model
|
||||
self.tts_auto_reply = tts_auto_reply
|
||||
|
||||
self.use_proper_replies = use_proper_replies
|
||||
|
||||
self.genai_client = None
|
||||
if GOOGLE_GENAI_AVAILABLE and self.api_token:
|
||||
@@ -524,6 +527,8 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
self.register_plugin('xep_0199')
|
||||
self.register_plugin('xep_0066')
|
||||
self.register_plugin('xep_0363')
|
||||
self.register_plugin('xep_0359')
|
||||
self.register_plugin('xep_0461')
|
||||
|
||||
if self.enable_omemo:
|
||||
self.register_plugin('xep_0085')
|
||||
@@ -1355,6 +1360,48 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
|
||||
return None
|
||||
|
||||
def _create_reply_stanza(self, original_msg, body_text):
|
||||
if self.use_proper_replies:
|
||||
try:
|
||||
reply_id = None
|
||||
try:
|
||||
reply_id = original_msg['stanza_id']['id']
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
|
||||
if not reply_id:
|
||||
try:
|
||||
reply_id = original_msg['origin_id']['id']
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
|
||||
if not reply_id:
|
||||
reply_id = original_msg['id']
|
||||
|
||||
if not reply_id:
|
||||
return original_msg.reply(body_text)
|
||||
|
||||
reply_to = original_msg['from']
|
||||
mtype = original_msg['type']
|
||||
|
||||
if mtype == 'groupchat':
|
||||
mto = reply_to.bare
|
||||
else:
|
||||
mto = reply_to
|
||||
|
||||
return self['xep_0461'].make_reply(
|
||||
reply_to=reply_to,
|
||||
reply_id=reply_id,
|
||||
mto=mto,
|
||||
mbody=body_text,
|
||||
mtype=mtype
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error creating proper reply stanza: {e}")
|
||||
return original_msg.reply(body_text)
|
||||
else:
|
||||
return original_msg.reply(body_text)
|
||||
|
||||
def direct_message(self, msg):
|
||||
if not self.allow_dm:
|
||||
return
|
||||
@@ -1391,13 +1438,13 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
if result:
|
||||
if result["type"] == "url":
|
||||
response = result['content']
|
||||
reply = msg.reply(response)
|
||||
reply = self._create_reply_stanza(msg, response)
|
||||
reply['oob']['url'] = response
|
||||
reply.send()
|
||||
else:
|
||||
msg.reply("Image generated (base64)").send()
|
||||
self._create_reply_stanza(msg, "Image generated (base64)").send()
|
||||
else:
|
||||
msg.reply("Failed to generate image.").send()
|
||||
self._create_reply_stanza(msg, "Failed to generate image.").send()
|
||||
return
|
||||
|
||||
if self.tts_enabled and body.startswith(self.tts_trigger):
|
||||
@@ -1407,13 +1454,13 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
if result:
|
||||
if result["type"] == "url":
|
||||
response = result['content']
|
||||
reply = msg.reply(response)
|
||||
reply = self._create_reply_stanza(msg, response)
|
||||
reply['oob']['url'] = response
|
||||
reply.send()
|
||||
else:
|
||||
msg.reply("Speech synthesized (base64)").send()
|
||||
self._create_reply_stanza(msg, "Speech synthesized (base64)").send()
|
||||
else:
|
||||
msg.reply("Failed to synthesize speech.").send()
|
||||
self._create_reply_stanza(msg, "Failed to synthesize speech.").send()
|
||||
return
|
||||
|
||||
quoted_text, non_quoted_text = self.extract_quoted_text(body)
|
||||
@@ -1461,13 +1508,14 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
response = self.rate_limited_send(query, max_retries=self.max_retries, room_jid=sender, image_data=image_data, audio_data=audio_data, video_data=video_data)
|
||||
|
||||
if response:
|
||||
msg.reply(response).send()
|
||||
reply_msg = self._create_reply_stanza(msg, response)
|
||||
reply_msg.send()
|
||||
logging.info(f"Replied to {sender}")
|
||||
|
||||
if self.tts_auto_reply and self.tts_enabled:
|
||||
tts_result = self.synthesize_speech(response, should_encrypt=False)
|
||||
if tts_result and tts_result["type"] == "url":
|
||||
reply = msg.reply(tts_result['content'])
|
||||
reply = self._create_reply_stanza(msg, tts_result['content'])
|
||||
reply['oob']['url'] = tts_result['content']
|
||||
reply.send()
|
||||
|
||||
@@ -1509,7 +1557,7 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
logging.info(f"Decrypted message from {sender}: {body[:50]}...")
|
||||
except Exception as e:
|
||||
logging.error(f"Decryption failed: {e}")
|
||||
await self._plain_reply(mfrom, mtype, f"Error decrypting message: {e}")
|
||||
await self._plain_reply(stanza, mtype, f"Error decrypting message: {e}")
|
||||
return
|
||||
else:
|
||||
if stanza["body"]:
|
||||
@@ -1537,9 +1585,9 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
response = "Failed to generate image."
|
||||
|
||||
if is_encrypted:
|
||||
await self._encrypted_reply(mfrom, mtype, response)
|
||||
await self._encrypted_reply(stanza, mtype, response)
|
||||
else:
|
||||
await self._plain_reply(mfrom, mtype, response)
|
||||
await self._plain_reply(stanza, mtype, response)
|
||||
return
|
||||
|
||||
if self.tts_enabled and body.startswith(self.tts_trigger):
|
||||
@@ -1554,9 +1602,9 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
response = "Failed to synthesize speech."
|
||||
|
||||
if is_encrypted:
|
||||
await self._encrypted_reply(mfrom, mtype, response)
|
||||
await self._encrypted_reply(stanza, mtype, response)
|
||||
else:
|
||||
await self._plain_reply(mfrom, mtype, response)
|
||||
await self._plain_reply(stanza, mtype, response)
|
||||
return
|
||||
|
||||
quoted_text, non_quoted_text = self.extract_quoted_text(body)
|
||||
@@ -1614,37 +1662,32 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
|
||||
if response:
|
||||
if is_encrypted:
|
||||
await self._encrypted_reply(mfrom, mtype, response)
|
||||
await self._encrypted_reply(stanza, mtype, response)
|
||||
else:
|
||||
await self._plain_reply(mfrom, mtype, response)
|
||||
await self._plain_reply(stanza, mtype, response)
|
||||
logging.info(f"Replied to {sender}")
|
||||
|
||||
if self.tts_auto_reply and self.tts_enabled:
|
||||
tts_result = await loop.run_in_executor(None, self.synthesize_speech, response, is_encrypted)
|
||||
if tts_result and tts_result["type"] == "url":
|
||||
if is_encrypted:
|
||||
await self._encrypted_reply(mfrom, mtype, tts_result['content'])
|
||||
await self._encrypted_reply(stanza, mtype, tts_result['content'])
|
||||
else:
|
||||
await self._plain_reply(mfrom, mtype, tts_result['content'])
|
||||
await self._plain_reply(stanza, mtype, tts_result['content'])
|
||||
|
||||
async def _plain_reply(self, mto: JID, mtype: str, reply_text: str) -> None:
|
||||
msg = self.make_message(mto=mto, mtype=mtype)
|
||||
msg["body"] = reply_text
|
||||
async def _plain_reply(self, original_msg: Message, mtype: str, reply_text: str) -> None:
|
||||
msg = self._create_reply_stanza(original_msg, reply_text)
|
||||
if reply_text.startswith("http") and not reply_text.startswith("aesgcm"):
|
||||
msg['oob']['url'] = reply_text
|
||||
msg.send()
|
||||
|
||||
async def _encrypted_reply(self, mto: JID, mtype: str, reply_text: str) -> None:
|
||||
async def _encrypted_reply(self, original_msg: Message, mtype: str, reply_text: str) -> None:
|
||||
xep_0384 = self["xep_0384"]
|
||||
|
||||
|
||||
if isinstance(mto, JID):
|
||||
target_jid = mto.bare
|
||||
else:
|
||||
target_jid = JID(mto).bare
|
||||
|
||||
msg = self.make_message(mto=target_jid, mtype=mtype)
|
||||
msg["body"] = reply_text
|
||||
msg = self._create_reply_stanza(original_msg, reply_text)
|
||||
target_jid = JID(msg['to']).bare
|
||||
msg['to'] = target_jid
|
||||
|
||||
|
||||
|
||||
@@ -1663,7 +1706,7 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
logging.debug(f"Sent encrypted message to {target_jid}")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to send encrypted reply: {e}")
|
||||
await self._plain_reply(mto, mtype, f"Error encrypting reply: {e}")
|
||||
await self._plain_reply(original_msg, mtype, f"Error encrypting reply: {e}")
|
||||
|
||||
def groupchat_message(self, msg):
|
||||
if msg['type'] != 'groupchat':
|
||||
@@ -1691,13 +1734,13 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
if result:
|
||||
if result["type"] == "url":
|
||||
response = f"{sender_nick}: {result['content']}"
|
||||
reply = msg.reply(response)
|
||||
reply = self._create_reply_stanza(msg, response)
|
||||
reply['oob']['url'] = result['content']
|
||||
reply.send()
|
||||
else:
|
||||
msg.reply(f"{sender_nick}: Image generated (base64)").send()
|
||||
self._create_reply_stanza(msg, f"{sender_nick}: Image generated (base64)").send()
|
||||
else:
|
||||
msg.reply(f"{sender_nick}: Failed to generate image.").send()
|
||||
self._create_reply_stanza(msg, f"{sender_nick}: Failed to generate image.").send()
|
||||
return
|
||||
|
||||
if self.tts_enabled and body.startswith(self.tts_trigger):
|
||||
@@ -1707,13 +1750,13 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
if result:
|
||||
if result["type"] == "url":
|
||||
response = f"{sender_nick}: {result['content']}"
|
||||
reply = msg.reply(response)
|
||||
reply = self._create_reply_stanza(msg, response)
|
||||
reply['oob']['url'] = result['content']
|
||||
reply.send()
|
||||
else:
|
||||
msg.reply(f"{sender_nick}: Speech synthesized (base64)").send()
|
||||
self._create_reply_stanza(msg, f"{sender_nick}: Speech synthesized (base64)").send()
|
||||
else:
|
||||
msg.reply(f"{sender_nick}: Failed to synthesize speech.").send()
|
||||
self._create_reply_stanza(msg, f"{sender_nick}: Failed to synthesize speech.").send()
|
||||
return
|
||||
|
||||
query = None
|
||||
@@ -1852,7 +1895,7 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
quoted = '\n'.join(f"> {line}" for line in lines)
|
||||
response = f"{quoted}\n{response}"
|
||||
|
||||
msg.reply(response).send()
|
||||
self._create_reply_stanza(msg, response).send()
|
||||
logging.info(f"Replied in {room_jid}")
|
||||
|
||||
if self.tts_auto_reply and self.tts_enabled:
|
||||
@@ -1863,7 +1906,7 @@ class LLMBot(slixmpp.ClientXMPP):
|
||||
|
||||
tts_result = self.synthesize_speech(tts_text, should_encrypt=False)
|
||||
if tts_result and tts_result["type"] == "url":
|
||||
reply = msg.reply(tts_result['content'])
|
||||
reply = self._create_reply_stanza(msg, tts_result['content'])
|
||||
reply['oob']['url'] = tts_result['content']
|
||||
reply.send()
|
||||
|
||||
@@ -2332,6 +2375,8 @@ if __name__ == '__main__':
|
||||
tts_model = config.get('Bot', 'tts_model', fallback='gemini-2.5-flash-preview-tts')
|
||||
tts_auto_reply = config.getboolean('Bot', 'tts_auto_reply', fallback=False)
|
||||
|
||||
use_proper_replies = config.getboolean('Bot', 'use_proper_replies', fallback=True)
|
||||
|
||||
system_prompts = {}
|
||||
global_prompt = config.get('Bot', 'system_prompt', fallback='').strip()
|
||||
if global_prompt:
|
||||
@@ -2379,6 +2424,7 @@ if __name__ == '__main__':
|
||||
tts_trigger=tts_trigger, tts_enabled=tts_enabled,
|
||||
tts_voice_name=tts_voice_name, tts_model=tts_model,
|
||||
tts_auto_reply=tts_auto_reply,
|
||||
use_proper_replies=use_proper_replies,
|
||||
loop=loop)
|
||||
|
||||
bot.connect()
|
||||
|
||||
Reference in New Issue
Block a user