go back to sync sqlite; async sqlite leaves hanging processes
This commit is contained in:
4
bot.py
4
bot.py
@@ -28,7 +28,7 @@ class AoiBot(commands.Bot):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
async def setup_hook(self):
|
async def setup_hook(self):
|
||||||
self.db = await Database.get()
|
self.db = Database.get()
|
||||||
intents = discord.Intents.default()
|
intents = discord.Intents.default()
|
||||||
intents.messages = True
|
intents.messages = True
|
||||||
intents.message_content = True
|
intents.message_content = True
|
||||||
@@ -169,7 +169,7 @@ async def newchat(interaction: discord.Interaction, prompt: str = None):
|
|||||||
name="changeprompt",
|
name="changeprompt",
|
||||||
description="Change the current chat's system prompt."
|
description="Change the current chat's system prompt."
|
||||||
)
|
)
|
||||||
async def newchat(interaction: discord.Interaction, prompt: str):
|
async def changeprompt(interaction: discord.Interaction, prompt: str):
|
||||||
await interaction.response.defer()
|
await interaction.response.defer()
|
||||||
channel_id = interaction.channel_id
|
channel_id = interaction.channel_id
|
||||||
conversation = await Conversation.get(channel_id, args.base_url, bot.db)
|
conversation = await Conversation.get(channel_id, args.base_url, bot.db)
|
||||||
|
80
database.py
80
database.py
@@ -1,34 +1,33 @@
|
|||||||
import aiosqlite
|
import sqlite3
|
||||||
import json
|
import json
|
||||||
|
|
||||||
class Database:
|
class Database:
|
||||||
def __init__(self, db_path='conversations.db'):
|
def __init__(self, db_conn):
|
||||||
self.db_path = db_path
|
self.conn = db_conn
|
||||||
self.conn = None
|
self._create_table()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get(cls, db_path='conversations.db'):
|
def get(cls, db_path='conversations.db'):
|
||||||
"""Asynchronously creates and returns a connected Database instance."""
|
"""Creates and returns a connected Database instance."""
|
||||||
print(f"Initializing DB connection to: {db_path}")
|
print(f"Initializing DB connection to: {db_path}")
|
||||||
db = Database(db_path)
|
return Database(sqlite3.connect(db_path))
|
||||||
db.conn = await aiosqlite.connect(db.db_path)
|
|
||||||
await db._create_table()
|
|
||||||
return db
|
|
||||||
|
|
||||||
async def _create_table(self):
|
def _create_table(self):
|
||||||
await self.conn.execute("""
|
with self.conn:
|
||||||
CREATE TABLE IF NOT EXISTS conversations (
|
self.conn.execute("""
|
||||||
id TEXT PRIMARY KEY,
|
CREATE TABLE IF NOT EXISTS conversations (
|
||||||
history TEXT NOT NULL,
|
id TEXT PRIMARY KEY,
|
||||||
bot_name TEXT NOT NULL,
|
history TEXT NOT NULL,
|
||||||
last_messages TEXT NOT NULL
|
bot_name TEXT NOT NULL,
|
||||||
)
|
last_messages TEXT NOT NULL
|
||||||
""")
|
)
|
||||||
await self.conn.commit()
|
""")
|
||||||
|
|
||||||
async def get_conversation(self, conversation_id):
|
def get_conversation(self, conversation_id):
|
||||||
async with self.conn.execute("SELECT history, bot_name, last_messages FROM conversations WHERE id = ?", (conversation_id,)) as cursor:
|
with self.conn:
|
||||||
row = await cursor.fetchone()
|
cursor = self.conn.cursor()
|
||||||
|
cursor.execute("SELECT history, bot_name, last_messages FROM conversations WHERE id = ?", (conversation_id,))
|
||||||
|
row = cursor.fetchone()
|
||||||
if row:
|
if row:
|
||||||
history = json.loads(row[0])
|
history = json.loads(row[0])
|
||||||
bot_name = row[1]
|
bot_name = row[1]
|
||||||
@@ -36,31 +35,14 @@ class Database:
|
|||||||
return history, bot_name, last_messages
|
return history, bot_name, last_messages
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def save(self, conversation_id, history, bot_name, last_messages):
|
def save(self, conversation_id, history, bot_name, last_messages):
|
||||||
await self.conn.execute(
|
with self.conn:
|
||||||
"INSERT OR REPLACE INTO conversations (id, history, bot_name, last_messages) VALUES (?, ?, ?, ?)",
|
self.conn.execute(
|
||||||
(conversation_id, json.dumps(history), bot_name, json.dumps(last_messages))
|
"INSERT OR REPLACE INTO conversations (id, history, bot_name, last_messages) VALUES (?, ?, ?, ?)",
|
||||||
)
|
(conversation_id, json.dumps(history), bot_name, json.dumps(last_messages))
|
||||||
await self.conn.commit()
|
)
|
||||||
|
|
||||||
async def delete(self, conversation_id):
|
def delete(self, conversation_id):
|
||||||
await self.conn.execute("DELETE FROM conversations WHERE id = ?", (conversation_id,))
|
with self.conn:
|
||||||
await self.conn.commit()
|
self.conn.execute("DELETE FROM conversations WHERE id = ?", (conversation_id,))
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
if self.conn:
|
|
||||||
await self.conn.close()
|
|
||||||
self.conn = None
|
|
||||||
|
|
||||||
# To use this, you would typically do this in your main application file:
|
|
||||||
#
|
|
||||||
# import asyncio
|
|
||||||
# from database import Database
|
|
||||||
#
|
|
||||||
# async def main():
|
|
||||||
# db = await Database.get()
|
|
||||||
# # now you can use db.get_conversation, db.save, etc.
|
|
||||||
# await db.close()
|
|
||||||
#
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# asyncio.run(main())
|
|
||||||
|
@@ -28,13 +28,13 @@ class Conversation:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def save(self):
|
async def save(self):
|
||||||
await self.db.save(
|
self.db.save(
|
||||||
self.id, self.history, self.bot_name, self.last_messages
|
self.id, self.history, self.bot_name, self.last_messages
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get(cls, key, base_url, db):
|
async def get(cls, key, base_url, db):
|
||||||
convo_data = await db.get_conversation(key)
|
convo_data = db.get_conversation(key)
|
||||||
if convo_data:
|
if convo_data:
|
||||||
history, bot_name, last_messages = convo_data
|
history, bot_name, last_messages = convo_data
|
||||||
client = AsyncOpenAI(base_url=base_url, api_key=API_KEY)
|
client = AsyncOpenAI(base_url=base_url, api_key=API_KEY)
|
||||||
@@ -80,6 +80,7 @@ class Conversation:
|
|||||||
async def update_prompt(self, prompt):
|
async def update_prompt(self, prompt):
|
||||||
self.history[0] = {"role": "system", "content": prompt}
|
self.history[0] = {"role": "system", "content": prompt}
|
||||||
self.bot_name = await self.get_name(self.client, prompt)
|
self.bot_name = await self.get_name(self.client, prompt)
|
||||||
|
await self.save()
|
||||||
|
|
||||||
async def generate(self, text, media=tuple()):
|
async def generate(self, text, media=tuple()):
|
||||||
# prepare text part
|
# prepare text part
|
||||||
|
Reference in New Issue
Block a user