diff --git a/bot.py b/bot.py index 8275776..6162e86 100644 --- a/bot.py +++ b/bot.py @@ -1,3 +1,4 @@ +import collections import discord from discord.ext import commands from openai import AsyncOpenAI @@ -5,28 +6,23 @@ import os import base64 import aiohttp import argparse -from dataclasses import dataclass, field from typing import List, Dict, Any # --- Configuration --- OPENAI_API_KEY = "eh" +MODEL = "p620" DEFAULT_SYSTEM_PROMPT = "you are a catboy named Aoi with dark blue fur and is a tsundere" +DEFAULT_NAME = "Aoi" # --- Data Structures --- -@dataclass class Conversation: - channel_id: int - history: List[Dict[str, Any]] = field(default_factory=list) + def __init__(self, prompt, name): + self.history = [{"role": "system", "content": prompt}] + self.bot_name = name - def add_message(self, role: str, content: Any): + def add_message(self, role, content): self.history.append({"role": role, "content": content}) - def reset_history(self, system_prompt: str = DEFAULT_SYSTEM_PROMPT): - self.history = [{"role": "system", "content": system_prompt}] - - def get_history(self): - return self.history - # --- Command Line Arguments --- parser = argparse.ArgumentParser(description="Aoi Discord Bot") parser.add_argument('--base_url', type=str, required=True, @@ -42,7 +38,10 @@ intents.message_content = True bot = commands.Bot(command_prefix="/", intents=intents) # --- Data Storage --- -conversation_history: Dict[int, Conversation] = {} # Keyed by channel ID +# Keyed by channel ID +conversation_history: Dict[int, Conversation] = collections.defaultdict( + lambda: Conversation(prompt=DEFAULT_SYSTEM_PROMPT, name=DEFAULT_NAME) +) # --- OpenAI Client --- client = AsyncOpenAI( @@ -50,6 +49,20 @@ client = AsyncOpenAI( api_key=OPENAI_API_KEY, ) +# --- Helpers --- +async def get_user_from_id(ctx, userid): + if ctx.guild: + user = await ctx.guild.fetch_member(userid) + else: + user = await bot.fetch_user(userid) + return user.display_name + +async def get_user_from_mention(ctx, mention): + match = re.findall(r"<@!?(\d+)>", mention) + if not match: + return mention + return await get_user_from_id(ctx, int(match[0])) + # --- Bot Events --- @bot.event async def on_ready(): @@ -63,14 +76,14 @@ async def on_message(message): return if bot.user.mentioned_in(message): + bot_tag = f'<@{bot.user.id}>' channel_id = message.channel.id - user_message_text = message.content.replace(f'<@!{bot.user.id}>', 'Aoi').strip() - - if channel_id not in conversation_history: - conversation_history[channel_id] = Conversation(channel_id=channel_id) - conversation_history[channel_id].reset_history() - conversation = conversation_history[channel_id] + user_message_text = message.content + if user_message_text.startswith(bot_tag): + user_message_text = user_message_text[len(bot_tag):] + user_message_text = user_message_text.replace(bot_tag, conversation.bot_name).strip() + print(f'> {message.author.name}: {user_message_text}') # Prepare content for OpenAI API openai_content = [] @@ -81,19 +94,19 @@ async def on_message(message): async with aiohttp.ClientSession() as session: for attachment in message.attachments: if attachment.content_type and "image" in attachment.content_type: - try: - async with session.get(attachment.url) as resp: - if resp.status == 200: - image_data = await resp.read() - base64_image = base64.b64encode(image_data).decode('utf-8') - image_url = f"data:{attachment.content_type};base64,{base64_image}" - openai_content.append({ - "type": "image_url", - "image_url": {"url": image_url} - }) - except Exception as e: - print(f"Error downloading or processing attachment: {e}") - + async with message.channel.typing(): + try: + async with session.get(attachment.url) as resp: + if resp.status == 200: + image_data = await resp.read() + base64_image = base64.b64encode(image_data).decode('utf-8') + image_url = f"data:{attachment.content_type};base64,{base64_image}" + openai_content.append({ + "type": "image_url", + "image_url": {"url": image_url} + }) + except Exception as e: + print(f"Error downloading or processing attachment: {e}") if not openai_content: # Don't send empty messages return @@ -105,12 +118,11 @@ async def on_message(message): else: conversation.add_message("user", openai_content) - try: async with message.channel.typing(): response = await client.chat.completions.create( - model="gpt-4", # Or any other model you are using - messages=conversation.get_history() + model=MODEL, + messages=conversation.history, ) bot_response = response.choices[0].message.content conversation.add_message("assistant", bot_response) @@ -129,20 +141,17 @@ async def on_message(message): @bot.tree.command(name="newchat", description="Start a new chat with a new system prompt.") async def newchat(interaction: discord.Interaction, prompt: str = None): channel_id = interaction.channel_id - - system_prompt = prompt - if system_prompt is None: - system_prompt = DEFAULT_SYSTEM_PROMPT - - if channel_id not in conversation_history: - conversation_history[channel_id] = Conversation(channel_id=channel_id) - - conversation_history[channel_id].reset_history(system_prompt) - - if prompt is None: - await interaction.response.send_message("Starting a new chat with the default prompt.") - else: - await interaction.response.send_message(f'Starting a new chat with the prompt: "{prompt}"') + prompt = prompt or DEFAULT_SYSTEM_PROMPT + name_response = await client.chat.completions.create( + model=MODEL, + messages=[ + {"role": "system", "content": prompt}, + {"role": "user", "content": "reply with your name, nothing else, no punctuation"} + ], + ) + name = name_response.choices[0].message.content + conversation_history[channel_id] = Conversation(prompt=prompt, name=name) + await interaction.response.send_message(f'Starting a new chat with the prompt: "{prompt}"') # --- Running the Bot ---