diff --git a/bot.py b/bot.py index 6dbe3d8..f1899ba 100644 --- a/bot.py +++ b/bot.py @@ -47,16 +47,14 @@ async def discord_send(channel, text, name, avatar=DEFAULT_AVATAR): messages.append(message) return messages - -# --- Data Storage --- -# Keyed by channel ID -conversations = {} -_webhooks = {} async def webhook(channel): - if channel.id not in _webhooks: - hook = await channel.create_webhook(name=f'aoi-{channel.id}') - _webhooks[channel.id] = hook - return _webhooks[channel.id] + name = f'aoi-{channel.id}' + channel_hooks = [ + hook for hook in (await channel.webhooks()) if hook.name == name + ] + if not channel_hooks: + return await channel.create_webhook(name=f'aoi-{channel.id}') + return channel_hooks[0] # --- Bot Events --- @@ -75,9 +73,7 @@ async def on_message(message): bot_tag = f'<@{bot.user.id}>' channel = message.channel - if channel.id not in conversations: - conversations[channel.id] = await Conversation.create(args.base_url) - conversation = conversations[channel.id] + conversation = await Conversation.get(channel.id) user_message = message.content if user_message.startswith(bot_tag): user_message = user_message[len(bot_tag):] @@ -92,9 +88,9 @@ async def on_message(message): try: async with channel.typing(): response = await conversation.generate(user_message, media) - conversation.last_messages = await discord_send( - channel, response, conversation.bot_name, - ) + conversation.last_messages = await discord_send( + channel, response, conversation.bot_name, + ) except Exception as e: print(f"An error occurred: {e}") await message.reply("Sorry, I had a little hiccup. Baka!") @@ -105,7 +101,7 @@ async def on_reaction_add(reaction, user): return message = reaction.message channel = message.channel - conversation = conversations[channel.id] + conversation = await Conversation.get(channel.id) if message not in conversation.last_messages: await reaction.clear() return @@ -132,8 +128,7 @@ async def on_reaction_add(reaction, user): async def newchat(interaction: discord.Interaction, prompt: str = None): await interaction.response.defer() channel_id = interaction.channel_id - conversation = await Conversation.create(args.base_url, prompt) - conversations[channel_id] = conversation + conversation = await Conversation.create(channel_id, args.base_url, prompt) await interaction.followup.send( f'Starting a new chat with {conversation.bot_name}: "{prompt}"' ) diff --git a/llm_client.py b/llm_client.py index dcfc2ed..6a5d70e 100644 --- a/llm_client.py +++ b/llm_client.py @@ -10,6 +10,8 @@ DEFAULT_SYSTEM_PROMPT = ( ) NAME_PROMPT = "reply with your name, nothing else, no punctuation" +conversations = {} + class Conversation: def __init__(self, client, name, prompt): @@ -19,11 +21,20 @@ class Conversation: self.client = client @classmethod - async def create(cls, base_url, prompt=None): + async def get(cls, key): + if key not in conversations: + conversations[key] = await Conversation.create(args.base_url) + return conversations[key] + + @classmethod + async def create(cls, channel_id, base_url, prompt=None): client = AsyncOpenAI(base_url=base_url, api_key=API_KEY) if not prompt: - return cls(client, DEFAULT_NAME, DEFAULT_SYSTEM_PROMPT) - return cls(client, await cls.get_name(client, prompt), prompt) + convo = cls(client, DEFAULT_NAME, DEFAULT_SYSTEM_PROMPT) + else: + convo = cls(client, await cls.get_name(client, prompt), prompt) + conversations[channel_id] = convo + return convo @classmethod async def get_name(self, client, system_prompt):