Files
zunda-bot/Sources/DiscordKit/GatewayClient.swift
2026-03-22 13:24:56 -04:00

181 lines
5.9 KiB
Swift

import Foundation
#if canImport(FoundationNetworking)
import FoundationNetworking
#endif
actor GatewayClient {
private var ws: URLSessionWebSocketTask
private let token: String
private(set) var open = false
private var sequenceNum: Int? = nil
private var sessionID: String? = nil
private var resumeURL: URL? = nil
private let intents: Intents
private let gatewayURL: URL
private var hbTask: Task<(), any Error>? = nil
init(gatewayURL: URL, token: String, intents: Intents) {
self.intents = intents
self.token = token
self.gatewayURL = gatewayURL
let queryItems = [URLQueryItem(name: "v", value: "10"), URLQueryItem(name: "encoding", value: "json")]
ws = URLSession.shared.webSocketTask(with: gatewayURL.appending(queryItems: queryItems))
}
func openConnection() async throws {
ws.resume()
try await setupHeartbeat()
try await sendIdentify(intents: intents)
guard case .ready(let readyMessage) = try await getMessage().d else { throw GatewayError.connectionFailure }
open = true
sessionID = readyMessage.session_id
resumeURL = readyMessage.resume_gateway_url
}
func setupHeartbeat() async throws {
guard case .hello(let helloMessage) = try await getMessage().d else { throw GatewayError.connectionFailure }
self.hbTask = Task.detached { [self] in
do {
try await Task.sleep(for: .milliseconds(Int.random(in: 0...helloMessage.heartbeat_interval)))
try await sendHeartbeat()
while !Task.isCancelled {
try await Task.sleep(for: .milliseconds(helloMessage.heartbeat_interval))
try await sendHeartbeat()
}
} catch {
print("Heartbeat task canceled")
}
}
}
func sendIdentify(intents: Intents) async throws {
let payload = """
{
"op": 2,
"d": {
"token": "\(token)",
"intents": \(intents.rawValue),
"properties": {
"os": "linux",
"browser": "discordkit",
"device": "discordkit"
}
}
}
""" // Im lazy
try await ws.send(.string(payload))
}
func getMessage() async throws -> GatewayMessage {
var strBuffer = ""
var gwMessage: GatewayMessage? = nil
let json = JSONDecoder()
while gwMessage == nil {
do {
let wsMessage = try await ws.receive()
#if DEBUG
//print(wsMessage)
#endif
guard case .string(let str) = wsMessage else { throw GatewayError.invalidMessage }
strBuffer.append(str)
gwMessage = try json.decode(GatewayMessage.self, from: Data(strBuffer.utf8))
} catch URLError.networkConnectionLost {
self.open = false
try await reconnect()
} catch DecodingError.dataCorrupted {
continue
}
}
guard let gwMessage = gwMessage else { throw GatewayError.invalidMessage }
sequenceNum = gwMessage.s ?? sequenceNum
return gwMessage
}
private func reconnect() async throws {
hbTask?.cancel()
do {
try await attemptResume()
} catch {
print(error)
ws.cancel()
let queryItems = [URLQueryItem(name: "v", value: "10"), URLQueryItem(name: "encoding", value: "json")]
ws = URLSession.shared.webSocketTask(with: gatewayURL.appending(queryItems: queryItems))
try await openConnection()
}
}
private func attemptResume() async throws {
guard
ws.closeCode.rawValue != 4004 && ws.closeCode.rawValue < 4010,
let resumeURL = resumeURL,
let sessionID = sessionID,
let sequenceNum = sequenceNum
else {
if let closeReason = ws.closeReason {
throw GatewayError.disconected(code: ws.closeCode.rawValue, message: String(data: closeReason, encoding: .utf8))
} else {
throw GatewayError.disconected(code: ws.closeCode.rawValue, message: "unknown")
}
}
ws.cancel()
let queryItems = [URLQueryItem(name: "v", value: "10"), URLQueryItem(name: "encoding", value: "json")]
ws = URLSession.shared.webSocketTask(with: resumeURL.appending(queryItems: queryItems))
ws.resume()
try await setupHeartbeat()
let payload = """
{
"op": 6,
"d": {
"token": "\(token)",
"session_id": "\(sessionID)",
"seq": \(sequenceNum)
}
}
""" // Im lazy
try await ws.send(.string(payload))
self.open = true
}
private func sendHeartbeat() async throws {
guard self.open else { print("conn closed, skipping hb"); return }
let hbMessage = "{\"op\":1,\"d\":\(sequenceNum == nil ? "null" : String(sequenceNum!))}"
try await ws.send(.string(hbMessage))
}
var events: AsyncStream<GatewayPayload> {
AsyncStream { [self] in
var payload: GatewayPayload? = nil
while payload == nil {
do {
while await !self.open {
try await Task.sleep(for: .seconds(1))
}
let event = try await getMessage()
#if DEBUG
dump(event)
#endif
if event.op == 1 { try await sendHeartbeat() }
payload = event.d
} catch {
print("Error listening to gateway: \(error)")
}
}
return payload!
}
}
}
public enum GatewayError: Error {
case invalidMessage
case invalidOpcode
case mismatchedOpcode
case disconected(code: Int, message: String?)
case connectionFailure
}