diff --git a/bot.py b/bot.py index a2ad2c2..1280fed 100644 --- a/bot.py +++ b/bot.py @@ -86,7 +86,7 @@ async def on_message(message): if user_message.startswith(bot_tag): user_message = user_message[len(bot_tag):] user_message = user_message.replace(bot_tag, conversation.bot_name).strip() - print(f'> {message.author.name}: {user_message}') + print(f'{channel.id}> {message.author.name}: {user_message}') media = [] if message.attachments: @@ -165,7 +165,21 @@ async def newchat(interaction: discord.Interaction, prompt: str = None): f'Starting a new chat with {conversation.bot_name}: "{prompt}"' ) +@bot.tree.command( + name="changeprompt", + description="Change the current chat's system prompt." +) +async def newchat(interaction: discord.Interaction, prompt: str): + await interaction.response.defer() + channel_id = interaction.channel_id + conversation = await Conversation.get(channel_id, args.base_url, bot.db) + await conversation.update_prompt(prompt) + await interaction.followup.send( + f'Now chatting with {conversation.bot_name}: "{prompt}"' + ) + # --- Running the Bot --- if __name__ == "__main__": bot.run(args.discord_token) + diff --git a/database.py b/database.py index ba79395..c28985a 100644 --- a/database.py +++ b/database.py @@ -6,11 +6,10 @@ class Database: self.db_path = db_path self.conn = None - @staticmethod - async def get(db_path='conversations.db'): - """ - Asynchronously creates and returns a connected Database instance. - """ + @classmethod + async def get(cls, db_path='conversations.db'): + """Asynchronously creates and returns a connected Database instance.""" + print(f"Initializing DB connection to: {db_path}") db = Database(db_path) db.conn = await aiosqlite.connect(db.db_path) await db._create_table() @@ -64,4 +63,4 @@ class Database: # await db.close() # # if __name__ == "__main__": -# asyncio.run(main()) \ No newline at end of file +# asyncio.run(main()) diff --git a/llm_client.py b/llm_client.py index e9a6c68..53f1f4d 100644 --- a/llm_client.py +++ b/llm_client.py @@ -77,6 +77,10 @@ class Conversation: self.history = self.history[:-2] await self.save() + async def update_prompt(self, prompt): + self.history[0] = {"role": "system", "content": prompt} + self.bot_name = await self.get_name(self.client, prompt) + async def generate(self, text, media=tuple()): # prepare text part if text: @@ -119,3 +123,4 @@ class Conversation: response = llm_response.choices[0].message.content self.history[-1] = {"role": "assistant", "content": response} return response +