feat: enforce device-bound connect challenge

This commit is contained in:
Peter Steinberger
2026-01-20 11:15:10 +00:00
parent 121ae6036b
commit dfbf6ac263
21 changed files with 953 additions and 129 deletions

View File

@@ -1,11 +1,18 @@
import CryptoKit
import Foundation
struct DeviceIdentity: Codable, Sendable {
var deviceId: String
var publicKey: String
var privateKey: String
var createdAtMs: Int
public struct DeviceIdentity: Codable, Sendable {
public var deviceId: String
public var publicKey: String
public var privateKey: String
public var createdAtMs: Int
public init(deviceId: String, publicKey: String, privateKey: String, createdAtMs: Int) {
self.deviceId = deviceId
self.publicKey = publicKey
self.privateKey = privateKey
self.createdAtMs = createdAtMs
}
}
enum DeviceIdentityPaths {
@@ -27,10 +34,10 @@ enum DeviceIdentityPaths {
}
}
enum DeviceIdentityStore {
public enum DeviceIdentityStore {
private static let fileName = "device.json"
static func loadOrCreate() -> DeviceIdentity {
public static func loadOrCreate() -> DeviceIdentity {
let url = self.fileURL()
if let data = try? Data(contentsOf: url),
let decoded = try? JSONDecoder().decode(DeviceIdentity.self, from: data),
@@ -44,7 +51,7 @@ enum DeviceIdentityStore {
return identity
}
static func signPayload(_ payload: String, identity: DeviceIdentity) -> String? {
public static func signPayload(_ payload: String, identity: DeviceIdentity) -> String? {
guard let privateKeyData = Data(base64Encoded: identity.privateKey) else { return nil }
do {
let privateKey = try Curve25519.Signing.PrivateKey(rawRepresentation: privateKeyData)
@@ -76,7 +83,7 @@ enum DeviceIdentityStore {
.replacingOccurrences(of: "=", with: "")
}
static func publicKeyBase64Url(_ identity: DeviceIdentity) -> String? {
public static func publicKeyBase64Url(_ identity: DeviceIdentity) -> String? {
guard let data = Data(base64Encoded: identity.publicKey) else { return nil }
return self.base64UrlEncode(data)
}

View File

@@ -94,6 +94,10 @@ public struct GatewayConnectOptions: Sendable {
// Avoid ambiguity with the app's own AnyCodable type.
private typealias ProtoAnyCodable = ClawdbotProtocol.AnyCodable
private enum ConnectChallengeError: Error {
case timeout
}
public actor GatewayChannelActor {
private let logger = Logger(subsystem: "com.clawdbot", category: "gateway")
private var task: WebSocketTaskBox?
@@ -113,6 +117,7 @@ public actor GatewayChannelActor {
private let decoder = JSONDecoder()
private let encoder = JSONEncoder()
private let connectTimeoutSeconds: Double = 6
private let connectChallengeTimeoutSeconds: Double = 0.75
private var watchdogTask: Task<Void, Never>?
private var tickTask: Task<Void, Never>?
private let defaultRequestTimeoutMs: Double = 15000
@@ -294,9 +299,10 @@ public actor GatewayChannelActor {
}
let identity = DeviceIdentityStore.loadOrCreate()
let signedAtMs = Int(Date().timeIntervalSince1970 * 1000)
let connectNonce = try await self.waitForConnectChallenge()
let scopes = options.scopes.joined(separator: ",")
let payload = [
"v1",
var payloadParts = [
connectNonce == nil ? "v1" : "v2",
identity.deviceId,
clientId,
clientMode,
@@ -304,15 +310,23 @@ public actor GatewayChannelActor {
scopes,
String(signedAtMs),
self.token ?? "",
].joined(separator: "|")
]
if let connectNonce {
payloadParts.append(connectNonce)
}
let payload = payloadParts.joined(separator: "|")
if let signature = DeviceIdentityStore.signPayload(payload, identity: identity),
let publicKey = DeviceIdentityStore.publicKeyBase64Url(identity) {
params["device"] = ProtoAnyCodable([
var device: [String: ProtoAnyCodable] = [
"id": ProtoAnyCodable(identity.deviceId),
"publicKey": ProtoAnyCodable(publicKey),
"signature": ProtoAnyCodable(signature),
"signedAt": ProtoAnyCodable(signedAtMs),
])
]
if let connectNonce {
device["nonce"] = ProtoAnyCodable(connectNonce)
}
params["device"] = ProtoAnyCodable(device)
}
let frame = RequestFrame(
@@ -322,40 +336,11 @@ public actor GatewayChannelActor {
params: ProtoAnyCodable(params))
let data = try self.encoder.encode(frame)
try await self.task?.send(.data(data))
guard let msg = try await task?.receive() else {
throw NSError(
domain: "Gateway",
code: 1,
userInfo: [NSLocalizedDescriptionKey: "connect failed (no response)"])
}
try await self.handleConnectResponse(msg, reqId: reqId)
let response = try await self.waitForConnectResponse(reqId: reqId)
try await self.handleConnectResponse(response)
}
private func handleConnectResponse(_ msg: URLSessionWebSocketTask.Message, reqId: String) async throws {
let data: Data? = switch msg {
case let .data(d): d
case let .string(s): s.data(using: .utf8)
@unknown default: nil
}
guard let data else {
throw NSError(
domain: "Gateway",
code: 1,
userInfo: [NSLocalizedDescriptionKey: "connect failed (empty response)"])
}
let decoder = JSONDecoder()
guard let frame = try? decoder.decode(GatewayFrame.self, from: data) else {
throw NSError(
domain: "Gateway",
code: 1,
userInfo: [NSLocalizedDescriptionKey: "connect failed (invalid response)"])
}
guard case let .res(res) = frame, res.id == reqId else {
throw NSError(
domain: "Gateway",
code: 1,
userInfo: [NSLocalizedDescriptionKey: "connect failed (unexpected response)"])
}
private func handleConnectResponse(_ res: ResponseFrame) async throws {
if res.ok == false {
let msg = (res.error?["message"]?.value as? String) ?? "gateway connect failed"
throw NSError(domain: "Gateway", code: 1008, userInfo: [NSLocalizedDescriptionKey: msg])
@@ -424,6 +409,7 @@ public actor GatewayChannelActor {
waiter.resume(returning: .res(res))
}
case let .event(evt):
if evt.event == "connect.challenge" { return }
if let seq = evt.seq {
if let last = lastSeq, seq > last + 1 {
await self.pushHandler?(.seqGap(expected: last + 1, received: seq))
@@ -437,6 +423,63 @@ public actor GatewayChannelActor {
}
}
private func waitForConnectChallenge() async throws -> String? {
guard let task = self.task else { return nil }
do {
return try await AsyncTimeout.withTimeout(
seconds: self.connectChallengeTimeoutSeconds,
onTimeout: { ConnectChallengeError.timeout },
operation: { [weak self] in
guard let self else { return nil }
while true {
let msg = try await task.receive()
guard let data = self.decodeMessageData(msg) else { continue }
guard let frame = try? self.decoder.decode(GatewayFrame.self, from: data) else { continue }
if case let .event(evt) = frame, evt.event == "connect.challenge" {
if let payload = evt.payload?.value as? [String: ProtoAnyCodable],
let nonce = payload["nonce"]?.value as? String {
return nonce
}
}
}
})
} catch {
if error is ConnectChallengeError { return nil }
throw error
}
}
private func waitForConnectResponse(reqId: String) async throws -> ResponseFrame {
guard let task = self.task else {
throw NSError(
domain: "Gateway",
code: 1,
userInfo: [NSLocalizedDescriptionKey: "connect failed (no response)"])
}
while true {
let msg = try await task.receive()
guard let data = self.decodeMessageData(msg) else { continue }
guard let frame = try? self.decoder.decode(GatewayFrame.self, from: data) else {
throw NSError(
domain: "Gateway",
code: 1,
userInfo: [NSLocalizedDescriptionKey: "connect failed (invalid response)"])
}
if case let .res(res) = frame, res.id == reqId {
return res
}
}
}
private func decodeMessageData(_ msg: URLSessionWebSocketTask.Message) -> Data? {
let data: Data? = switch msg {
case let .data(data): data
case let .string(text): text.data(using: .utf8)
@unknown default: nil
}
return data
}
private func watchTicks() async {
let tolerance = self.tickIntervalMs * 2
while self.connected {