feat: Resume after disconnect

This commit is contained in:
Andrew Glaze
2026-03-22 13:14:26 -04:00
parent bc72fc9781
commit e6426225d7
4 changed files with 122 additions and 59 deletions

View File

@@ -6,20 +6,18 @@ import FoundationNetworking
public actor Bot { public actor Bot {
public let client: DiscordClient public let client: DiscordClient
let gateway: GatewayClient let gateway: GatewayClient
let intents: Intents
public init(token: String, intents: Intents) async throws { public init(token: String, intents: Intents) async throws {
client = DiscordClient(token: token) client = DiscordClient(token: token)
self.intents = intents
let gatewayURL = try await client.getGatewayURL() let gatewayURL = try await client.getGatewayURL()
gateway = GatewayClient(gatewayURL: gatewayURL, token: token) gateway = GatewayClient(gatewayURL: gatewayURL, token: token, intents: intents)
} }
public func connect() async throws { public func connect() async throws {
try await gateway.openConnection(intents: intents) try await gateway.openConnection()
} }
public var events: AsyncStream<GatewayMessage> { public var events: AsyncStream<GatewayPayload> {
get async { await gateway.events } get async { await gateway.events }
} }
} }

View File

@@ -4,35 +4,52 @@ import FoundationNetworking
#endif #endif
actor GatewayClient { actor GatewayClient {
private let ws: URLSessionWebSocketTask private var ws: URLSessionWebSocketTask
private let token: String private let token: String
private(set) var open = false private(set) var open = false
var sequenceNum: Int? = nil private var sequenceNum: Int? = nil
private var sessionID: String? = nil
private var resumeURL: URL? = nil
private let intents: Intents
private let gatewayURL: URL
private var hbTask: Task<(), any Error>? = nil
init(gatewayURL: URL, token: String) {
init(gatewayURL: URL, token: String, intents: Intents) {
self.intents = intents
self.token = token
self.gatewayURL = gatewayURL
let queryItems = [URLQueryItem(name: "v", value: "10"), URLQueryItem(name: "encoding", value: "json")] let queryItems = [URLQueryItem(name: "v", value: "10"), URLQueryItem(name: "encoding", value: "json")]
ws = URLSession.shared.webSocketTask(with: gatewayURL.appending(queryItems: queryItems)) ws = URLSession.shared.webSocketTask(with: gatewayURL.appending(queryItems: queryItems))
self.token = token
} }
func openConnection(intents: Intents) async throws { func openConnection() async throws {
ws.resume() ws.resume()
open = true try await setupHeartbeat()
guard case .hello(let helloMessage) = try await getMessage().d else { throw GatewayError.mismatchedOpcode }
dump(helloMessage)
let heartbeatTask = Task() {
try await Task.sleep(for: .milliseconds(Int.random(in: 0...helloMessage.heartbeat_interval)))
try await sendHeartbeat()
while !Task.isCancelled {
try await Task.sleep(for: .milliseconds(helloMessage.heartbeat_interval))
try await sendHeartbeat()
}
}
try await sendIdentify(intents: intents) try await sendIdentify(intents: intents)
_ = await heartbeatTask.result guard case .ready(let readyMessage) = try await getMessage().d else { throw GatewayError.connectionFailure }
open = true
sessionID = readyMessage.session_id
resumeURL = readyMessage.resume_gateway_url
}
func setupHeartbeat() async throws {
guard case .hello(let helloMessage) = try await getMessage().d else { throw GatewayError.connectionFailure }
self.hbTask = Task.detached { [self] in
do {
try await Task.sleep(for: .milliseconds(Int.random(in: 0...helloMessage.heartbeat_interval)))
try await sendHeartbeat()
while !Task.isCancelled {
try await Task.sleep(for: .milliseconds(helloMessage.heartbeat_interval))
try await sendHeartbeat()
}
} catch {
print("Heartbeat task canceled")
}
}
} }
func sendIdentify(intents: Intents) async throws { func sendIdentify(intents: Intents) async throws {
@@ -58,11 +75,15 @@ actor GatewayClient {
var gwMessage: GatewayMessage? = nil var gwMessage: GatewayMessage? = nil
let json = JSONDecoder() let json = JSONDecoder()
while gwMessage == nil { while gwMessage == nil {
let wsMessage = try await ws.receive()
guard case .string(let str) = wsMessage else { throw GatewayError.invalidMessage }
strBuffer.append(str)
do { do {
let wsMessage = try await ws.receive()
//print(wsMessage)
guard case .string(let str) = wsMessage else { throw GatewayError.invalidMessage }
strBuffer.append(str)
gwMessage = try json.decode(GatewayMessage.self, from: Data(strBuffer.utf8)) gwMessage = try json.decode(GatewayMessage.self, from: Data(strBuffer.utf8))
} catch URLError.networkConnectionLost {
self.open = false
try await reconnect()
} catch DecodingError.dataCorrupted { } catch DecodingError.dataCorrupted {
continue continue
} }
@@ -72,26 +93,75 @@ actor GatewayClient {
return gwMessage return gwMessage
} }
private func reconnect() async throws {
hbTask?.cancel()
do {
try await attemptResume()
} catch {
print(error)
ws.cancel()
let queryItems = [URLQueryItem(name: "v", value: "10"), URLQueryItem(name: "encoding", value: "json")]
ws = URLSession.shared.webSocketTask(with: gatewayURL.appending(queryItems: queryItems))
try await openConnection()
}
}
private func attemptResume() async throws {
guard
ws.closeCode.rawValue != 4004 && ws.closeCode.rawValue < 4010,
let resumeURL = resumeURL,
let sessionID = sessionID,
let sequenceNum = sequenceNum
else {
if let closeReason = ws.closeReason {
throw GatewayError.disconected(code: ws.closeCode.rawValue, message: String(data: closeReason, encoding: .utf8))
} else {
throw GatewayError.disconected(code: ws.closeCode.rawValue, message: "unknown")
}
}
ws.cancel()
let queryItems = [URLQueryItem(name: "v", value: "10"), URLQueryItem(name: "encoding", value: "json")]
ws = URLSession.shared.webSocketTask(with: resumeURL.appending(queryItems: queryItems))
ws.resume()
try await setupHeartbeat()
let payload = """
{
"op": 6,
"d": {
"token": "\(token)",
"session_id": "\(sessionID)",
"seq": \(sequenceNum)
}
}
""" // Im lazy
try await ws.send(.string(payload))
self.open = true
}
private func sendHeartbeat() async throws { private func sendHeartbeat() async throws {
guard self.open else { print("conn closed, skipping hb"); return }
let hbMessage = "{\"op\":1,\"d\":\(sequenceNum == nil ? "null" : String(sequenceNum!))}" let hbMessage = "{\"op\":1,\"d\":\(sequenceNum == nil ? "null" : String(sequenceNum!))}"
try await ws.send(.string(hbMessage)) try await ws.send(.string(hbMessage))
} }
var events: AsyncStream<GatewayMessage> { var events: AsyncStream<GatewayPayload> {
AsyncStream { [self] in AsyncStream { [self] in
var event: GatewayMessage? = nil var payload: GatewayPayload? = nil
while event == nil { while payload == nil {
do { do {
while await !open { while await !self.open {
try await Task.sleep(for: .seconds(1)) try await Task.sleep(for: .seconds(1))
} }
event = try await getMessage() let event = try await getMessage()
if event!.op == 1 { try await sendHeartbeat() } dump(event)
if event.op == 1 { try await sendHeartbeat() }
payload = event.d
} catch { } catch {
print("Error listening to gateway: \(error)") print("Error listening to gateway: \(error)")
} }
} }
return event! return payload!
} }
} }
@@ -101,4 +171,6 @@ public enum GatewayError: Error {
case invalidMessage case invalidMessage
case invalidOpcode case invalidOpcode
case mismatchedOpcode case mismatchedOpcode
case disconected(code: Int, message: String?)
case connectionFailure
} }

View File

@@ -50,6 +50,7 @@ public struct SessionStartLimit: Codable, Sendable {
public enum GatewayPayload: Decodable, Sendable { public enum GatewayPayload: Decodable, Sendable {
case hello(GatewayHello) case hello(GatewayHello)
case messageCreate(MessageCreate) case messageCreate(MessageCreate)
case ready(GatewayReady)
} }
public struct GatewayMessage: Decodable, Sendable { public struct GatewayMessage: Decodable, Sendable {
@@ -77,6 +78,9 @@ public struct GatewayMessage: Decodable, Sendable {
case 0 where t == "MESSAGE_CREATE": case 0 where t == "MESSAGE_CREATE":
let messageCreate = try container.decode(MessageCreate.self, forKey: .d) let messageCreate = try container.decode(MessageCreate.self, forKey: .d)
d = .messageCreate(messageCreate) d = .messageCreate(messageCreate)
case 0 where t == "READY":
let ready = try container.decode(GatewayReady.self, forKey: .d)
d = .ready(ready)
default: default:
d = nil d = nil
break break
@@ -91,8 +95,13 @@ public struct GatewayMessage: Decodable, Sendable {
} }
} }
public struct GatewayReady: Codable, Sendable {
public let session_id: String
public let resume_gateway_url: URL
}
public struct GatewayHello: Codable, Sendable { public struct GatewayHello: Codable, Sendable {
let heartbeat_interval: Int public let heartbeat_interval: Int
} }
public struct CreateMessageReq: Codable, Sendable { public struct CreateMessageReq: Codable, Sendable {

View File

@@ -21,32 +21,16 @@ struct Zundamon {
ownID = try await bot.client.getOwnUser().id ownID = try await bot.client.getOwnUser().id
guard ownID != nil else { fatalError("Failed to get own User ID") } guard ownID != nil else { fatalError("Failed to get own User ID") }
await withThrowingTaskGroup(of: Void.self) { taskGroup in try await bot.connect()
taskGroup.addTask {
try await bot.connect()
}
taskGroup.addTask { for await event in await bot.events {
for await event in await bot.events { switch event {
dump(event) case .messageCreate(let event):
switch event.d { try await MessageHandler(ctx: event, client: bot.client).handle()
case .messageCreate(let event): default:
try await MessageHandler(ctx: event, client: bot.client).handle() continue
default:
continue
}
}
} }
} }
print("bottom of main")
} }
} }
//struct EventHandler: GatewayEventHandler {
// let event: Gateway.Event
// let client: any DiscordClient
//
// func onMessageCreate(_ payload: Gateway.MessageCreate) async throws {
// try await MessageHandler(ctx: payload, client: client).handle()
// }
//
//}