Files
zunda-bot/Sources/DiscordKit/GatewayClient.swift

206 lines
6.8 KiB
Swift

import Foundation
#if canImport(FoundationNetworking)
import FoundationNetworking
#endif
actor GatewayClient {
private var ws: URLSessionWebSocketTask
private let token: String
private(set) var open = false
private var sequenceNum: Int? = nil
private var sessionID: String? = nil
private var resumeURL: URL? = nil
private let intents: Intents
private let gatewayURL: URL
private var hbTask: Task<(), any Error>? = nil
init(gatewayURL: URL, token: String, intents: Intents) {
self.intents = intents
self.token = token
self.gatewayURL = gatewayURL
let queryItems = [URLQueryItem(name: "v", value: "10"), URLQueryItem(name: "encoding", value: "json")]
ws = URLSession.shared.webSocketTask(with: gatewayURL.appending(queryItems: queryItems))
}
func openConnection() async throws {
ws.resume()
try await setupHeartbeat()
try await sendIdentify(intents: intents)
guard case .ready(let readyMessage) = try await getMessage().d else { throw GatewayError.connectionFailure }
open = true
sessionID = readyMessage.session_id
resumeURL = readyMessage.resume_gateway_url
}
func setupHeartbeat() async throws {
guard case .hello(let helloMessage) = try await getMessage().d else { throw GatewayError.connectionFailure }
self.hbTask = Task.detached { [self] in
do {
try await Task.sleep(for: .milliseconds(Int.random(in: 0...helloMessage.heartbeat_interval)))
while !Task.isCancelled {
try await sendHeartbeat()
try await Task.sleep(for: .milliseconds(helloMessage.heartbeat_interval))
if await !hbAck {
// haven't heard back
await ws.cancel()
}
}
} catch {
print("Heartbeat task canceled")
}
}
}
func sendIdentify(intents: Intents) async throws {
let payload = """
{
"op": 2,
"d": {
"token": "\(token)",
"intents": \(intents.rawValue),
"properties": {
"os": "linux",
"browser": "discordkit",
"device": "discordkit"
}
}
}
""" // Im lazy
try await ws.send(.string(payload))
}
func getMessage() async throws -> GatewayMessage {
var strBuffer = ""
var gwMessage: GatewayMessage? = nil
let json = JSONDecoder()
while gwMessage == nil {
var wsMessage: URLSessionWebSocketTask.Message? = nil
do {
wsMessage = try await ws.receive()
} catch {
print("Error listening to gateway: \(error)")
try await reconnect()
}
guard let wsMessage = wsMessage else { continue }
#if DEBUG
//print(wsMessage)
#endif
guard case .string(let str) = wsMessage else { throw GatewayError.invalidMessage }
strBuffer.append(str)
do {
gwMessage = try json.decode(GatewayMessage.self, from: Data(strBuffer.utf8))
} catch DecodingError.dataCorrupted {
continue
}
}
guard let gwMessage = gwMessage else { throw GatewayError.invalidMessage }
sequenceNum = gwMessage.s ?? sequenceNum
if gwMessage.d == .heartbeatAck { hbAck = true }
return gwMessage
}
private func reconnect() async throws {
open = false
ws.cancel()
hbTask?.cancel()
do {
try await attemptResume()
} catch {
print("Error resuming session: \(error)")
while (!open) {
open = false
ws.cancel()
hbTask?.cancel()
let queryItems = [URLQueryItem(name: "v", value: "10"), URLQueryItem(name: "encoding", value: "json")]
ws = URLSession.shared.webSocketTask(with: gatewayURL.appending(queryItems: queryItems))
do {
try await openConnection()
} catch {
print("Error reconnecting: \(error)")
}
if (open) { break }
try await Task.sleep(for: .seconds(5))
}
}
}
private func attemptResume() async throws {
guard
ws.closeCode.rawValue != 4004 && ws.closeCode.rawValue != 1000 &&
ws.closeCode.rawValue != 1001 && ws.closeCode.rawValue < 4010,
let resumeURL = resumeURL,
let sessionID = sessionID,
let sequenceNum = sequenceNum
else {
if let closeReason = ws.closeReason {
throw GatewayError.disconected(code: ws.closeCode.rawValue, message: String(data: closeReason, encoding: .utf8))
} else {
throw GatewayError.disconected(code: ws.closeCode.rawValue, message: "unknown")
}
}
ws.cancel()
let queryItems = [URLQueryItem(name: "v", value: "10"), URLQueryItem(name: "encoding", value: "json")]
ws = URLSession.shared.webSocketTask(with: resumeURL.appending(queryItems: queryItems))
ws.resume()
try await setupHeartbeat()
let payload = """
{
"op": 6,
"d": {
"token": "\(token)",
"session_id": "\(sessionID)",
"seq": \(sequenceNum)
}
}
""" // Im lazy
try await ws.send(.string(payload))
self.open = true
}
var hbAck = false
private func sendHeartbeat() async throws {
guard self.open else { print("conn closed, skipping hb"); return }
hbAck = false
let hbMessage = "{\"op\":1,\"d\":\(sequenceNum == nil ? "null" : String(sequenceNum!))}"
try await ws.send(.string(hbMessage))
}
var events: AsyncStream<GatewayPayload> {
AsyncStream { [self] in
var payload: GatewayPayload? = nil
while payload == nil {
do {
while await !self.open {
try await Task.sleep(for: .seconds(1))
}
let event = try await getMessage()
#if DEBUG
dump(event)
#endif
if event.op == 1 { try await sendHeartbeat() }
payload = event.d
} catch {
print("Error listening to gateway: \(error)")
}
}
return payload!
}
}
}
public enum GatewayError: Error {
case invalidMessage
case invalidOpcode
case mismatchedOpcode
case disconected(code: Int, message: String?)
case connectionFailure
}