diff --git a/Package.resolved b/Package.resolved index 994e5f9..308c81c 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "48b60a6f8caccb2179aca6ee731ee84b313e8f73da96743e41514dc5bf29a68e", + "originHash" : "217db9374e82d9bb33f20f600906b3522dd8ec57390c3a666fcba5ce08c1af12", "pins" : [ { "identity" : "async-http-client", @@ -55,24 +55,6 @@ "version" : "4.8.1" } }, - { - "identity" : "jwt", - "kind" : "remoteSourceControl", - "location" : "https://github.com/vapor/jwt.git", - "state" : { - "revision" : "af1c59762d70d1065ddbc0d7902ea9b3dacd1a26", - "version" : "5.1.2" - } - }, - { - "identity" : "jwt-kit", - "kind" : "remoteSourceControl", - "location" : "https://github.com/vapor/jwt-kit.git", - "state" : { - "revision" : "03f5013f0b547ce43abe45e7e90711303a3e5495", - "version" : "5.1.2" - } - }, { "identity" : "multipart-kit", "kind" : "remoteSourceControl", @@ -145,15 +127,6 @@ "version" : "1.2.0" } }, - { - "identity" : "swift-certificates", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-certificates.git", - "state" : { - "revision" : "999fd70c7803da89f3904d635a6815a2a7cd7585", - "version" : "1.10.0" - } - }, { "identity" : "swift-collections", "kind" : "remoteSourceControl", diff --git a/Package.swift b/Package.swift index e917f55..d656d6d 100644 --- a/Package.swift +++ b/Package.swift @@ -15,8 +15,6 @@ let package = Package( .package(url: "https://github.com/vapor/fluent-sqlite-driver.git", from: "4.6.0"), // 🔵 Non-blocking, event-driven networking for Swift. Used for custom executors .package(url: "https://github.com/apple/swift-nio.git", from: "2.65.0"), - // JWTs - .package(url: "https://github.com/vapor/jwt.git", from: "5.0.0"), // MsgPack .package(url: "https://github.com/nnabeyang/swift-msgpack.git", from: "0.7.0") ], @@ -29,7 +27,6 @@ let package = Package( .product(name: "Vapor", package: "vapor"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOPosix", package: "swift-nio"), - .product(name: "JWT", package: "jwt"), .product(name: "SwiftMsgpack", package: "swift-msgpack"), ], swiftSettings: swiftSettings diff --git a/Sources/stella/Controllers/Api/ToolController.swift b/Sources/stella/Controllers/Api/ToolController.swift index 6bacfbc..37c7368 100644 --- a/Sources/stella/Controllers/Api/ToolController.swift +++ b/Sources/stella/Controllers/Api/ToolController.swift @@ -1,4 +1,5 @@ import Vapor +import Fluent import SwiftMsgpack import JWT @@ -13,11 +14,24 @@ struct ToolController: RouteCollection { func signup(req: Request) async throws -> Response { let body = try req.content.decode(SignupReq.self, using: MsgPackDecoder()) - let session = try await req.jwt.verify(body.access_token, as: SessionPayload.self) - guard session.type == SessionType.ZAT.rawValue else { + guard + let session = try await Session.query(on: req.db).filter(\.$id == body.access_token).first(), + session.type == SessionType.ZAT + else { throw Abort(.forbidden, reason: "Invalid access token") } + let accountId = session.$account.id + if let player = try await Player.query(on: req.db).filter(\.$account.$id == accountId).first() { + + } else { + guard let account = try await Account.query(on: req.db).filter(\.$id == accountId).first() else { + throw Abort(.forbidden, reason: "Account ID does not exist") + } + // Create new Player + let player = try Player.createDefault(account: account) + try await player.save(on: req.db) + } throw Abort(.notImplemented) } diff --git a/Sources/stella/Controllers/OpenApi/AuthController.swift b/Sources/stella/Controllers/OpenApi/AuthController.swift index 02423f5..80b105f 100644 --- a/Sources/stella/Controllers/OpenApi/AuthController.swift +++ b/Sources/stella/Controllers/OpenApi/AuthController.swift @@ -61,12 +61,24 @@ struct AuthController: RouteCollection { let idpId = body.whiteKey var account: Account? = nil - if let rawAccountId = req.headers["playerId"].first, let accountId = Int(rawAccountId) { - if let existingAccount = try await Account.query(on: req.db).filter(\.$id == accountId).first() { - account = existingAccount - } + if + let rawAccountId = req.headers["playerId"].first, + let accountId = Int(rawAccountId), + let existingAccount = try await Account.query(on: req.db).filter(\.$id == accountId).first() + { + account = existingAccount + // Delete old sessions + try await Session.query(on: req.db) + .filter(\.$account.$id == existingAccount.requireID()) + .filter(\.$type ~~ [.ZAT, .ZRT]) + .delete() } else if let existingAccount = try await Account.query(on: req.db).filter(\.$idpId == idpId).first() { account = existingAccount + // Delete old sessions + try await Session.query(on: req.db) + .filter(\.$account.$id == existingAccount.requireID()) + .filter(\.$type ~~ [.ZAT, .ZRT]) + .delete() } else { account = Account(appId: body.appId, idpAlias: idpAlias, idpCode: "zd3", idpId: idpId, status: "normal") try await account!.create(on: req.db) @@ -75,6 +87,7 @@ struct AuthController: RouteCollection { guard let account = account else { throw Abort(.badRequest) } + print("got here") if account.idpAlias != idpAlias { account.idpAlias = idpAlias @@ -84,11 +97,13 @@ struct AuthController: RouteCollection { let zatExpiry = Date.now.advanced(by: 43200) let zrtExpiry = Date.now.advanced(by: 2592000) - let zatTokenJWT = generateToken(accountId: try account.requireID(), expires: zatExpiry, type: .ZAT) - let zrtTokenJWT = generateToken(accountId: try account.requireID(), expires: zrtExpiry, type: .ZRT) + let zatSession = try Session(account: account, expires: zatExpiry, type: .ZAT) + let zrtSession = try Session(account: account, expires: zrtExpiry, type: .ZRT) + try await zatSession.create(on: req.db) + try await zrtSession.create(on: req.db) - let zatToken = try await req.jwt.sign(zatTokenJWT) - let zrtToken = try await req.jwt.sign(zrtTokenJWT) + let zatToken = try zatSession.requireID() + let zrtToken = try zrtSession.requireID() let res = LoginDeviceRes( zatExpiryTime: Int(zatExpiry.timeIntervalSince1970) * 1000, @@ -119,18 +134,20 @@ struct AuthController: RouteCollection { @Sendable func zatLogin(req: Request) async throws -> LoginDeviceRes { let body = try req.content.decode(ZatLoginReq.self, as: .json) - if let session = try? await req.jwt.verify(body.zat, as: SessionPayload.self) { - guard session.type == SessionType.ZAT.rawValue && session.accountId.value == body.playerId else { + if let session = try await Session.query(on: req.db).filter(\.$id == body.zat).first() { + guard let accountId = Int(body.playerId), session.type == SessionType.ZAT && session.$account.id == accountId else { throw Abort(.badRequest, reason: "Invalid zat provided.") } + try await session.delete(on: req.db) } + guard let accountId = Int(body.playerId), let account = try await Account.query(on: req.db) .filter(\.$id == accountId) .first() else { - throw Abort(.badRequest, reason: "Invalid playerId") + throw Abort(.unauthorized, reason: "Invalid playerId") } account.lastLogin = Date.now @@ -138,8 +155,9 @@ struct AuthController: RouteCollection { let zatExpiry = Date.now.advanced(by: 43200) - let session = generateToken(accountId: try account.requireID(), expires: zatExpiry, type: SessionType.ZAT) - let zatToken = try await req.jwt.sign(session) + let session = try Session(account: account, expires: zatExpiry, type: SessionType.ZAT) + try await session.create(on: req.db) + let zatToken = try session.requireID() return LoginDeviceRes( zatExpiryTime: Int(zatExpiry.timeIntervalSince1970) * 1000, zrtExpiryTime: nil, @@ -303,8 +321,3 @@ struct PushOptionResponse: Content { let player: String } -enum SessionType: Int { - case ZAT = 0 - case ZRT = 1 - case VIEWER = 2 -} diff --git a/Sources/stella/Controllers/OpenApiController.swift b/Sources/stella/Controllers/OpenApiController.swift index 4393c45..ae5fa83 100644 --- a/Sources/stella/Controllers/OpenApiController.swift +++ b/Sources/stella/Controllers/OpenApiController.swift @@ -1,5 +1,5 @@ import Vapor -import JWT +import Fluent struct OpenApiController: RouteCollection { func boot(routes: any RoutesBuilder) throws { @@ -17,10 +17,13 @@ struct OpenApiController: RouteCollection { guard let zatToken = req.headers["zat"].first else { throw Abort(.badRequest, reason: "Missing zat header.") } - let jwt = try await req.jwt.verify(zatToken, as: SessionPayload.self) - guard jwt.accountId.value == beat.playerId else { + + guard let session = try await Session.query(on: req.db).filter(\.$id == zatToken).first(), + let playerId = Int(beat.playerId), + session.$account.id == playerId else { throw Abort(.unauthorized, reason: "zat invalid") } + return "{}" } @@ -29,8 +32,9 @@ struct OpenApiController: RouteCollection { guard let zatToken = req.headers["zat"].first else { throw Abort(.badRequest, reason: "Missing zat header.") } - let jwt = try await req.jwt.verify(zatToken, as: SessionPayload.self) - guard jwt.accountId.value == beat.playerId else { + guard let session = try await Session.query(on: req.db).filter(\.$id == zatToken).first(), + let playerId = Int(beat.playerId), + session.$account.id == playerId else { throw Abort(.unauthorized, reason: "zat invalid") } return "{}" diff --git a/Sources/stella/Migrations/CreatePlayers.swift b/Sources/stella/Migrations/CreatePlayers.swift index 42e0c21..02fd970 100644 --- a/Sources/stella/Migrations/CreatePlayers.swift +++ b/Sources/stella/Migrations/CreatePlayers.swift @@ -27,13 +27,13 @@ struct CreatePlayers: AsyncMigration { .field("free_mana", .int, .required) .field("paid_mana", .int, .required) .field("enable_auto_3x", .bool, .required) - .field("account_id", .int, .required) + .field("account_id", .int, .required, .references("accounts", "id")) .field("tutorial_step", .int, .required) .field("tutorial_skip_flag", .int, .required) .create() } func revert(on database: any Database) async throws { - + try await database.schema("players").delete() } } diff --git a/Sources/stella/Migrations/CreateSessions.swift b/Sources/stella/Migrations/CreateSessions.swift new file mode 100644 index 0000000..ddc2e16 --- /dev/null +++ b/Sources/stella/Migrations/CreateSessions.swift @@ -0,0 +1,16 @@ +import Fluent + +struct CreateSessions: AsyncMigration { + func prepare(on database: any Database) async throws { + try await database.schema("sessions") + .field("token", .string, .identifier(auto: false)) + .field("type", .int, .required) + .field("expires", .datetime, .required) + .field("accountId", .int, .required, .references("accounts", "id")) + .create() + } + + func revert(on database: any Database) async throws { + try await database.schema("sessions").delete() + } +} diff --git a/Sources/stella/Models/Player.swift b/Sources/stella/Models/Player.swift index 846f3eb..d49d0cf 100644 --- a/Sources/stella/Models/Player.swift +++ b/Sources/stella/Models/Player.swift @@ -88,7 +88,7 @@ final class Player: Model, @unchecked Sendable { account: Account, tutorialStep: Int?, tutorialSkipFlag: Int? - ) { + ) throws { self.stamina = stamina self.staminaHealTime = staminaHealTime self.boostPoint = boostPoint @@ -114,10 +114,12 @@ final class Player: Model, @unchecked Sendable { self.enableAuto3x = enableAuto3x self.tutorialStep = tutorialStep self.tutorialSkipFlag = tutorialSkipFlag + + self.$account.id = try account.requireID() } - static func createDefault(account: Account) -> Player { - return Player( + static func createDefault(account: Account) throws -> Player { + return try Player( stamina: 20, staminaHealTime: Date.now, boostPoint: 3, diff --git a/Sources/stella/Models/Session.swift b/Sources/stella/Models/Session.swift new file mode 100644 index 0000000..9003bba --- /dev/null +++ b/Sources/stella/Models/Session.swift @@ -0,0 +1,32 @@ +import Fluent + +final class Session: Model, @unchecked Sendable { + static let schema: String = "sessions" + + @ID(custom: "token", generatedBy: .random) + var id: String? + + @Parent(key: "accountId") + var account: Account + + @Field(key: "expires") + var expires: Date + + @Field(key: "type") + var type: SessionType + + init() { } + + init(account: Account, expires: Date, type: SessionType) throws { + self.id = UUID().uuidString + self.$account.id = try account.requireID() + self.expires = expires + self.type = type + } +} + +enum SessionType: Int, Codable { + case ZAT = 0 + case ZRT = 1 + case VIEWER = 2 +} diff --git a/Sources/stella/configure.swift b/Sources/stella/configure.swift index 3ddd91f..4764526 100644 --- a/Sources/stella/configure.swift +++ b/Sources/stella/configure.swift @@ -2,7 +2,6 @@ import NIOSSL import Fluent import FluentSQLiteDriver import Vapor -import JWT // configures your application public func configure(_ app: Application) async throws { @@ -12,12 +11,11 @@ public func configure(_ app: Application) async throws { app.databases.use(DatabaseConfigurationFactory.sqlite(.file("db.sqlite")), as: .sqlite) app.migrations.add(CreateAccount()) + app.migrations.add(CreatePlayers()) + app.migrations.add(CreateSessions()) app.http.server.configuration.hostname = "0.0.0.0" app.http.server.configuration.port = 8000 - // JWT - await app.jwt.keys.add(hmac: "secret", digestAlgorithm: .sha256) - // register routes try routes(app) } diff --git a/Sources/stella/util.swift b/Sources/stella/util.swift index d3c051c..a1e52f7 100644 --- a/Sources/stella/util.swift +++ b/Sources/stella/util.swift @@ -1,4 +1,3 @@ -import JWT import Fluent import Foundation @@ -6,38 +5,3 @@ func generateIdpAlias(appId: String, deviceId: String, serialNo: String) -> Stri return "\(appId):\(deviceId):\(serialNo)" } -func generateToken(accountId: Int, expires: Date, type: SessionType) -> SessionPayload { - return SessionPayload( - accountId: .init(value: String(accountId)), - expiration: .init(value: expires), - type: type.rawValue - ) -} - -struct SessionPayload: JWTPayload { - enum CodingKeys: String, CodingKey { - case accountId = "sub" - case expiration = "exp" - case type = "type" - } - - // The "sub" (subject) claim identifies the principal that is the - // subject of the JWT. - var accountId: SubjectClaim - - // The "exp" (expiration time) claim identifies the expiration time on - // or after which the JWT MUST NOT be accepted for processing. - var expiration: ExpirationClaim - - // Custom data. - // If true, the user is an admin. - var type: Int - - // Run any additional verification logic beyond - // signature verification here. - // Since we have an ExpirationClaim, we will - // call its verify method. - func verify(using algorithm: some JWTAlgorithm) async throws { - try self.expiration.verifyNotExpired() - } -}