Compare commits

...

2 Commits

Author SHA1 Message Date
4900778a97 change " to ' 2025-08-23 23:42:00 -07:00
6f001f2924 refactor all parameters to commandline arguments 2025-08-23 23:34:31 -07:00
3 changed files with 137 additions and 111 deletions

114
bot.py
View File

@@ -1,38 +1,55 @@
import collections import collections
import discord import discord
from discord.ext import commands from discord.ext import commands
from openai import AsyncOpenAI
import os import os
import argparse import argparse
from typing import List, Dict, Any from typing import List, Dict, Any
from llm_client import Conversation from conversations import Conversation, ConversationManager
from database import Database from database import Database
# --- Configuration --- # --- Configuration ---
DEFAULT_AVATAR = "https://cdn.discordapp.com/avatars/1406466525858369716/f1dfeaf2a1c361dbf981e2e899c7f981?size=256" DEFAULT_AVATAR = 'https://cdn.discordapp.com/avatars/1406466525858369716/f1dfeaf2a1c361dbf981e2e899c7f981?size=256'
# --- Command Line Arguments --- # --- Command Line Arguments ---
parser = argparse.ArgumentParser(description="Aoi Discord Bot") parser = argparse.ArgumentParser(description='Aoi Discord Bot')
parser.add_argument( parser.add_argument(
'--base_url', type=str, required=True, '--base_url', default='http://localhost:8080/v1',
help='The base URL for the OpenAI API.', help='The base URL for the OpenAI API server.',
) )
parser.add_argument( parser.add_argument(
'--discord_token', type=str, required=True, '--api_key', default='', help='The API key for OpenAI API.',
help='The Discord bot token.', )
parser.add_argument(
'--model', default='', help='The model to use from OpenAI API.',
)
parser.add_argument(
'--default_prompt',
default='you are a catboy named Aoi with dark blue fur and is a tsundere',
help='Default system prompt when not given in chat.',
)
parser.add_argument(
'--db', default='conversations.db', help='SQLite DB to use.',
)
parser.add_argument(
'--discord_token', required=True, help='The Discord bot token.',
) )
args = parser.parse_args() args = parser.parse_args()
# --- Bot Setup --- # --- Bot Setup ---
class AoiBot(commands.Bot): class AoiBot(commands.Bot):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def setup_hook(self): async def setup_hook(self):
self.db = Database.get() db = Database.get(args.db)
openai = AsyncOpenAI(base_url=args.base_url, api_key=args.api_key)
self.manager = ConversationManager(
openai, args.model, db, args.default_prompt,
)
intents = discord.Intents.default() intents = discord.Intents.default()
intents.messages = True intents.messages = True
intents.message_content = True intents.message_content = True
bot = AoiBot(command_prefix="/", intents=intents) bot = AoiBot(command_prefix='/', intents=intents)
# --- Helpers --- # --- Helpers ---
@@ -51,8 +68,8 @@ async def discord_send(channel, text, name, avatar=DEFAULT_AVATAR):
else: else:
message = await channel.send(content=chunk) message = await channel.send(content=chunk)
messages.append(message.id) messages.append(message.id)
await message.add_reaction("🔁") await message.add_reaction('🔁')
await message.add_reaction("") await message.add_reaction('')
return messages return messages
async def webhook(channel): async def webhook(channel):
@@ -68,8 +85,8 @@ async def clear_reactions(channel, message_ids):
for message_id in message_ids: for message_id in message_ids:
try: try:
message = await channel.fetch_message(message_id) message = await channel.fetch_message(message_id)
await message.clear_reaction("🔁") await message.clear_reaction('🔁')
await message.clear_reaction("") await message.clear_reaction('')
except (discord.NotFound, discord.Forbidden): except (discord.NotFound, discord.Forbidden):
pass # Ignore if message is not found or we don't have perms pass # Ignore if message is not found or we don't have perms
@@ -90,7 +107,7 @@ async def on_message(message):
bot_tag = f'<@{bot.user.id}>' bot_tag = f'<@{bot.user.id}>'
channel = message.channel channel = message.channel
conversation = await Conversation.get(channel.id, args.base_url, bot.db) conversation = await bot.manager.get(channel.id)
user_message = message.content user_message = message.content
if user_message.startswith(bot_tag): if user_message.startswith(bot_tag):
user_message = user_message[len(bot_tag):] user_message = user_message[len(bot_tag):]
@@ -111,77 +128,73 @@ async def on_message(message):
) )
await conversation.save() await conversation.save()
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") print(f'An error occurred: {e}')
await message.reply("Sorry, I had a little hiccup. Baka!") await message.reply('Sorry, I had a little hiccup. Baka!')
@bot.event @bot.event
async def on_reaction_add(reaction, user): async def on_reaction_add(reaction, user):
if reaction.emoji not in ("🔁", "") or user == bot.user: if reaction.emoji not in ('🔁', '') or user == bot.user:
return return
message = reaction.message message = reaction.message
channel = message.channel channel = message.channel
conversation = await Conversation.get(channel.id, args.base_url, bot.db) conversation = await bot.manager.get(channel.id)
if message.id not in conversation.last_messages: if message.id not in conversation.last_messages:
await reaction.clear() await reaction.clear()
return return
print(f"_ {user}: {reaction}") print(f'_ {user}: {reaction}')
try: try:
async with channel.typing(): try:
try: messages = [
messages = [ await channel.fetch_message(message_id)
await channel.fetch_message(message_id) for message_id in conversation.last_messages
for message_id in conversation.last_messages ]
] except (discord.NotFound, discord.Forbidden):
except (discord.NotFound, discord.Forbidden) as e: # don't do anything if any message in the list is not found
# don't do anything if any message in the list is not found await reaction.clear()
await reaction.clear() return
return for message in messages:
for message in messages: await message.delete()
await message.delete()
if reaction.emoji == "": if reaction.emoji == '':
await conversation.pop() await conversation.pop()
elif reaction.emoji == "🔁": elif reaction.emoji == '🔁':
async with channel.typing():
response = await conversation.regenerate() response = await conversation.regenerate()
conversation.last_messages = await discord_send( conversation.last_messages = await discord_send(
channel, response, conversation.bot_name, channel, response, conversation.bot_name,
) )
await conversation.save() await conversation.save()
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") print(f'An error occurred: {e}')
await message.reply("Sorry, I had a little hiccup. Baka!") await message.reply('Sorry, I had a little hiccup. Baka!')
# --- Slash Commands --- # --- Slash Commands ---
@bot.tree.command( @bot.tree.command(
name="newchat", name='newchat',
description="Start a new chat with an optional system prompt." description='Start a new chat with an optional system prompt.'
) )
async def newchat(interaction: discord.Interaction, prompt: str = None): async def newchat(interaction: discord.Interaction, prompt: str = None):
await interaction.response.defer() await interaction.response.defer()
channel_id = interaction.channel_id channel_id = interaction.channel_id
old_convo = await Conversation.get( old_convo = await bot.manager.get(channel_id, create_if_missing=False)
channel_id, args.base_url, bot.db, create_if_not_exist=False,
)
if old_convo: if old_convo:
await clear_reactions(interaction.channel, old_convo.last_messages) await clear_reactions(interaction.channel, old_convo.last_messages)
conversation = await Conversation.create( conversation = await bot.manager.new_conversation(channel_id, prompt)
channel_id, args.base_url, bot.db, prompt
)
await interaction.followup.send( await interaction.followup.send(
f'Starting a new chat with {conversation.bot_name}: ' f'Starting a new chat with {conversation.bot_name}: '
f'"{conversation.history[0]["content"]}"' f'"{conversation.history[0]["content"]}"'
) )
@bot.tree.command( @bot.tree.command(
name="changeprompt", name='changeprompt',
description="Change the current chat's system prompt." description='Change the system prompt of the current conversation.'
) )
async def changeprompt(interaction: discord.Interaction, prompt: str): async def changeprompt(interaction: discord.Interaction, prompt: str):
await interaction.response.defer() await interaction.response.defer()
channel_id = interaction.channel_id channel_id = interaction.channel_id
conversation = await Conversation.get(channel_id, args.base_url, bot.db) conversation = await bot.manager.get(channel_id)
await conversation.update_prompt(prompt) await conversation.update_prompt(prompt)
await interaction.followup.send( await interaction.followup.send(
f'Now chatting with {conversation.bot_name}: "{prompt}"' f'Now chatting with {conversation.bot_name}: "{prompt}"'
@@ -189,6 +202,5 @@ async def changeprompt(interaction: discord.Interaction, prompt: str):
# --- Running the Bot --- # --- Running the Bot ---
if __name__ == "__main__": if __name__ == '__main__':
bot.run(args.discord_token) bot.run(args.discord_token)

View File

@@ -1,90 +1,97 @@
import aiohttp import aiohttp
import base64 import base64
from openai import AsyncOpenAI
from database import Database from database import Database
API_KEY = "eh" API_KEY = "eh"
MODEL = "p620" MODEL = "p620"
DEFAULT_NAME = "Aoi" DEFAULT_NAME = "Aoi"
DEFAULT_SYSTEM_PROMPT = (
"you are a catboy named Aoi with dark blue fur and is a tsundere"
)
NAME_PROMPT = "reply with your name, nothing else, no punctuation" NAME_PROMPT = "reply with your name, nothing else, no punctuation"
class Conversation: async def get_name(client, model, prompt):
def __init__(self, client, name, prompt, convo_id, db): """Generates an assistant name for the given prompt."""
self.history = [{"role": "system", "content": prompt}] name_response = await client.chat.completions.create(
self.bot_name = name model=model,
self.last_messages = [] messages=[
self.client = client {"role": "system", "content": prompt},
self.id = convo_id {"role": "user", "content": NAME_PROMPT}
],
)
return name_response.choices[0].message.content.split('\n')[0]
class ConversationManager:
"""Creates and retrieves Conversations."""
def __init__(self, openai_client, model, db, default_prompt):
self.model = model
self.client = openai_client
self.db = db self.db = db
self.default_prompt = default_prompt
def __str__(self): async def get(self, key, create_if_missing=True):
return ( """Gets a conversation based on |key|, optionally create when not found."""
f"Conversation({self.bot_name}, {self.last_messages}, " convo_data = self.db.get_conversation(key)
f"{self.history}"
)
async def save(self):
self.db.save(
self.id, self.history, self.bot_name, self.last_messages
)
@classmethod
async def get(cls, key, base_url, db, create_if_not_exist=True):
convo_data = db.get_conversation(key)
if convo_data: if convo_data:
history, bot_name, last_messages = convo_data history, bot_name, last_messages = convo_data
client = AsyncOpenAI(base_url=base_url, api_key=API_KEY) return Conversation(
convo = cls(client, bot_name, history[0]['content'], key, db) key, bot_name, history, last_messages,
convo.history = history self.client, self.model, self.db,
convo.last_messages = last_messages )
return convo if create_if_missing:
if create_if_not_exists: return await self.new_conversation(key, self.default_prompt)
return await Conversation.create(key, base_url, db)
return None return None
@classmethod async def new_conversation(self, key, prompt = None):
async def create(cls, key, base_url, db, prompt=None): """Creates a new Conversation with key based on given prompt."""
client = AsyncOpenAI(base_url=base_url, api_key=API_KEY) prompt = prompt or self.default_prompt
if not prompt: name = await get_name(self.client, self.model, prompt)
convo = cls(client, DEFAULT_NAME, DEFAULT_SYSTEM_PROMPT, key, db) history = [{"role": "system", "content": prompt}]
else: last_messages = []
name = await cls.get_name(client, prompt) convo = Conversation(
convo = cls(client, name, prompt, key, db) key, name, history, last_messages, self.client, self.model, self.db,
)
await convo.save() await convo.save()
return convo return convo
@classmethod
async def get_name(self, client, system_prompt): class Conversation:
name_response = await client.chat.completions.create( """Holds data about a conversation thread."""
model=MODEL, def __init__(
messages=[ self, convo_id, name, history, last_messages, api_client, model, db,
{"role": "system", "content": system_prompt}, ):
{"role": "user", "content": NAME_PROMPT} self.id = convo_id
], self.bot_name = name
) self.history = history
return name_response.choices[0].message.content.split('\n')[0] self.last_messages = last_messages
self.client = api_client
self.model = model
self.db = db
async def save(self):
"""Saves the conversation to the DB."""
self.db.save(self.id, self.history, self.bot_name, self.last_messages)
def add_message_pair(self, user, assistant): def add_message_pair(self, user, assistant):
"""Adds a user/assistant convesation turn pair."""
self.history.extend([ self.history.extend([
{"role": "user", "content": user}, {"role": "user", "content": user},
{"role": "assistant", "content": assistant}, {"role": "assistant", "content": assistant},
]) ])
async def pop(self): async def pop(self):
"""Removes one user/assistant converation turn pair."""
if len(self.history) >= 3: if len(self.history) >= 3:
self.history = self.history[:-2] self.history = self.history[:-2]
await self.save() await self.save()
async def update_prompt(self, prompt): async def update_prompt(self, prompt):
"""Changes current prompt to a new one, keeping the rest of history."""
self.history[0] = {"role": "system", "content": prompt} self.history[0] = {"role": "system", "content": prompt}
self.bot_name = await self.get_name(self.client, prompt) self.bot_name = await get_name(self.client, self.model, prompt)
await self.save() await self.save()
async def generate(self, text, media=tuple()): async def generate(self, text, media=tuple()):
"""Generates next assistant conversation turn."""
# prepare text part # prepare text part
if text: if text:
openai_content = [{"type": "text", "text": text}] openai_content = [{"type": "text", "text": text}]
@@ -98,8 +105,7 @@ class Conversation:
continue continue
try: try:
async with session.get(url) as resp: async with session.get(url) as resp:
if resp.status != 200: resp.raise_for_status()
raise IOError(f"{url} --> {resp.status}")
image_data = await resp.read() image_data = await resp.read()
b64_image = base64.b64encode(image_data).decode('utf-8') b64_image = base64.b64encode(image_data).decode('utf-8')
b64_url = f"data:{content_type};base64,{b64_image}" b64_url = f"data:{content_type};base64,{b64_image}"
@@ -120,10 +126,10 @@ class Conversation:
return response return response
async def regenerate(self): async def regenerate(self):
"""Regenerates the last assistant turn."""
llm_response = await self.client.chat.completions.create( llm_response = await self.client.chat.completions.create(
model=MODEL, messages=self.history[:-1] model=MODEL, messages=self.history[:-1]
) )
response = llm_response.choices[0].message.content response = llm_response.choices[0].message.content
self.history[-1] = {"role": "assistant", "content": response} self.history[-1] = {"role": "assistant", "content": response}
return response return response

View File

@@ -7,7 +7,7 @@ class Database:
self._create_table() self._create_table()
@classmethod @classmethod
def get(cls, db_path='conversations.db'): def get(cls, db_path):
"""Creates and returns a connected Database instance.""" """Creates and returns a connected Database instance."""
print(f"Initializing DB connection to: {db_path}") print(f"Initializing DB connection to: {db_path}")
return Database(sqlite3.connect(db_path)) return Database(sqlite3.connect(db_path))
@@ -26,7 +26,10 @@ class Database:
def get_conversation(self, conversation_id): def get_conversation(self, conversation_id):
with self.conn: with self.conn:
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute("SELECT history, bot_name, last_messages FROM conversations WHERE id = ?", (conversation_id,)) cursor.execute(
"SELECT history, bot_name, last_messages FROM conversations WHERE id = ?",
(conversation_id,)
)
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:
history = json.loads(row[0]) history = json.loads(row[0])
@@ -38,11 +41,16 @@ class Database:
def save(self, conversation_id, history, bot_name, last_messages): def save(self, conversation_id, history, bot_name, last_messages):
with self.conn: with self.conn:
self.conn.execute( self.conn.execute(
"INSERT OR REPLACE INTO conversations (id, history, bot_name, last_messages) VALUES (?, ?, ?, ?)", "INSERT OR REPLACE INTO conversations "
(conversation_id, json.dumps(history), bot_name, json.dumps(last_messages)) "(id, history, bot_name, last_messages) VALUES (?, ?, ?, ?)",
(
conversation_id, json.dumps(history),
bot_name, json.dumps(last_messages)
),
) )
def delete(self, conversation_id): def delete(self, conversation_id):
with self.conn: with self.conn:
self.conn.execute("DELETE FROM conversations WHERE id = ?", (conversation_id,)) self.conn.execute(
"DELETE FROM conversations WHERE id = ?", (conversation_id,),
)