From e6426225d7c2f381ce7408fb200eb6d966abf250 Mon Sep 17 00:00:00 2001 From: Andrew Glaze Date: Sun, 22 Mar 2026 13:14:26 -0400 Subject: [PATCH] feat: Resume after disconnect --- Sources/DiscordKit/Bot.swift | 8 +- Sources/DiscordKit/GatewayClient.swift | 130 +++++++++++++++++++------ Sources/DiscordKit/Models.swift | 11 ++- Sources/zundamon/Zundamon.swift | 32 ++---- 4 files changed, 122 insertions(+), 59 deletions(-) diff --git a/Sources/DiscordKit/Bot.swift b/Sources/DiscordKit/Bot.swift index 0c14316..3877309 100644 --- a/Sources/DiscordKit/Bot.swift +++ b/Sources/DiscordKit/Bot.swift @@ -6,20 +6,18 @@ import FoundationNetworking public actor Bot { public let client: DiscordClient let gateway: GatewayClient - let intents: Intents public init(token: String, intents: Intents) async throws { client = DiscordClient(token: token) - self.intents = intents let gatewayURL = try await client.getGatewayURL() - gateway = GatewayClient(gatewayURL: gatewayURL, token: token) + gateway = GatewayClient(gatewayURL: gatewayURL, token: token, intents: intents) } public func connect() async throws { - try await gateway.openConnection(intents: intents) + try await gateway.openConnection() } - public var events: AsyncStream { + public var events: AsyncStream { get async { await gateway.events } } } diff --git a/Sources/DiscordKit/GatewayClient.swift b/Sources/DiscordKit/GatewayClient.swift index f09ee71..ff1cd86 100644 --- a/Sources/DiscordKit/GatewayClient.swift +++ b/Sources/DiscordKit/GatewayClient.swift @@ -4,35 +4,52 @@ import FoundationNetworking #endif actor GatewayClient { - private let ws: URLSessionWebSocketTask + private var ws: URLSessionWebSocketTask private let token: String private(set) var open = false - var sequenceNum: Int? = nil + private var sequenceNum: Int? = nil + private var sessionID: String? = nil + private var resumeURL: URL? = nil + private let intents: Intents + private let gatewayURL: URL + private var hbTask: Task<(), any Error>? = nil - init(gatewayURL: URL, token: String) { + + init(gatewayURL: URL, token: String, intents: Intents) { + self.intents = intents + self.token = token + self.gatewayURL = gatewayURL let queryItems = [URLQueryItem(name: "v", value: "10"), URLQueryItem(name: "encoding", value: "json")] ws = URLSession.shared.webSocketTask(with: gatewayURL.appending(queryItems: queryItems)) - self.token = token } - func openConnection(intents: Intents) async throws { + func openConnection() async throws { ws.resume() - open = true - guard case .hello(let helloMessage) = try await getMessage().d else { throw GatewayError.mismatchedOpcode } - dump(helloMessage) - - let heartbeatTask = Task() { - try await Task.sleep(for: .milliseconds(Int.random(in: 0...helloMessage.heartbeat_interval))) - try await sendHeartbeat() - - while !Task.isCancelled { - try await Task.sleep(for: .milliseconds(helloMessage.heartbeat_interval)) - try await sendHeartbeat() - } - } + try await setupHeartbeat() try await sendIdentify(intents: intents) - _ = await heartbeatTask.result + guard case .ready(let readyMessage) = try await getMessage().d else { throw GatewayError.connectionFailure } + open = true + sessionID = readyMessage.session_id + resumeURL = readyMessage.resume_gateway_url + } + + func setupHeartbeat() async throws { + guard case .hello(let helloMessage) = try await getMessage().d else { throw GatewayError.connectionFailure } + + self.hbTask = Task.detached { [self] in + do { + try await Task.sleep(for: .milliseconds(Int.random(in: 0...helloMessage.heartbeat_interval))) + try await sendHeartbeat() + + while !Task.isCancelled { + try await Task.sleep(for: .milliseconds(helloMessage.heartbeat_interval)) + try await sendHeartbeat() + } + } catch { + print("Heartbeat task canceled") + } + } } func sendIdentify(intents: Intents) async throws { @@ -58,11 +75,15 @@ actor GatewayClient { var gwMessage: GatewayMessage? = nil let json = JSONDecoder() while gwMessage == nil { - let wsMessage = try await ws.receive() - guard case .string(let str) = wsMessage else { throw GatewayError.invalidMessage } - strBuffer.append(str) do { + let wsMessage = try await ws.receive() + //print(wsMessage) + guard case .string(let str) = wsMessage else { throw GatewayError.invalidMessage } + strBuffer.append(str) gwMessage = try json.decode(GatewayMessage.self, from: Data(strBuffer.utf8)) + } catch URLError.networkConnectionLost { + self.open = false + try await reconnect() } catch DecodingError.dataCorrupted { continue } @@ -72,26 +93,75 @@ actor GatewayClient { return gwMessage } + private func reconnect() async throws { + hbTask?.cancel() + do { + try await attemptResume() + } catch { + print(error) + ws.cancel() + let queryItems = [URLQueryItem(name: "v", value: "10"), URLQueryItem(name: "encoding", value: "json")] + ws = URLSession.shared.webSocketTask(with: gatewayURL.appending(queryItems: queryItems)) + try await openConnection() + } + } + + private func attemptResume() async throws { + guard + ws.closeCode.rawValue != 4004 && ws.closeCode.rawValue < 4010, + let resumeURL = resumeURL, + let sessionID = sessionID, + let sequenceNum = sequenceNum + else { + if let closeReason = ws.closeReason { + throw GatewayError.disconected(code: ws.closeCode.rawValue, message: String(data: closeReason, encoding: .utf8)) + } else { + throw GatewayError.disconected(code: ws.closeCode.rawValue, message: "unknown") + } + } + ws.cancel() + let queryItems = [URLQueryItem(name: "v", value: "10"), URLQueryItem(name: "encoding", value: "json")] + ws = URLSession.shared.webSocketTask(with: resumeURL.appending(queryItems: queryItems)) + ws.resume() + try await setupHeartbeat() + + let payload = """ + { + "op": 6, + "d": { + "token": "\(token)", + "session_id": "\(sessionID)", + "seq": \(sequenceNum) + } + } + """ // Im lazy + try await ws.send(.string(payload)) + self.open = true + } + private func sendHeartbeat() async throws { + guard self.open else { print("conn closed, skipping hb"); return } let hbMessage = "{\"op\":1,\"d\":\(sequenceNum == nil ? "null" : String(sequenceNum!))}" try await ws.send(.string(hbMessage)) } - var events: AsyncStream { + var events: AsyncStream { AsyncStream { [self] in - var event: GatewayMessage? = nil - while event == nil { + var payload: GatewayPayload? = nil + while payload == nil { do { - while await !open { + while await !self.open { try await Task.sleep(for: .seconds(1)) } - event = try await getMessage() - if event!.op == 1 { try await sendHeartbeat() } + let event = try await getMessage() + dump(event) + if event.op == 1 { try await sendHeartbeat() } + payload = event.d } catch { print("Error listening to gateway: \(error)") } } - return event! + return payload! } } @@ -101,4 +171,6 @@ public enum GatewayError: Error { case invalidMessage case invalidOpcode case mismatchedOpcode + case disconected(code: Int, message: String?) + case connectionFailure } diff --git a/Sources/DiscordKit/Models.swift b/Sources/DiscordKit/Models.swift index 238e600..761c4e1 100644 --- a/Sources/DiscordKit/Models.swift +++ b/Sources/DiscordKit/Models.swift @@ -50,6 +50,7 @@ public struct SessionStartLimit: Codable, Sendable { public enum GatewayPayload: Decodable, Sendable { case hello(GatewayHello) case messageCreate(MessageCreate) + case ready(GatewayReady) } public struct GatewayMessage: Decodable, Sendable { @@ -77,6 +78,9 @@ public struct GatewayMessage: Decodable, Sendable { case 0 where t == "MESSAGE_CREATE": let messageCreate = try container.decode(MessageCreate.self, forKey: .d) d = .messageCreate(messageCreate) + case 0 where t == "READY": + let ready = try container.decode(GatewayReady.self, forKey: .d) + d = .ready(ready) default: d = nil break @@ -91,8 +95,13 @@ public struct GatewayMessage: Decodable, Sendable { } } +public struct GatewayReady: Codable, Sendable { + public let session_id: String + public let resume_gateway_url: URL +} + public struct GatewayHello: Codable, Sendable { - let heartbeat_interval: Int + public let heartbeat_interval: Int } public struct CreateMessageReq: Codable, Sendable { diff --git a/Sources/zundamon/Zundamon.swift b/Sources/zundamon/Zundamon.swift index 245704f..939dc72 100644 --- a/Sources/zundamon/Zundamon.swift +++ b/Sources/zundamon/Zundamon.swift @@ -21,32 +21,16 @@ struct Zundamon { ownID = try await bot.client.getOwnUser().id guard ownID != nil else { fatalError("Failed to get own User ID") } - await withThrowingTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { - try await bot.connect() - } + try await bot.connect() - taskGroup.addTask { - for await event in await bot.events { - dump(event) - switch event.d { - case .messageCreate(let event): - try await MessageHandler(ctx: event, client: bot.client).handle() - default: - continue - } - } + for await event in await bot.events { + switch event { + case .messageCreate(let event): + try await MessageHandler(ctx: event, client: bot.client).handle() + default: + continue } } + print("bottom of main") } } - -//struct EventHandler: GatewayEventHandler { -// let event: Gateway.Event -// let client: any DiscordClient -// -// func onMessageCreate(_ payload: Gateway.MessageCreate) async throws { -// try await MessageHandler(ctx: payload, client: client).handle() -// } -// -//}