From 6f001f29242378b668d98a5d0027479cf350e78f Mon Sep 17 00:00:00 2001 From: Dory Date: Sat, 23 Aug 2025 23:34:31 -0700 Subject: [PATCH] refactor all parameters to commandline arguments --- bot.py | 77 ++++++++++++-------- llm_client.py => conversations.py | 114 ++++++++++++++++-------------- database.py | 20 ++++-- 3 files changed, 122 insertions(+), 89 deletions(-) rename llm_client.py => conversations.py (52%) diff --git a/bot.py b/bot.py index 2de3095..47d0963 100644 --- a/bot.py +++ b/bot.py @@ -1,21 +1,40 @@ import collections import discord from discord.ext import commands +from openai import AsyncOpenAI import os import argparse from typing import List, Dict, Any -from llm_client import Conversation +from conversations import Conversation, ConversationManager from database import Database # --- Configuration --- DEFAULT_AVATAR = "https://cdn.discordapp.com/avatars/1406466525858369716/f1dfeaf2a1c361dbf981e2e899c7f981?size=256" +DEFAULT_SYSTEM_PROMPT = "you are a catboy named Aoi with dark blue fur and is a tsundere" +DEFAULT_DB = "conversations.db" # --- Command Line Arguments --- parser = argparse.ArgumentParser(description="Aoi Discord Bot") parser.add_argument( - '--base_url', type=str, required=True, - help='The base URL for the OpenAI API.', + '--base_url', type=str, default='http://localhost:8080/v1', + help='The base URL for the OpenAI API server.', +) +parser.add_argument( + '--api_key', type=str, default='', + help='The API key for OpenAI API.', +) +parser.add_argument( + '--model', type=str, default='', + help='The model to use from OpenAI API.', +) +parser.add_argument( + '--default_prompt', type=str, default=DEFAULT_SYSTEM_PROMPT, + help='Default system prompt when not given in chat.', +) +parser.add_argument( + '--db', type=str, default=DEFAULT_DB, + help='SQLite DB to use.', ) parser.add_argument( '--discord_token', type=str, required=True, @@ -24,11 +43,16 @@ parser.add_argument( args = parser.parse_args() # --- Bot Setup --- + class AoiBot(commands.Bot): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) async def setup_hook(self): - self.db = Database.get() + db = Database.get(args.db) + openai = AsyncOpenAI(base_url=args.base_url, api_key=args.api_key) + self.manager = ConversationManager( + openai, args.model, db, args.default_prompt, + ) intents = discord.Intents.default() intents.messages = True intents.message_content = True @@ -90,7 +114,7 @@ async def on_message(message): bot_tag = f'<@{bot.user.id}>' channel = message.channel - conversation = await Conversation.get(channel.id, args.base_url, bot.db) + conversation = await bot.manager.get(channel.id) user_message = message.content if user_message.startswith(bot_tag): user_message = user_message[len(bot_tag):] @@ -120,29 +144,29 @@ async def on_reaction_add(reaction, user): return message = reaction.message channel = message.channel - conversation = await Conversation.get(channel.id, args.base_url, bot.db) + conversation = await bot.manager.get(channel.id) if message.id not in conversation.last_messages: await reaction.clear() return print(f"_ {user}: {reaction}") try: - async with channel.typing(): - try: - messages = [ - await channel.fetch_message(message_id) - for message_id in conversation.last_messages - ] - except (discord.NotFound, discord.Forbidden) as e: - # don't do anything if any message in the list is not found - await reaction.clear() - return - for message in messages: - await message.delete() + try: + messages = [ + await channel.fetch_message(message_id) + for message_id in conversation.last_messages + ] + except (discord.NotFound, discord.Forbidden): + # don't do anything if any message in the list is not found + await reaction.clear() + return + for message in messages: + await message.delete() - if reaction.emoji == "❌": - await conversation.pop() - elif reaction.emoji == "🔁": + if reaction.emoji == "❌": + await conversation.pop() + elif reaction.emoji == "🔁": + async with channel.typing(): response = await conversation.regenerate() conversation.last_messages = await discord_send( channel, response, conversation.bot_name, @@ -161,14 +185,10 @@ async def on_reaction_add(reaction, user): async def newchat(interaction: discord.Interaction, prompt: str = None): await interaction.response.defer() channel_id = interaction.channel_id - old_convo = await Conversation.get( - channel_id, args.base_url, bot.db, create_if_not_exist=False, - ) + old_convo = await bot.manager.get(channel_id, create_if_missing=False) if old_convo: await clear_reactions(interaction.channel, old_convo.last_messages) - conversation = await Conversation.create( - channel_id, args.base_url, bot.db, prompt - ) + conversation = await bot.manager.new_conversation(channel_id, prompt) await interaction.followup.send( f'Starting a new chat with {conversation.bot_name}: ' f'"{conversation.history[0]["content"]}"' @@ -181,7 +201,7 @@ async def newchat(interaction: discord.Interaction, prompt: str = None): async def changeprompt(interaction: discord.Interaction, prompt: str): await interaction.response.defer() channel_id = interaction.channel_id - conversation = await Conversation.get(channel_id, args.base_url, bot.db) + conversation = await bot.manager.get(channel_id) await conversation.update_prompt(prompt) await interaction.followup.send( f'Now chatting with {conversation.bot_name}: "{prompt}"' @@ -191,4 +211,3 @@ async def changeprompt(interaction: discord.Interaction, prompt: str): # --- Running the Bot --- if __name__ == "__main__": bot.run(args.discord_token) - diff --git a/llm_client.py b/conversations.py similarity index 52% rename from llm_client.py rename to conversations.py index 60c99c9..0a89163 100644 --- a/llm_client.py +++ b/conversations.py @@ -1,90 +1,97 @@ import aiohttp import base64 -from openai import AsyncOpenAI from database import Database API_KEY = "eh" MODEL = "p620" DEFAULT_NAME = "Aoi" -DEFAULT_SYSTEM_PROMPT = ( - "you are a catboy named Aoi with dark blue fur and is a tsundere" -) NAME_PROMPT = "reply with your name, nothing else, no punctuation" -class Conversation: - def __init__(self, client, name, prompt, convo_id, db): - self.history = [{"role": "system", "content": prompt}] - self.bot_name = name - self.last_messages = [] - self.client = client - self.id = convo_id +async def get_name(client, model, prompt): + """Generates an assistant name for the given prompt.""" + name_response = await client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": prompt}, + {"role": "user", "content": NAME_PROMPT} + ], + ) + return name_response.choices[0].message.content.split('\n')[0] + + +class ConversationManager: + """Creates and retrieves Conversations.""" + def __init__(self, openai_client, model, db, default_prompt): + self.model = model + self.client = openai_client self.db = db + self.default_prompt = default_prompt - def __str__(self): - return ( - f"Conversation({self.bot_name}, {self.last_messages}, " - f"{self.history}" - ) - - async def save(self): - self.db.save( - self.id, self.history, self.bot_name, self.last_messages - ) - - @classmethod - async def get(cls, key, base_url, db, create_if_not_exist=True): - convo_data = db.get_conversation(key) + async def get(self, key, create_if_missing=True): + """Gets a conversation based on |key|, optionally create when not found.""" + convo_data = self.db.get_conversation(key) if convo_data: history, bot_name, last_messages = convo_data - client = AsyncOpenAI(base_url=base_url, api_key=API_KEY) - convo = cls(client, bot_name, history[0]['content'], key, db) - convo.history = history - convo.last_messages = last_messages - return convo - if create_if_not_exists: - return await Conversation.create(key, base_url, db) + return Conversation( + key, bot_name, history, last_messages, + self.client, self.model, self.db, + ) + if create_if_missing: + return await self.new_conversation(key, self.default_prompt) return None - @classmethod - async def create(cls, key, base_url, db, prompt=None): - client = AsyncOpenAI(base_url=base_url, api_key=API_KEY) - if not prompt: - convo = cls(client, DEFAULT_NAME, DEFAULT_SYSTEM_PROMPT, key, db) - else: - name = await cls.get_name(client, prompt) - convo = cls(client, name, prompt, key, db) + async def new_conversation(self, key, prompt = None): + """Creates a new Conversation with key based on given prompt.""" + prompt = prompt or self.default_prompt + name = await get_name(self.client, self.model, prompt) + history = [{"role": "system", "content": prompt}] + last_messages = [] + convo = Conversation( + key, name, history, last_messages, self.client, self.model, self.db, + ) await convo.save() return convo - @classmethod - async def get_name(self, client, system_prompt): - name_response = await client.chat.completions.create( - model=MODEL, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": NAME_PROMPT} - ], - ) - return name_response.choices[0].message.content.split('\n')[0] + +class Conversation: + """Holds data about a conversation thread.""" + def __init__( + self, convo_id, name, history, last_messages, api_client, model, db, + ): + self.id = convo_id + self.bot_name = name + self.history = history + self.last_messages = last_messages + self.client = api_client + self.model = model + self.db = db + + async def save(self): + """Saves the conversation to the DB.""" + self.db.save(self.id, self.history, self.bot_name, self.last_messages) def add_message_pair(self, user, assistant): + """Adds a user/assistant convesation turn pair.""" self.history.extend([ {"role": "user", "content": user}, {"role": "assistant", "content": assistant}, ]) async def pop(self): + """Removes one user/assistant converation turn pair.""" if len(self.history) >= 3: self.history = self.history[:-2] await self.save() async def update_prompt(self, prompt): + """Changes current prompt to a new one, keeping the rest of history.""" self.history[0] = {"role": "system", "content": prompt} - self.bot_name = await self.get_name(self.client, prompt) + self.bot_name = await get_name(self.client, self.model, prompt) await self.save() async def generate(self, text, media=tuple()): + """Generates next assistant conversation turn.""" # prepare text part if text: openai_content = [{"type": "text", "text": text}] @@ -98,8 +105,7 @@ class Conversation: continue try: async with session.get(url) as resp: - if resp.status != 200: - raise IOError(f"{url} --> {resp.status}") + resp.raise_for_status() image_data = await resp.read() b64_image = base64.b64encode(image_data).decode('utf-8') b64_url = f"data:{content_type};base64,{b64_image}" @@ -120,10 +126,10 @@ class Conversation: return response async def regenerate(self): + """Regenerates the last assistant turn.""" llm_response = await self.client.chat.completions.create( model=MODEL, messages=self.history[:-1] ) response = llm_response.choices[0].message.content self.history[-1] = {"role": "assistant", "content": response} return response - diff --git a/database.py b/database.py index 3e41353..0108873 100644 --- a/database.py +++ b/database.py @@ -7,7 +7,7 @@ class Database: self._create_table() @classmethod - def get(cls, db_path='conversations.db'): + def get(cls, db_path): """Creates and returns a connected Database instance.""" print(f"Initializing DB connection to: {db_path}") return Database(sqlite3.connect(db_path)) @@ -26,7 +26,10 @@ class Database: def get_conversation(self, conversation_id): with self.conn: cursor = self.conn.cursor() - cursor.execute("SELECT history, bot_name, last_messages FROM conversations WHERE id = ?", (conversation_id,)) + cursor.execute( + "SELECT history, bot_name, last_messages FROM conversations WHERE id = ?", + (conversation_id,) + ) row = cursor.fetchone() if row: history = json.loads(row[0]) @@ -38,11 +41,16 @@ class Database: def save(self, conversation_id, history, bot_name, last_messages): with self.conn: self.conn.execute( - "INSERT OR REPLACE INTO conversations (id, history, bot_name, last_messages) VALUES (?, ?, ?, ?)", - (conversation_id, json.dumps(history), bot_name, json.dumps(last_messages)) + "INSERT OR REPLACE INTO conversations " + "(id, history, bot_name, last_messages) VALUES (?, ?, ?, ?)", + ( + conversation_id, json.dumps(history), + bot_name, json.dumps(last_messages) + ), ) def delete(self, conversation_id): with self.conn: - self.conn.execute("DELETE FROM conversations WHERE id = ?", (conversation_id,)) - + self.conn.execute( + "DELETE FROM conversations WHERE id = ?", (conversation_id,), + )