persistent db
This commit is contained in:
		
							
								
								
									
										9
									
								
								bot.py
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								bot.py
									
									
									
									
									
								
							| @@ -73,7 +73,7 @@ async def on_message(message): | |||||||
|  |  | ||||||
|     bot_tag = f'<@{bot.user.id}>' |     bot_tag = f'<@{bot.user.id}>' | ||||||
|     channel = message.channel |     channel = message.channel | ||||||
|     conversation = await Conversation.get(channel.id) |     conversation = await Conversation.get(channel.id, args.base_url) | ||||||
|     user_message = message.content |     user_message = message.content | ||||||
|     if user_message.startswith(bot_tag): |     if user_message.startswith(bot_tag): | ||||||
|         user_message = user_message[len(bot_tag):] |         user_message = user_message[len(bot_tag):] | ||||||
| @@ -91,6 +91,7 @@ async def on_message(message): | |||||||
|         conversation.last_messages = await discord_send( |         conversation.last_messages = await discord_send( | ||||||
|             channel, response, conversation.bot_name, |             channel, response, conversation.bot_name, | ||||||
|         ) |         ) | ||||||
|  |         conversation.save() | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(f"An error occurred: {e}") |         print(f"An error occurred: {e}") | ||||||
|         await message.reply("Sorry, I had a little hiccup. Baka!") |         await message.reply("Sorry, I had a little hiccup. Baka!") | ||||||
| @@ -101,7 +102,7 @@ async def on_reaction_add(reaction, user): | |||||||
|         return |         return | ||||||
|     message = reaction.message |     message = reaction.message | ||||||
|     channel = message.channel |     channel = message.channel | ||||||
|     conversation = await Conversation.get(channel.id) |     conversation = await Conversation.get(channel.id, args.base_url) | ||||||
|     if message.id not in conversation.last_messages: |     if message.id not in conversation.last_messages: | ||||||
|         await reaction.clear() |         await reaction.clear() | ||||||
|         return |         return | ||||||
| @@ -114,7 +115,7 @@ async def on_reaction_add(reaction, user): | |||||||
|                     await channel.fetch_message(message_id) |                     await channel.fetch_message(message_id) | ||||||
|                     for message_id in conversation.last_messages |                     for message_id in conversation.last_messages | ||||||
|                 ] |                 ] | ||||||
|             except (discord.NotFound, discord.Forbidden): |             except (discord.NotFound, discord.Forbidden) as e: | ||||||
|                 # don't do anything if any message in the list is not found |                 # don't do anything if any message in the list is not found | ||||||
|                 await reaction.clear() |                 await reaction.clear() | ||||||
|             for message in messages: |             for message in messages: | ||||||
| @@ -123,6 +124,7 @@ async def on_reaction_add(reaction, user): | |||||||
|             conversation.last_messages = await discord_send( |             conversation.last_messages = await discord_send( | ||||||
|                 channel, response, conversation.bot_name, |                 channel, response, conversation.bot_name, | ||||||
|             ) |             ) | ||||||
|  |             conversation.save() | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(f"An error occurred: {e}") |         print(f"An error occurred: {e}") | ||||||
|         await message.reply("Sorry, I had a little hiccup. Baka!") |         await message.reply("Sorry, I had a little hiccup. Baka!") | ||||||
| @@ -145,3 +147,4 @@ async def newchat(interaction: discord.Interaction, prompt: str = None): | |||||||
| # --- Running the Bot --- | # --- Running the Bot --- | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     bot.run(args.discord_token) |     bot.run(args.discord_token) | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										44
									
								
								database.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								database.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | |||||||
|  | import sqlite3 | ||||||
|  | import json | ||||||
|  |  | ||||||
|  | class Database: | ||||||
|  |     def __init__(self, db_path='conversations.db'): | ||||||
|  |         self.db_path = db_path | ||||||
|  |         self.conn = sqlite3.connect(self.db_path) | ||||||
|  |         self.create_table() | ||||||
|  |  | ||||||
|  |     def create_table(self): | ||||||
|  |         with self.conn: | ||||||
|  |             self.conn.execute(""" | ||||||
|  |                 CREATE TABLE IF NOT EXISTS conversations ( | ||||||
|  |                     id TEXT PRIMARY KEY, | ||||||
|  |                     history TEXT NOT NULL, | ||||||
|  |                     bot_name TEXT NOT NULL, | ||||||
|  |                     last_messages TEXT NOT NULL | ||||||
|  |                 ) | ||||||
|  |             """) | ||||||
|  |  | ||||||
|  |     def get(self, conversation_id): | ||||||
|  |         with self.conn: | ||||||
|  |             cursor = self.conn.cursor() | ||||||
|  |             cursor.execute("SELECT history, bot_name, last_messages FROM conversations WHERE id = ?", (conversation_id,)) | ||||||
|  |             row = cursor.fetchone() | ||||||
|  |             if row: | ||||||
|  |                 history = json.loads(row[0]) | ||||||
|  |                 bot_name = row[1] | ||||||
|  |                 last_messages = json.loads(row[2]) | ||||||
|  |                 return history, bot_name, last_messages | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |     def save(self, conversation_id, history, bot_name, last_messages): | ||||||
|  |         with self.conn: | ||||||
|  |             self.conn.execute( | ||||||
|  |                 "INSERT OR REPLACE INTO conversations (id, history, bot_name, last_messages) VALUES (?, ?, ?, ?)", | ||||||
|  |                 (conversation_id, json.dumps(history), bot_name, json.dumps(last_messages)) | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def delete(self, conversation_id): | ||||||
|  |         with self.conn: | ||||||
|  |             self.conn.execute("DELETE FROM conversations WHERE id = ?", (conversation_id,)) | ||||||
|  |  | ||||||
|  | db = Database() | ||||||
| @@ -1,6 +1,7 @@ | |||||||
| import aiohttp | import aiohttp | ||||||
| import base64 | import base64 | ||||||
| from openai import AsyncOpenAI | from openai import AsyncOpenAI | ||||||
|  | from database import db | ||||||
|  |  | ||||||
| API_KEY = "eh" | API_KEY = "eh" | ||||||
| MODEL = "p620" | MODEL = "p620" | ||||||
| @@ -10,30 +11,44 @@ DEFAULT_SYSTEM_PROMPT = ( | |||||||
| ) | ) | ||||||
| NAME_PROMPT = "reply with your name, nothing else, no punctuation" | NAME_PROMPT = "reply with your name, nothing else, no punctuation" | ||||||
|  |  | ||||||
| conversations = {} |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Conversation: | class Conversation: | ||||||
|     def __init__(self, client, name, prompt): |     def __init__(self, client, name, prompt, convo_id): | ||||||
|         self.history = [{"role": "system", "content": prompt}] |         self.history = [{"role": "system", "content": prompt}] | ||||||
|         self.bot_name = name |         self.bot_name = name | ||||||
|         self.last_messages = [] |         self.last_messages = [] | ||||||
|         self.client = client |         self.client = client | ||||||
|  |         self.id = convo_id | ||||||
|  |  | ||||||
|  |     def __str__(self): | ||||||
|  |         return ( | ||||||
|  |             f"Conversation({self.bot_name}, {self.last_messages}, " | ||||||
|  |             f"{self.history}" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def save(self): | ||||||
|  |         db.save(self.id, self.history, self.bot_name, self.last_messages) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     async def get(cls, key): |     async def get(cls, key, base_url): | ||||||
|         if key not in conversations: |         convo_data = db.get(key) | ||||||
|             conversations[key] = await Conversation.create(args.base_url) |         if convo_data: | ||||||
|         return conversations[key] |             history, bot_name, last_messages = convo_data | ||||||
|  |             client = AsyncOpenAI(base_url=base_url, api_key=API_KEY) | ||||||
|  |             convo = cls(client, bot_name, history[0]['content'], key) | ||||||
|  |             convo.history = history | ||||||
|  |             convo.last_messages = last_messages | ||||||
|  |             return convo | ||||||
|  |         return await Conversation.create(key, base_url) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     async def create(cls, channel_id, base_url, prompt=None): |     async def create(cls, key, base_url, prompt=None): | ||||||
|         client = AsyncOpenAI(base_url=base_url, api_key=API_KEY) |         client = AsyncOpenAI(base_url=base_url, api_key=API_KEY) | ||||||
|         if not prompt: |         if not prompt: | ||||||
|             convo = cls(client, DEFAULT_NAME, DEFAULT_SYSTEM_PROMPT) |             convo = cls(client, DEFAULT_NAME, DEFAULT_SYSTEM_PROMPT, key) | ||||||
|         else: |         else: | ||||||
|             convo = cls(client, await cls.get_name(client, prompt), prompt) |             convo = cls(client, await cls.get_name(client, prompt), prompt, key) | ||||||
|         conversations[channel_id] = convo |         convo.save() | ||||||
|         return convo |         return convo | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
| @@ -96,4 +111,3 @@ class Conversation: | |||||||
|         self.history[-1] = {"role": "assistant", "content": response} |         self.history[-1] = {"role": "assistant", "content": response} | ||||||
|         return response |         return response | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user