diff --git a/bot.py b/bot.py index 4211829..a2ad2c2 100644 --- a/bot.py +++ b/bot.py @@ -6,6 +6,7 @@ import argparse from typing import List, Dict, Any from llm_client import Conversation +from database import Database # --- Configuration --- DEFAULT_AVATAR = "https://cdn.discordapp.com/avatars/1406466525858369716/f1dfeaf2a1c361dbf981e2e899c7f981?size=256" @@ -23,10 +24,15 @@ parser.add_argument( args = parser.parse_args() # --- Bot Setup --- +class AoiBot(commands.Bot): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + async def setup_hook(self): + self.db = await Database.get() intents = discord.Intents.default() intents.messages = True intents.message_content = True -bot = commands.Bot(command_prefix="/", intents=intents) +bot = AoiBot(command_prefix="/", intents=intents) # --- Helpers --- @@ -75,7 +81,7 @@ async def on_message(message): bot_tag = f'<@{bot.user.id}>' channel = message.channel - conversation = await Conversation.get(channel.id, args.base_url) + conversation = await Conversation.get(channel.id, args.base_url, bot.db) user_message = message.content if user_message.startswith(bot_tag): user_message = user_message[len(bot_tag):] @@ -91,13 +97,16 @@ async def on_message(message): async with channel.typing(): response = await conversation.generate(user_message, media) for old_message_id in conversation.last_messages: - old_message = await channel.fetch_message(old_message_id) - await old_message.clear_reaction("🔁") - await old_message.clear_reaction("❌") + try: + old_message = await channel.fetch_message(old_message_id) + await old_message.clear_reaction("🔁") + await old_message.clear_reaction("❌") + except (discord.NotFound, discord.Forbidden): + pass # Ignore if message is not found or we don't have perms conversation.last_messages = await discord_send( channel, response, conversation.bot_name, ) - conversation.save() + await conversation.save() except Exception as e: print(f"An error occurred: {e}") await message.reply("Sorry, I had a little hiccup. Baka!") @@ -108,7 +117,7 @@ async def on_reaction_add(reaction, user): return message = reaction.message channel = message.channel - conversation = await Conversation.get(channel.id, args.base_url) + conversation = await Conversation.get(channel.id, args.base_url, bot.db) if message.id not in conversation.last_messages: await reaction.clear() return @@ -124,17 +133,18 @@ async def on_reaction_add(reaction, user): except (discord.NotFound, discord.Forbidden) as e: # don't do anything if any message in the list is not found await reaction.clear() + return for message in messages: await message.delete() if reaction.emoji == "❌": - conversation.pop() + await conversation.pop() elif reaction.emoji == "🔁": response = await conversation.regenerate() conversation.last_messages = await discord_send( channel, response, conversation.bot_name, ) - conversation.save() + await conversation.save() except Exception as e: print(f"An error occurred: {e}") await message.reply("Sorry, I had a little hiccup. Baka!") @@ -148,7 +158,9 @@ async def on_reaction_add(reaction, user): async def newchat(interaction: discord.Interaction, prompt: str = None): await interaction.response.defer() channel_id = interaction.channel_id - conversation = await Conversation.create(channel_id, args.base_url, prompt) + conversation = await Conversation.create( + channel_id, args.base_url, bot.db, prompt + ) await interaction.followup.send( f'Starting a new chat with {conversation.bot_name}: "{prompt}"' ) @@ -157,4 +169,3 @@ async def newchat(interaction: discord.Interaction, prompt: str = None): # --- Running the Bot --- if __name__ == "__main__": bot.run(args.discord_token) - diff --git a/database.py b/database.py index a375588..ba79395 100644 --- a/database.py +++ b/database.py @@ -1,28 +1,35 @@ -import sqlite3 +import aiosqlite import json class Database: def __init__(self, db_path='conversations.db'): self.db_path = db_path - self.conn = sqlite3.connect(self.db_path) - self.create_table() + self.conn = None - def create_table(self): - with self.conn: - self.conn.execute(""" - CREATE TABLE IF NOT EXISTS conversations ( - id TEXT PRIMARY KEY, - history TEXT NOT NULL, - bot_name TEXT NOT NULL, - last_messages TEXT NOT NULL - ) - """) + @staticmethod + async def get(db_path='conversations.db'): + """ + Asynchronously creates and returns a connected Database instance. + """ + db = Database(db_path) + db.conn = await aiosqlite.connect(db.db_path) + await db._create_table() + return db - def get(self, conversation_id): - with self.conn: - cursor = self.conn.cursor() - cursor.execute("SELECT history, bot_name, last_messages FROM conversations WHERE id = ?", (conversation_id,)) - row = cursor.fetchone() + async def _create_table(self): + await self.conn.execute(""" + CREATE TABLE IF NOT EXISTS conversations ( + id TEXT PRIMARY KEY, + history TEXT NOT NULL, + bot_name TEXT NOT NULL, + last_messages TEXT NOT NULL + ) + """) + await self.conn.commit() + + async def get_conversation(self, conversation_id): + async with self.conn.execute("SELECT history, bot_name, last_messages FROM conversations WHERE id = ?", (conversation_id,)) as cursor: + row = await cursor.fetchone() if row: history = json.loads(row[0]) bot_name = row[1] @@ -30,15 +37,31 @@ class Database: return history, bot_name, last_messages return None - def save(self, conversation_id, history, bot_name, last_messages): - with self.conn: - self.conn.execute( - "INSERT OR REPLACE INTO conversations (id, history, bot_name, last_messages) VALUES (?, ?, ?, ?)", - (conversation_id, json.dumps(history), bot_name, json.dumps(last_messages)) - ) + async def save(self, conversation_id, history, bot_name, last_messages): + await self.conn.execute( + "INSERT OR REPLACE INTO conversations (id, history, bot_name, last_messages) VALUES (?, ?, ?, ?)", + (conversation_id, json.dumps(history), bot_name, json.dumps(last_messages)) + ) + await self.conn.commit() - def delete(self, conversation_id): - with self.conn: - self.conn.execute("DELETE FROM conversations WHERE id = ?", (conversation_id,)) + async def delete(self, conversation_id): + await self.conn.execute("DELETE FROM conversations WHERE id = ?", (conversation_id,)) + await self.conn.commit() -db = Database() + async def close(self): + if self.conn: + await self.conn.close() + self.conn = None + +# To use this, you would typically do this in your main application file: +# +# import asyncio +# from database import Database +# +# async def main(): +# db = await Database.get() +# # now you can use db.get_conversation, db.save, etc. +# await db.close() +# +# if __name__ == "__main__": +# asyncio.run(main()) \ No newline at end of file diff --git a/llm_client.py b/llm_client.py index ee07637..e9a6c68 100644 --- a/llm_client.py +++ b/llm_client.py @@ -1,7 +1,7 @@ import aiohttp import base64 from openai import AsyncOpenAI -from database import db +from database import Database API_KEY = "eh" MODEL = "p620" @@ -13,12 +13,13 @@ NAME_PROMPT = "reply with your name, nothing else, no punctuation" class Conversation: - def __init__(self, client, name, prompt, convo_id): + def __init__(self, client, name, prompt, convo_id, db): self.history = [{"role": "system", "content": prompt}] self.bot_name = name self.last_messages = [] self.client = client self.id = convo_id + self.db = db def __str__(self): return ( @@ -26,29 +27,32 @@ class Conversation: f"{self.history}" ) - def save(self): - db.save(self.id, self.history, self.bot_name, self.last_messages) + async def save(self): + await self.db.save( + self.id, self.history, self.bot_name, self.last_messages + ) @classmethod - async def get(cls, key, base_url): - convo_data = db.get(key) + async def get(cls, key, base_url, db): + convo_data = await db.get_conversation(key) if convo_data: history, bot_name, last_messages = convo_data client = AsyncOpenAI(base_url=base_url, api_key=API_KEY) - convo = cls(client, bot_name, history[0]['content'], key) + convo = cls(client, bot_name, history[0]['content'], key, db) convo.history = history convo.last_messages = last_messages return convo - return await Conversation.create(key, base_url) + return await Conversation.create(key, base_url, db) @classmethod - async def create(cls, key, base_url, prompt=None): + async def create(cls, key, base_url, db, prompt=None): client = AsyncOpenAI(base_url=base_url, api_key=API_KEY) if not prompt: - convo = cls(client, DEFAULT_NAME, DEFAULT_SYSTEM_PROMPT, key) + convo = cls(client, DEFAULT_NAME, DEFAULT_SYSTEM_PROMPT, key, db) else: - convo = cls(client, await cls.get_name(client, prompt), prompt, key) - convo.save() + name = await cls.get_name(client, prompt) + convo = cls(client, name, prompt, key, db) + await convo.save() return convo @classmethod @@ -68,10 +72,10 @@ class Conversation: {"role": "assistant", "content": assistant}, ]) - def pop(self): + async def pop(self): if len(self.history) >= 3: self.history = self.history[:-2] - self.save() + await self.save() async def generate(self, text, media=tuple()): # prepare text part @@ -115,4 +119,3 @@ class Conversation: response = llm_response.choices[0].message.content self.history[-1] = {"role": "assistant", "content": response} return response -