go back to sync sqlite; async sqlite leaves hanging processes
This commit is contained in:
4
bot.py
4
bot.py
@@ -28,7 +28,7 @@ class AoiBot(commands.Bot):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
async def setup_hook(self):
|
||||
self.db = await Database.get()
|
||||
self.db = Database.get()
|
||||
intents = discord.Intents.default()
|
||||
intents.messages = True
|
||||
intents.message_content = True
|
||||
@@ -169,7 +169,7 @@ async def newchat(interaction: discord.Interaction, prompt: str = None):
|
||||
name="changeprompt",
|
||||
description="Change the current chat's system prompt."
|
||||
)
|
||||
async def newchat(interaction: discord.Interaction, prompt: str):
|
||||
async def changeprompt(interaction: discord.Interaction, prompt: str):
|
||||
await interaction.response.defer()
|
||||
channel_id = interaction.channel_id
|
||||
conversation = await Conversation.get(channel_id, args.base_url, bot.db)
|
||||
|
80
database.py
80
database.py
@@ -1,34 +1,33 @@
|
||||
import aiosqlite
|
||||
import sqlite3
|
||||
import json
|
||||
|
||||
class Database:
|
||||
def __init__(self, db_path='conversations.db'):
|
||||
self.db_path = db_path
|
||||
self.conn = None
|
||||
def __init__(self, db_conn):
|
||||
self.conn = db_conn
|
||||
self._create_table()
|
||||
|
||||
@classmethod
|
||||
async def get(cls, db_path='conversations.db'):
|
||||
"""Asynchronously creates and returns a connected Database instance."""
|
||||
def get(cls, db_path='conversations.db'):
|
||||
"""Creates and returns a connected Database instance."""
|
||||
print(f"Initializing DB connection to: {db_path}")
|
||||
db = Database(db_path)
|
||||
db.conn = await aiosqlite.connect(db.db_path)
|
||||
await db._create_table()
|
||||
return db
|
||||
return Database(sqlite3.connect(db_path))
|
||||
|
||||
async def _create_table(self):
|
||||
await self.conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS conversations (
|
||||
id TEXT PRIMARY KEY,
|
||||
history TEXT NOT NULL,
|
||||
bot_name TEXT NOT NULL,
|
||||
last_messages TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
await self.conn.commit()
|
||||
def _create_table(self):
|
||||
with self.conn:
|
||||
self.conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS conversations (
|
||||
id TEXT PRIMARY KEY,
|
||||
history TEXT NOT NULL,
|
||||
bot_name TEXT NOT NULL,
|
||||
last_messages TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
|
||||
async def get_conversation(self, conversation_id):
|
||||
async with self.conn.execute("SELECT history, bot_name, last_messages FROM conversations WHERE id = ?", (conversation_id,)) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
def get_conversation(self, conversation_id):
|
||||
with self.conn:
|
||||
cursor = self.conn.cursor()
|
||||
cursor.execute("SELECT history, bot_name, last_messages FROM conversations WHERE id = ?", (conversation_id,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
history = json.loads(row[0])
|
||||
bot_name = row[1]
|
||||
@@ -36,31 +35,14 @@ class Database:
|
||||
return history, bot_name, last_messages
|
||||
return None
|
||||
|
||||
async def save(self, conversation_id, history, bot_name, last_messages):
|
||||
await self.conn.execute(
|
||||
"INSERT OR REPLACE INTO conversations (id, history, bot_name, last_messages) VALUES (?, ?, ?, ?)",
|
||||
(conversation_id, json.dumps(history), bot_name, json.dumps(last_messages))
|
||||
)
|
||||
await self.conn.commit()
|
||||
def save(self, conversation_id, history, bot_name, last_messages):
|
||||
with self.conn:
|
||||
self.conn.execute(
|
||||
"INSERT OR REPLACE INTO conversations (id, history, bot_name, last_messages) VALUES (?, ?, ?, ?)",
|
||||
(conversation_id, json.dumps(history), bot_name, json.dumps(last_messages))
|
||||
)
|
||||
|
||||
async def delete(self, conversation_id):
|
||||
await self.conn.execute("DELETE FROM conversations WHERE id = ?", (conversation_id,))
|
||||
await self.conn.commit()
|
||||
def delete(self, conversation_id):
|
||||
with self.conn:
|
||||
self.conn.execute("DELETE FROM conversations WHERE id = ?", (conversation_id,))
|
||||
|
||||
async def close(self):
|
||||
if self.conn:
|
||||
await self.conn.close()
|
||||
self.conn = None
|
||||
|
||||
# To use this, you would typically do this in your main application file:
|
||||
#
|
||||
# import asyncio
|
||||
# from database import Database
|
||||
#
|
||||
# async def main():
|
||||
# db = await Database.get()
|
||||
# # now you can use db.get_conversation, db.save, etc.
|
||||
# await db.close()
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# asyncio.run(main())
|
||||
|
@@ -28,13 +28,13 @@ class Conversation:
|
||||
)
|
||||
|
||||
async def save(self):
|
||||
await self.db.save(
|
||||
self.db.save(
|
||||
self.id, self.history, self.bot_name, self.last_messages
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get(cls, key, base_url, db):
|
||||
convo_data = await db.get_conversation(key)
|
||||
convo_data = db.get_conversation(key)
|
||||
if convo_data:
|
||||
history, bot_name, last_messages = convo_data
|
||||
client = AsyncOpenAI(base_url=base_url, api_key=API_KEY)
|
||||
@@ -80,6 +80,7 @@ class Conversation:
|
||||
async def update_prompt(self, prompt):
|
||||
self.history[0] = {"role": "system", "content": prompt}
|
||||
self.bot_name = await self.get_name(self.client, prompt)
|
||||
await self.save()
|
||||
|
||||
async def generate(self, text, media=tuple()):
|
||||
# prepare text part
|
||||
|
Reference in New Issue
Block a user