92 lines
3.7 KiB
Swift
92 lines
3.7 KiB
Swift
// The Swift Programming Language
|
|
// https://docs.swift.org/swift-book
|
|
|
|
import DiscordBM
|
|
import GeminiKit
|
|
|
|
@main
|
|
struct Nahi {
|
|
|
|
static func main() async throws {
|
|
let bot = await BotGatewayManager(
|
|
token: try! String(contentsOfFile: "discord_token.txt", encoding: .utf8).trimmingCharacters(in: .whitespacesAndNewlines),
|
|
presence: .init(activities: [.init(name: "Vibing", type: .competing)], status: .online, afk: false),
|
|
intents: [.guildMessages, .messageContent]
|
|
)
|
|
|
|
let gemini = GeminiKit(configuration: .init(apiKey: try! String(contentsOfFile: "gemini_token.txt", encoding: .utf8).trimmingCharacters(in: .whitespacesAndNewlines)))
|
|
|
|
await withTaskGroup(of: Void.self) { taskGroup in
|
|
taskGroup.addTask {
|
|
await bot.connect()
|
|
}
|
|
|
|
taskGroup.addTask {
|
|
for await event in await bot.events {
|
|
await EventHandler(event: event, client: bot.client, gemini: gemini).handleAsync()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
struct EventHandler: GatewayEventHandler {
|
|
let event: Gateway.Event
|
|
let client: any DiscordClient
|
|
let gemini: GeminiKit
|
|
|
|
func onMessageCreate(_ payload: Gateway.MessageCreate) async throws {
|
|
guard !(payload.author?.bot ?? false) else { return }
|
|
guard payload.mentions.contains(where: { mention in mention.id.rawValue == client.appId!.rawValue }) else { return }
|
|
_ = try await client.triggerTypingIndicator(channelId: payload.channel_id)
|
|
|
|
let streamRes = await Task {
|
|
return try await gemini.streamGenerateContent(model: .gemini20Flash, prompt: payload.content)
|
|
}.result
|
|
|
|
var message: DiscordChannel.Message?
|
|
switch (streamRes) {
|
|
case .success(let stream):
|
|
var resStr = ""
|
|
for try await chunk in stream {
|
|
switch chunk.candidates?.first?.content.parts.first.unsafelyUnwrapped {
|
|
case .text(let text):
|
|
resStr.append(text)
|
|
if resStr.count > 2000 {
|
|
resStr = text
|
|
message = nil
|
|
}
|
|
default:
|
|
print("Unknown content type")
|
|
}
|
|
if let message = message {
|
|
try await client.updateMessage(
|
|
channelId: message.channel_id,
|
|
messageId: message.id,
|
|
payload: .init(content: resStr)
|
|
).guardSuccess()
|
|
} else {
|
|
let messageRes = try await client.createMessage(
|
|
channelId: payload.channel_id,
|
|
payload: .init(content: resStr)
|
|
)
|
|
try messageRes.guardSuccess()
|
|
message = try messageRes.decode()
|
|
}
|
|
if chunk.candidates?.first?.finishReason == nil {
|
|
_ = try await client.triggerTypingIndicator(channelId: payload.channel_id)
|
|
} else {
|
|
let tmpMsg = try await client.createMessage(channelId: payload.channel_id, payload: .init(content: "a")).decode()
|
|
_ = try await client.deleteMessage(channelId: tmpMsg.channel_id, messageId: tmpMsg.id)
|
|
|
|
}
|
|
}
|
|
case .failure(let err):
|
|
try await client.createMessage(
|
|
channelId: payload.channel_id,
|
|
payload: .init(content: "\(err.localizedDescription)")
|
|
).guardSuccess()
|
|
}
|
|
}
|
|
}
|