Compare commits

...

24 Commits

Author SHA1 Message Date
6822d3ca71 requirements.txt 2025-08-23 23:49:46 -07:00
4900778a97 change " to ' 2025-08-23 23:42:00 -07:00
6f001f2924 refactor all parameters to commandline arguments 2025-08-23 23:34:31 -07:00
ede777ee42 clear reactions on /newchat 2025-08-22 18:29:34 -07:00
4ba6e64403 fixed newchat response when there's no prompt 2025-08-22 18:07:25 -07:00
fd08420f26 go back to sync sqlite; async sqlite leaves hanging processes 2025-08-22 00:51:09 -07:00
ad9c069993 changeprompt 2025-08-22 00:21:37 -07:00
8c750b2fbb async sqlite 2025-08-21 22:36:17 -07:00
dd47cc05a6 add deletion and prepopulated reacts 2025-08-21 17:23:20 -07:00
ae7843cf17 persistent db 2025-08-21 11:04:04 -07:00
8437cc6940 store message.id instead of message in reroll history 2025-08-21 10:36:14 -07:00
3d9c5d8e71 fix convo and hook lookup 2025-08-21 10:18:25 -07:00
452bd41e7e refactor llm out of bot 2025-08-21 09:37:35 -07:00
5ecf47a451 beautify 2025-08-21 01:22:30 -07:00
b175019321 reaction to regenerate 2025-08-21 01:16:02 -07:00
005e9ac80b refactor and add last_messages 2025-08-21 01:16:02 -07:00
d6a770e011 refactor and add last_messages 2025-08-21 01:16:02 -07:00
4261991022 refactor and add last_messages 2025-08-21 01:16:02 -07:00
7b052e0ff9 refactor AI out of Discord function 2025-08-21 01:16:02 -07:00
c615726127 fix race condition 2025-08-21 01:16:02 -07:00
e1749dfe92 webhook 2025-08-21 01:16:02 -07:00
e1a8cbf98d name aware 2025-08-21 01:16:02 -07:00
a3bb1a499c feat: introduce Conversation dataclass to encapsulate conversation history 2025-08-21 01:16:02 -07:00
a94daf3654 remove useless tests 2025-08-21 01:16:02 -07:00
5 changed files with 353 additions and 214 deletions

241
bot.py
View File

@@ -1,37 +1,95 @@
import collections
import discord
from discord.ext import commands
from openai import AsyncOpenAI
import os
import base64
import aiohttp
import argparse
from typing import List, Dict, Any
from conversations import Conversation, ConversationManager
from database import Database
# --- Configuration ---
OPENAI_API_KEY = "eh"
DEFAULT_SYSTEM_PROMPT = "you are a catboy named Aoi with dark blue fur and is a tsundere"
DEFAULT_AVATAR = 'https://cdn.discordapp.com/avatars/1406466525858369716/f1dfeaf2a1c361dbf981e2e899c7f981?size=256'
# --- Command Line Arguments ---
parser = argparse.ArgumentParser(description="Aoi Discord Bot")
parser.add_argument('--base_url', type=str, required=True,
help='The base URL for the OpenAI API.')
parser.add_argument('--discord_token', type=str, required=True,
help='The Discord bot token.')
parser = argparse.ArgumentParser(description='Aoi Discord Bot')
parser.add_argument(
'--base_url', default='http://localhost:8080/v1',
help='The base URL for the OpenAI API server.',
)
parser.add_argument(
'--api_key', default='', help='The API key for OpenAI API.',
)
parser.add_argument(
'--model', default='', help='The model to use from OpenAI API.',
)
parser.add_argument(
'--default_prompt',
default='you are a catboy named Aoi with dark blue fur and is a tsundere',
help='Default system prompt when not given in chat.',
)
parser.add_argument(
'--db', default='conversations.db', help='SQLite DB to use.',
)
parser.add_argument(
'--discord_token', required=True, help='The Discord bot token.',
)
args = parser.parse_args()
# --- Bot Setup ---
class AoiBot(commands.Bot):
async def setup_hook(self):
db = Database.get(args.db)
openai = AsyncOpenAI(base_url=args.base_url, api_key=args.api_key)
self.manager = ConversationManager(
openai, args.model, db, args.default_prompt,
)
intents = discord.Intents.default()
intents.messages = True
intents.message_content = True
bot = commands.Bot(command_prefix="/", intents=intents)
bot = AoiBot(command_prefix='/', intents=intents)
# --- Data Storage ---
conversation_history = {} # Keyed by channel ID
# --- OpenAI Client ---
client = AsyncOpenAI(
base_url=args.base_url,
api_key=OPENAI_API_KEY,
)
# --- Helpers ---
async def discord_send(channel, text, name, avatar=DEFAULT_AVATAR):
chunks = [text[i:i+2000] for i in range(0, len(text), 2000)]
messages = []
for chunk in chunks:
if channel.guild:
hook = await webhook(channel)
message = await hook.send(
content=chunk,
username=name,
avatar_url=avatar,
wait=True,
)
else:
message = await channel.send(content=chunk)
messages.append(message.id)
await message.add_reaction('🔁')
await message.add_reaction('')
return messages
async def webhook(channel):
name = f'aoi-{channel.id}'
channel_hooks = [
hook for hook in (await channel.webhooks()) if hook.name == name
]
if not channel_hooks:
return await channel.create_webhook(name=f'aoi-{channel.id}')
return channel_hooks[0]
async def clear_reactions(channel, message_ids):
for message_id in message_ids:
try:
message = await channel.fetch_message(message_id)
await message.clear_reaction('🔁')
await message.clear_reaction('')
except (discord.NotFound, discord.Forbidden):
pass # Ignore if message is not found or we don't have perms
# --- Bot Events ---
@bot.event
@@ -44,84 +102,105 @@ async def on_ready():
async def on_message(message):
if message.author == bot.user:
return
if not bot.user.mentioned_in(message):
return
if bot.user.mentioned_in(message):
channel_id = message.channel.id
user_message_text = message.content.replace(f'<@!{bot.user.id}>', 'Aoi').strip()
bot_tag = f'<@{bot.user.id}>'
channel = message.channel
conversation = await bot.manager.get(channel.id)
user_message = message.content
if user_message.startswith(bot_tag):
user_message = user_message[len(bot_tag):]
user_message = user_message.replace(bot_tag, conversation.bot_name).strip()
print(f'{channel.id}> {message.author.name}: {user_message}')
if channel_id not in conversation_history:
conversation_history[channel_id] = [
{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}
]
media = []
if message.attachments:
for attachment in message.attachments:
media.append((attachment.content_type, attachment.url))
# Prepare content for OpenAI API
openai_content = []
if user_message_text:
openai_content.append({"type": "text", "text": user_message_text})
if message.attachments:
async with aiohttp.ClientSession() as session:
for attachment in message.attachments:
if attachment.content_type and "image" in attachment.content_type:
try:
async with session.get(attachment.url) as resp:
if resp.status == 200:
image_data = await resp.read()
base64_image = base64.b64encode(image_data).decode('utf-8')
image_url = f"data:{attachment.content_type};base64,{base64_image}"
openai_content.append({
"type": "image_url",
"image_url": {"url": image_url}
})
except Exception as e:
print(f"Error downloading or processing attachment: {e}")
if not openai_content: # Don't send empty messages
return
# Add to conversation history
if len(openai_content) == 1 and openai_content[0]['type'] == 'text':
# Keep original format for text-only messages for compatibility
conversation_history[channel_id].append({"role": "user", "content": openai_content[0]['text']})
else:
conversation_history[channel_id].append({"role": "user", "content": openai_content})
try:
async with channel.typing():
response = await conversation.generate(user_message, media)
await clear_reactions(channel, conversation.last_messages)
conversation.last_messages = await discord_send(
channel, response, conversation.bot_name,
)
await conversation.save()
except Exception as e:
print(f'An error occurred: {e}')
await message.reply('Sorry, I had a little hiccup. Baka!')
@bot.event
async def on_reaction_add(reaction, user):
if reaction.emoji not in ('🔁', '') or user == bot.user:
return
message = reaction.message
channel = message.channel
conversation = await bot.manager.get(channel.id)
if message.id not in conversation.last_messages:
await reaction.clear()
return
print(f'_ {user}: {reaction}')
try:
try:
async with message.channel.typing():
response = await client.chat.completions.create(
model="gpt-4", # Or any other model you are using
messages=conversation_history[channel_id]
messages = [
await channel.fetch_message(message_id)
for message_id in conversation.last_messages
]
except (discord.NotFound, discord.Forbidden):
# don't do anything if any message in the list is not found
await reaction.clear()
return
for message in messages:
await message.delete()
if reaction.emoji == '':
await conversation.pop()
elif reaction.emoji == '🔁':
async with channel.typing():
response = await conversation.regenerate()
conversation.last_messages = await discord_send(
channel, response, conversation.bot_name,
)
bot_response = response.choices[0].message.content
conversation_history[channel_id].append({"role": "assistant", "content": bot_response})
await message.reply(bot_response)
except Exception as e:
print(f"An error occurred: {e}")
conversation_history[channel_id].pop() # Remove user message on error
await message.reply("Sorry, I had a little hiccup. Baka!")
await conversation.save()
except Exception as e:
print(f'An error occurred: {e}')
await message.reply('Sorry, I had a little hiccup. Baka!')
# --- Slash Commands ---
@bot.tree.command(name="newchat", description="Start a new chat with a new system prompt.")
@bot.tree.command(
name='newchat',
description='Start a new chat with an optional system prompt.'
)
async def newchat(interaction: discord.Interaction, prompt: str = None):
await interaction.response.defer()
channel_id = interaction.channel_id
system_prompt = prompt
if system_prompt is None:
system_prompt = DEFAULT_SYSTEM_PROMPT
old_convo = await bot.manager.get(channel_id, create_if_missing=False)
if old_convo:
await clear_reactions(interaction.channel, old_convo.last_messages)
conversation = await bot.manager.new_conversation(channel_id, prompt)
await interaction.followup.send(
f'Starting a new chat with {conversation.bot_name}: '
f'"{conversation.history[0]["content"]}"'
)
conversation_history[channel_id] = [
{"role": "system", "content": system_prompt}
]
if prompt is None:
await interaction.response.send_message("Starting a new chat with the default prompt.")
else:
await interaction.response.send_message(f'Starting a new chat with the prompt: "{prompt}"')
@bot.tree.command(
name='changeprompt',
description='Change the system prompt of the current conversation.'
)
async def changeprompt(interaction: discord.Interaction, prompt: str):
await interaction.response.defer()
channel_id = interaction.channel_id
conversation = await bot.manager.get(channel_id)
await conversation.update_prompt(prompt)
await interaction.followup.send(
f'Now chatting with {conversation.bot_name}: "{prompt}"'
)
# --- Running the Bot ---
if __name__ == "__main__":
if __name__ == '__main__':
bot.run(args.discord_token)

135
conversations.py Normal file
View File

@@ -0,0 +1,135 @@
import aiohttp
import base64
from database import Database
API_KEY = "eh"
MODEL = "p620"
DEFAULT_NAME = "Aoi"
NAME_PROMPT = "reply with your name, nothing else, no punctuation"
async def get_name(client, model, prompt):
"""Generates an assistant name for the given prompt."""
name_response = await client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": NAME_PROMPT}
],
)
return name_response.choices[0].message.content.split('\n')[0]
class ConversationManager:
"""Creates and retrieves Conversations."""
def __init__(self, openai_client, model, db, default_prompt):
self.model = model
self.client = openai_client
self.db = db
self.default_prompt = default_prompt
async def get(self, key, create_if_missing=True):
"""Gets a conversation based on |key|, optionally create when not found."""
convo_data = self.db.get_conversation(key)
if convo_data:
history, bot_name, last_messages = convo_data
return Conversation(
key, bot_name, history, last_messages,
self.client, self.model, self.db,
)
if create_if_missing:
return await self.new_conversation(key, self.default_prompt)
return None
async def new_conversation(self, key, prompt = None):
"""Creates a new Conversation with key based on given prompt."""
prompt = prompt or self.default_prompt
name = await get_name(self.client, self.model, prompt)
history = [{"role": "system", "content": prompt}]
last_messages = []
convo = Conversation(
key, name, history, last_messages, self.client, self.model, self.db,
)
await convo.save()
return convo
class Conversation:
"""Holds data about a conversation thread."""
def __init__(
self, convo_id, name, history, last_messages, api_client, model, db,
):
self.id = convo_id
self.bot_name = name
self.history = history
self.last_messages = last_messages
self.client = api_client
self.model = model
self.db = db
async def save(self):
"""Saves the conversation to the DB."""
self.db.save(self.id, self.history, self.bot_name, self.last_messages)
def add_message_pair(self, user, assistant):
"""Adds a user/assistant convesation turn pair."""
self.history.extend([
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
])
async def pop(self):
"""Removes one user/assistant converation turn pair."""
if len(self.history) >= 3:
self.history = self.history[:-2]
await self.save()
async def update_prompt(self, prompt):
"""Changes current prompt to a new one, keeping the rest of history."""
self.history[0] = {"role": "system", "content": prompt}
self.bot_name = await get_name(self.client, self.model, prompt)
await self.save()
async def generate(self, text, media=tuple()):
"""Generates next assistant conversation turn."""
# prepare text part
if text:
openai_content = [{"type": "text", "text": text}]
else:
openai_content = [{"type": "text", "text": "."}]
# prepare images part
async with aiohttp.ClientSession() as session:
for (content_type, url) in media:
if "image" not in content_type:
continue
try:
async with session.get(url) as resp:
resp.raise_for_status()
image_data = await resp.read()
b64_image = base64.b64encode(image_data).decode('utf-8')
b64_url = f"data:{content_type};base64,{b64_image}"
openai_content.append({
"type": "image_url",
"image_url": {"url": b64_url}
})
except Exception as e:
print(f"Error downloading or processing attachment: {e}")
# send request to openai api and return response
request = self.history + [{"role": "user", "content": openai_content}]
llm_response = await self.client.chat.completions.create(
model=MODEL, messages=request,
)
response = llm_response.choices[0].message.content
self.add_message_pair(openai_content, response)
return response
async def regenerate(self):
"""Regenerates the last assistant turn."""
llm_response = await self.client.chat.completions.create(
model=MODEL, messages=self.history[:-1]
)
response = llm_response.choices[0].message.content
self.history[-1] = {"role": "assistant", "content": response}
return response

56
database.py Normal file
View File

@@ -0,0 +1,56 @@
import sqlite3
import json
class Database:
def __init__(self, db_conn):
self.conn = db_conn
self._create_table()
@classmethod
def get(cls, db_path):
"""Creates and returns a connected Database instance."""
print(f"Initializing DB connection to: {db_path}")
return Database(sqlite3.connect(db_path))
def _create_table(self):
with self.conn:
self.conn.execute("""
CREATE TABLE IF NOT EXISTS conversations (
id TEXT PRIMARY KEY,
history TEXT NOT NULL,
bot_name TEXT NOT NULL,
last_messages TEXT NOT NULL
)
""")
def get_conversation(self, conversation_id):
with self.conn:
cursor = self.conn.cursor()
cursor.execute(
"SELECT history, bot_name, last_messages FROM conversations WHERE id = ?",
(conversation_id,)
)
row = cursor.fetchone()
if row:
history = json.loads(row[0])
bot_name = row[1]
last_messages = json.loads(row[2])
return history, bot_name, last_messages
return None
def save(self, conversation_id, history, bot_name, last_messages):
with self.conn:
self.conn.execute(
"INSERT OR REPLACE INTO conversations "
"(id, history, bot_name, last_messages) VALUES (?, ?, ?, ?)",
(
conversation_id, json.dumps(history),
bot_name, json.dumps(last_messages)
),
)
def delete(self, conversation_id):
with self.conn:
self.conn.execute(
"DELETE FROM conversations WHERE id = ?", (conversation_id,),
)

2
requirements.txt Normal file
View File

@@ -0,0 +1,2 @@
discord
openai

View File

@@ -1,133 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch, AsyncMock
import sys
import os
import base64
# Add the parent directory to the Python path to import the bot
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
# Patch sys.argv before importing the bot to prevent argparse errors
with patch.object(sys, 'argv', ['bot.py', '--base_url', 'http://fake.url', '--discord_token', 'fake_token']):
with patch('discord.ext.commands.Bot') as BotMock:
bot_instance = BotMock()
with patch('bot.bot', bot_instance):
import bot
class TestAoiBot(unittest.IsolatedAsyncioTestCase):
def setUp(self):
# Reset conversation history before each test
bot.conversation_history = {}
bot.bot.user = MagicMock()
bot.bot.user.id = 12345
bot.bot.user.mentioned_in = MagicMock(return_value=True)
bot.on_message = AsyncMock()
bot.newchat.callback = AsyncMock()
@patch('bot.AsyncOpenAI')
async def test_on_message_text_only(self, MockOpenAI):
# Mock the OpenAI client and its response
mock_openai_instance = MockOpenAI.return_value
mock_response = MagicMock()
mock_response.choices[0].message.content = "Hello from Aoi!"
mock_openai_instance.chat.completions.create = AsyncMock(return_value=mock_response)
# Mock a Discord message
message = AsyncMock()
message.author = MagicMock()
message.author.bot = False
message.channel = AsyncMock()
message.channel.id = 123
message.content = f"<@!{bot.bot.user.id}> Hello there"
message.attachments = []
# Call the on_message event handler
await bot.on_message(message)
# Assertions
bot.on_message.assert_awaited_once_with(message)
@patch('bot.AsyncOpenAI')
@patch('bot.aiohttp.ClientSession')
async def test_on_message_with_image(self, MockClientSession, MockOpenAI):
# Mock the OpenAI client
mock_openai_instance = MockOpenAI.return_value
mock_response = MagicMock()
mock_response.choices[0].message.content = "I see an image!"
mock_openai_instance.chat.completions.create = AsyncMock(return_value=mock_response)
# Mock aiohttp session to simulate image download
mock_session = MockClientSession.return_value.__aenter__.return_value
mock_resp = mock_session.get.return_value.__aenter__.return_value
mock_resp.status = 200
mock_resp.read = AsyncMock(return_value=b'fake_image_data')
# Mock a Discord message with an attachment
message = AsyncMock()
message.author = MagicMock()
message.author.bot = False
message.channel = AsyncMock()
message.channel.id = 456
message.content = f"<@!{bot.bot.user.id}> Look at this!"
attachment = MagicMock()
attachment.content_type = 'image/jpeg'
attachment.url = 'http://fakeurl.com/image.jpg'
message.attachments = [attachment]
# Call the on_message event handler
await bot.on_message(message)
# Assertions
bot.on_message.assert_awaited_once_with(message)
async def test_newchat_command_with_prompt(self):
# Mock a Discord interaction
interaction = AsyncMock()
interaction.channel_id = 789
prompt = "You are a helpful assistant."
# Call the newchat command
await bot.newchat.callback(interaction, prompt=prompt)
# Assertions
bot.newchat.callback.assert_awaited_once_with(interaction, prompt=prompt)
async def test_newchat_command_no_prompt(self):
# Mock a Discord interaction
interaction = AsyncMock()
interaction.channel_id = 789
# Call the newchat command
await bot.newchat.callback(interaction, prompt=None)
# Assertions
bot.newchat.callback.assert_awaited_once_with(interaction, prompt=None)
@patch('bot.AsyncOpenAI')
async def test_on_message_api_error(self, MockOpenAI):
# Mock the OpenAI client to raise an error
mock_openai_instance = MockOpenAI.return_value
mock_openai_instance.chat.completions.create.side_effect = Exception("API Error")
# Mock a Discord message
message = AsyncMock()
message.author = MagicMock()
message.author.bot = False
message.channel = AsyncMock()
message.channel.id = 123
message.content = f"<@!{bot.bot.user.id}> This will fail"
message.attachments = []
# Call the on_message event handler
await bot.on_message(message)
# Assertions
bot.on_message.assert_awaited_once_with(message)
if __name__ == '__main__':
unittest.main()