feat: Resume after disconnect
This commit is contained in:
@@ -6,20 +6,18 @@ import FoundationNetworking
|
|||||||
public actor Bot {
|
public actor Bot {
|
||||||
public let client: DiscordClient
|
public let client: DiscordClient
|
||||||
let gateway: GatewayClient
|
let gateway: GatewayClient
|
||||||
let intents: Intents
|
|
||||||
|
|
||||||
public init(token: String, intents: Intents) async throws {
|
public init(token: String, intents: Intents) async throws {
|
||||||
client = DiscordClient(token: token)
|
client = DiscordClient(token: token)
|
||||||
self.intents = intents
|
|
||||||
let gatewayURL = try await client.getGatewayURL()
|
let gatewayURL = try await client.getGatewayURL()
|
||||||
gateway = GatewayClient(gatewayURL: gatewayURL, token: token)
|
gateway = GatewayClient(gatewayURL: gatewayURL, token: token, intents: intents)
|
||||||
}
|
}
|
||||||
|
|
||||||
public func connect() async throws {
|
public func connect() async throws {
|
||||||
try await gateway.openConnection(intents: intents)
|
try await gateway.openConnection()
|
||||||
}
|
}
|
||||||
|
|
||||||
public var events: AsyncStream<GatewayMessage> {
|
public var events: AsyncStream<GatewayPayload> {
|
||||||
get async { await gateway.events }
|
get async { await gateway.events }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,24 +4,41 @@ import FoundationNetworking
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
actor GatewayClient {
|
actor GatewayClient {
|
||||||
private let ws: URLSessionWebSocketTask
|
private var ws: URLSessionWebSocketTask
|
||||||
private let token: String
|
private let token: String
|
||||||
private(set) var open = false
|
private(set) var open = false
|
||||||
var sequenceNum: Int? = nil
|
private var sequenceNum: Int? = nil
|
||||||
|
private var sessionID: String? = nil
|
||||||
|
private var resumeURL: URL? = nil
|
||||||
|
private let intents: Intents
|
||||||
|
private let gatewayURL: URL
|
||||||
|
private var hbTask: Task<(), any Error>? = nil
|
||||||
|
|
||||||
init(gatewayURL: URL, token: String) {
|
|
||||||
|
init(gatewayURL: URL, token: String, intents: Intents) {
|
||||||
|
self.intents = intents
|
||||||
|
self.token = token
|
||||||
|
self.gatewayURL = gatewayURL
|
||||||
let queryItems = [URLQueryItem(name: "v", value: "10"), URLQueryItem(name: "encoding", value: "json")]
|
let queryItems = [URLQueryItem(name: "v", value: "10"), URLQueryItem(name: "encoding", value: "json")]
|
||||||
ws = URLSession.shared.webSocketTask(with: gatewayURL.appending(queryItems: queryItems))
|
ws = URLSession.shared.webSocketTask(with: gatewayURL.appending(queryItems: queryItems))
|
||||||
self.token = token
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func openConnection(intents: Intents) async throws {
|
func openConnection() async throws {
|
||||||
ws.resume()
|
ws.resume()
|
||||||
open = true
|
try await setupHeartbeat()
|
||||||
guard case .hello(let helloMessage) = try await getMessage().d else { throw GatewayError.mismatchedOpcode }
|
|
||||||
dump(helloMessage)
|
|
||||||
|
|
||||||
let heartbeatTask = Task() {
|
try await sendIdentify(intents: intents)
|
||||||
|
guard case .ready(let readyMessage) = try await getMessage().d else { throw GatewayError.connectionFailure }
|
||||||
|
open = true
|
||||||
|
sessionID = readyMessage.session_id
|
||||||
|
resumeURL = readyMessage.resume_gateway_url
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupHeartbeat() async throws {
|
||||||
|
guard case .hello(let helloMessage) = try await getMessage().d else { throw GatewayError.connectionFailure }
|
||||||
|
|
||||||
|
self.hbTask = Task.detached { [self] in
|
||||||
|
do {
|
||||||
try await Task.sleep(for: .milliseconds(Int.random(in: 0...helloMessage.heartbeat_interval)))
|
try await Task.sleep(for: .milliseconds(Int.random(in: 0...helloMessage.heartbeat_interval)))
|
||||||
try await sendHeartbeat()
|
try await sendHeartbeat()
|
||||||
|
|
||||||
@@ -29,10 +46,10 @@ actor GatewayClient {
|
|||||||
try await Task.sleep(for: .milliseconds(helloMessage.heartbeat_interval))
|
try await Task.sleep(for: .milliseconds(helloMessage.heartbeat_interval))
|
||||||
try await sendHeartbeat()
|
try await sendHeartbeat()
|
||||||
}
|
}
|
||||||
|
} catch {
|
||||||
|
print("Heartbeat task canceled")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
try await sendIdentify(intents: intents)
|
|
||||||
_ = await heartbeatTask.result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func sendIdentify(intents: Intents) async throws {
|
func sendIdentify(intents: Intents) async throws {
|
||||||
@@ -58,11 +75,15 @@ actor GatewayClient {
|
|||||||
var gwMessage: GatewayMessage? = nil
|
var gwMessage: GatewayMessage? = nil
|
||||||
let json = JSONDecoder()
|
let json = JSONDecoder()
|
||||||
while gwMessage == nil {
|
while gwMessage == nil {
|
||||||
|
do {
|
||||||
let wsMessage = try await ws.receive()
|
let wsMessage = try await ws.receive()
|
||||||
|
//print(wsMessage)
|
||||||
guard case .string(let str) = wsMessage else { throw GatewayError.invalidMessage }
|
guard case .string(let str) = wsMessage else { throw GatewayError.invalidMessage }
|
||||||
strBuffer.append(str)
|
strBuffer.append(str)
|
||||||
do {
|
|
||||||
gwMessage = try json.decode(GatewayMessage.self, from: Data(strBuffer.utf8))
|
gwMessage = try json.decode(GatewayMessage.self, from: Data(strBuffer.utf8))
|
||||||
|
} catch URLError.networkConnectionLost {
|
||||||
|
self.open = false
|
||||||
|
try await reconnect()
|
||||||
} catch DecodingError.dataCorrupted {
|
} catch DecodingError.dataCorrupted {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -72,26 +93,75 @@ actor GatewayClient {
|
|||||||
return gwMessage
|
return gwMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private func reconnect() async throws {
|
||||||
|
hbTask?.cancel()
|
||||||
|
do {
|
||||||
|
try await attemptResume()
|
||||||
|
} catch {
|
||||||
|
print(error)
|
||||||
|
ws.cancel()
|
||||||
|
let queryItems = [URLQueryItem(name: "v", value: "10"), URLQueryItem(name: "encoding", value: "json")]
|
||||||
|
ws = URLSession.shared.webSocketTask(with: gatewayURL.appending(queryItems: queryItems))
|
||||||
|
try await openConnection()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func attemptResume() async throws {
|
||||||
|
guard
|
||||||
|
ws.closeCode.rawValue != 4004 && ws.closeCode.rawValue < 4010,
|
||||||
|
let resumeURL = resumeURL,
|
||||||
|
let sessionID = sessionID,
|
||||||
|
let sequenceNum = sequenceNum
|
||||||
|
else {
|
||||||
|
if let closeReason = ws.closeReason {
|
||||||
|
throw GatewayError.disconected(code: ws.closeCode.rawValue, message: String(data: closeReason, encoding: .utf8))
|
||||||
|
} else {
|
||||||
|
throw GatewayError.disconected(code: ws.closeCode.rawValue, message: "unknown")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ws.cancel()
|
||||||
|
let queryItems = [URLQueryItem(name: "v", value: "10"), URLQueryItem(name: "encoding", value: "json")]
|
||||||
|
ws = URLSession.shared.webSocketTask(with: resumeURL.appending(queryItems: queryItems))
|
||||||
|
ws.resume()
|
||||||
|
try await setupHeartbeat()
|
||||||
|
|
||||||
|
let payload = """
|
||||||
|
{
|
||||||
|
"op": 6,
|
||||||
|
"d": {
|
||||||
|
"token": "\(token)",
|
||||||
|
"session_id": "\(sessionID)",
|
||||||
|
"seq": \(sequenceNum)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""" // Im lazy
|
||||||
|
try await ws.send(.string(payload))
|
||||||
|
self.open = true
|
||||||
|
}
|
||||||
|
|
||||||
private func sendHeartbeat() async throws {
|
private func sendHeartbeat() async throws {
|
||||||
|
guard self.open else { print("conn closed, skipping hb"); return }
|
||||||
let hbMessage = "{\"op\":1,\"d\":\(sequenceNum == nil ? "null" : String(sequenceNum!))}"
|
let hbMessage = "{\"op\":1,\"d\":\(sequenceNum == nil ? "null" : String(sequenceNum!))}"
|
||||||
try await ws.send(.string(hbMessage))
|
try await ws.send(.string(hbMessage))
|
||||||
}
|
}
|
||||||
|
|
||||||
var events: AsyncStream<GatewayMessage> {
|
var events: AsyncStream<GatewayPayload> {
|
||||||
AsyncStream { [self] in
|
AsyncStream { [self] in
|
||||||
var event: GatewayMessage? = nil
|
var payload: GatewayPayload? = nil
|
||||||
while event == nil {
|
while payload == nil {
|
||||||
do {
|
do {
|
||||||
while await !open {
|
while await !self.open {
|
||||||
try await Task.sleep(for: .seconds(1))
|
try await Task.sleep(for: .seconds(1))
|
||||||
}
|
}
|
||||||
event = try await getMessage()
|
let event = try await getMessage()
|
||||||
if event!.op == 1 { try await sendHeartbeat() }
|
dump(event)
|
||||||
|
if event.op == 1 { try await sendHeartbeat() }
|
||||||
|
payload = event.d
|
||||||
} catch {
|
} catch {
|
||||||
print("Error listening to gateway: \(error)")
|
print("Error listening to gateway: \(error)")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return event!
|
return payload!
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,4 +171,6 @@ public enum GatewayError: Error {
|
|||||||
case invalidMessage
|
case invalidMessage
|
||||||
case invalidOpcode
|
case invalidOpcode
|
||||||
case mismatchedOpcode
|
case mismatchedOpcode
|
||||||
|
case disconected(code: Int, message: String?)
|
||||||
|
case connectionFailure
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ public struct SessionStartLimit: Codable, Sendable {
|
|||||||
public enum GatewayPayload: Decodable, Sendable {
|
public enum GatewayPayload: Decodable, Sendable {
|
||||||
case hello(GatewayHello)
|
case hello(GatewayHello)
|
||||||
case messageCreate(MessageCreate)
|
case messageCreate(MessageCreate)
|
||||||
|
case ready(GatewayReady)
|
||||||
}
|
}
|
||||||
|
|
||||||
public struct GatewayMessage: Decodable, Sendable {
|
public struct GatewayMessage: Decodable, Sendable {
|
||||||
@@ -77,6 +78,9 @@ public struct GatewayMessage: Decodable, Sendable {
|
|||||||
case 0 where t == "MESSAGE_CREATE":
|
case 0 where t == "MESSAGE_CREATE":
|
||||||
let messageCreate = try container.decode(MessageCreate.self, forKey: .d)
|
let messageCreate = try container.decode(MessageCreate.self, forKey: .d)
|
||||||
d = .messageCreate(messageCreate)
|
d = .messageCreate(messageCreate)
|
||||||
|
case 0 where t == "READY":
|
||||||
|
let ready = try container.decode(GatewayReady.self, forKey: .d)
|
||||||
|
d = .ready(ready)
|
||||||
default:
|
default:
|
||||||
d = nil
|
d = nil
|
||||||
break
|
break
|
||||||
@@ -91,8 +95,13 @@ public struct GatewayMessage: Decodable, Sendable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public struct GatewayReady: Codable, Sendable {
|
||||||
|
public let session_id: String
|
||||||
|
public let resume_gateway_url: URL
|
||||||
|
}
|
||||||
|
|
||||||
public struct GatewayHello: Codable, Sendable {
|
public struct GatewayHello: Codable, Sendable {
|
||||||
let heartbeat_interval: Int
|
public let heartbeat_interval: Int
|
||||||
}
|
}
|
||||||
|
|
||||||
public struct CreateMessageReq: Codable, Sendable {
|
public struct CreateMessageReq: Codable, Sendable {
|
||||||
|
|||||||
@@ -21,32 +21,16 @@ struct Zundamon {
|
|||||||
ownID = try await bot.client.getOwnUser().id
|
ownID = try await bot.client.getOwnUser().id
|
||||||
guard ownID != nil else { fatalError("Failed to get own User ID") }
|
guard ownID != nil else { fatalError("Failed to get own User ID") }
|
||||||
|
|
||||||
await withThrowingTaskGroup(of: Void.self) { taskGroup in
|
|
||||||
taskGroup.addTask {
|
|
||||||
try await bot.connect()
|
try await bot.connect()
|
||||||
}
|
|
||||||
|
|
||||||
taskGroup.addTask {
|
|
||||||
for await event in await bot.events {
|
for await event in await bot.events {
|
||||||
dump(event)
|
switch event {
|
||||||
switch event.d {
|
|
||||||
case .messageCreate(let event):
|
case .messageCreate(let event):
|
||||||
try await MessageHandler(ctx: event, client: bot.client).handle()
|
try await MessageHandler(ctx: event, client: bot.client).handle()
|
||||||
default:
|
default:
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
print("bottom of main")
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//struct EventHandler: GatewayEventHandler {
|
|
||||||
// let event: Gateway.Event
|
|
||||||
// let client: any DiscordClient
|
|
||||||
//
|
|
||||||
// func onMessageCreate(_ payload: Gateway.MessageCreate) async throws {
|
|
||||||
// try await MessageHandler(ctx: payload, client: client).handle()
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
//}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user