go back to sync sqlite; async sqlite leaves hanging processes

This commit is contained in:
2025-08-22 00:51:09 -07:00
parent ad9c069993
commit fd08420f26
3 changed files with 36 additions and 53 deletions

4
bot.py
View File

@@ -28,7 +28,7 @@ class AoiBot(commands.Bot):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
async def setup_hook(self): async def setup_hook(self):
self.db = await Database.get() self.db = Database.get()
intents = discord.Intents.default() intents = discord.Intents.default()
intents.messages = True intents.messages = True
intents.message_content = True intents.message_content = True
@@ -169,7 +169,7 @@ async def newchat(interaction: discord.Interaction, prompt: str = None):
name="changeprompt", name="changeprompt",
description="Change the current chat's system prompt." description="Change the current chat's system prompt."
) )
async def newchat(interaction: discord.Interaction, prompt: str): async def changeprompt(interaction: discord.Interaction, prompt: str):
await interaction.response.defer() await interaction.response.defer()
channel_id = interaction.channel_id channel_id = interaction.channel_id
conversation = await Conversation.get(channel_id, args.base_url, bot.db) conversation = await Conversation.get(channel_id, args.base_url, bot.db)

View File

@@ -1,34 +1,33 @@
import aiosqlite import sqlite3
import json import json
class Database: class Database:
def __init__(self, db_path='conversations.db'): def __init__(self, db_conn):
self.db_path = db_path self.conn = db_conn
self.conn = None self._create_table()
@classmethod @classmethod
async def get(cls, db_path='conversations.db'): def get(cls, db_path='conversations.db'):
"""Asynchronously creates and returns a connected Database instance.""" """Creates and returns a connected Database instance."""
print(f"Initializing DB connection to: {db_path}") print(f"Initializing DB connection to: {db_path}")
db = Database(db_path) return Database(sqlite3.connect(db_path))
db.conn = await aiosqlite.connect(db.db_path)
await db._create_table()
return db
async def _create_table(self): def _create_table(self):
await self.conn.execute(""" with self.conn:
CREATE TABLE IF NOT EXISTS conversations ( self.conn.execute("""
id TEXT PRIMARY KEY, CREATE TABLE IF NOT EXISTS conversations (
history TEXT NOT NULL, id TEXT PRIMARY KEY,
bot_name TEXT NOT NULL, history TEXT NOT NULL,
last_messages TEXT NOT NULL bot_name TEXT NOT NULL,
) last_messages TEXT NOT NULL
""") )
await self.conn.commit() """)
async def get_conversation(self, conversation_id): def get_conversation(self, conversation_id):
async with self.conn.execute("SELECT history, bot_name, last_messages FROM conversations WHERE id = ?", (conversation_id,)) as cursor: with self.conn:
row = await cursor.fetchone() cursor = self.conn.cursor()
cursor.execute("SELECT history, bot_name, last_messages FROM conversations WHERE id = ?", (conversation_id,))
row = cursor.fetchone()
if row: if row:
history = json.loads(row[0]) history = json.loads(row[0])
bot_name = row[1] bot_name = row[1]
@@ -36,31 +35,14 @@ class Database:
return history, bot_name, last_messages return history, bot_name, last_messages
return None return None
async def save(self, conversation_id, history, bot_name, last_messages): def save(self, conversation_id, history, bot_name, last_messages):
await self.conn.execute( with self.conn:
"INSERT OR REPLACE INTO conversations (id, history, bot_name, last_messages) VALUES (?, ?, ?, ?)", self.conn.execute(
(conversation_id, json.dumps(history), bot_name, json.dumps(last_messages)) "INSERT OR REPLACE INTO conversations (id, history, bot_name, last_messages) VALUES (?, ?, ?, ?)",
) (conversation_id, json.dumps(history), bot_name, json.dumps(last_messages))
await self.conn.commit() )
async def delete(self, conversation_id): def delete(self, conversation_id):
await self.conn.execute("DELETE FROM conversations WHERE id = ?", (conversation_id,)) with self.conn:
await self.conn.commit() self.conn.execute("DELETE FROM conversations WHERE id = ?", (conversation_id,))
async def close(self):
if self.conn:
await self.conn.close()
self.conn = None
# To use this, you would typically do this in your main application file:
#
# import asyncio
# from database import Database
#
# async def main():
# db = await Database.get()
# # now you can use db.get_conversation, db.save, etc.
# await db.close()
#
# if __name__ == "__main__":
# asyncio.run(main())

View File

@@ -28,13 +28,13 @@ class Conversation:
) )
async def save(self): async def save(self):
await self.db.save( self.db.save(
self.id, self.history, self.bot_name, self.last_messages self.id, self.history, self.bot_name, self.last_messages
) )
@classmethod @classmethod
async def get(cls, key, base_url, db): async def get(cls, key, base_url, db):
convo_data = await db.get_conversation(key) convo_data = db.get_conversation(key)
if convo_data: if convo_data:
history, bot_name, last_messages = convo_data history, bot_name, last_messages = convo_data
client = AsyncOpenAI(base_url=base_url, api_key=API_KEY) client = AsyncOpenAI(base_url=base_url, api_key=API_KEY)
@@ -80,6 +80,7 @@ class Conversation:
async def update_prompt(self, prompt): async def update_prompt(self, prompt):
self.history[0] = {"role": "system", "content": prompt} self.history[0] = {"role": "system", "content": prompt}
self.bot_name = await self.get_name(self.client, prompt) self.bot_name = await self.get_name(self.client, prompt)
await self.save()
async def generate(self, text, media=tuple()): async def generate(self, text, media=tuple()):
# prepare text part # prepare text part