feat: enforce device-bound connect challenge

This commit is contained in:
Peter Steinberger
2026-01-20 11:15:10 +00:00
parent 121ae6036b
commit dfbf6ac263
21 changed files with 953 additions and 129 deletions

View File

@@ -78,6 +78,7 @@ let package = Package(
.executableTarget( .executableTarget(
name: "ClawdbotWizardCLI", name: "ClawdbotWizardCLI",
dependencies: [ dependencies: [
.product(name: "ClawdbotKit", package: "ClawdbotKit"),
.product(name: "ClawdbotProtocol", package: "ClawdbotKit"), .product(name: "ClawdbotProtocol", package: "ClawdbotKit"),
], ],
path: "Sources/ClawdbotWizardCLI", path: "Sources/ClawdbotWizardCLI",

View File

@@ -20,7 +20,7 @@ public struct ConnectParams: Codable, Sendable {
public let permissions: [String: AnyCodable]? public let permissions: [String: AnyCodable]?
public let role: String? public let role: String?
public let scopes: [String]? public let scopes: [String]?
public let device: [String: AnyCodable] public let device: [String: AnyCodable]?
public let auth: [String: AnyCodable]? public let auth: [String: AnyCodable]?
public let locale: String? public let locale: String?
public let useragent: String? public let useragent: String?
@@ -34,7 +34,7 @@ public struct ConnectParams: Codable, Sendable {
permissions: [String: AnyCodable]?, permissions: [String: AnyCodable]?,
role: String?, role: String?,
scopes: [String]?, scopes: [String]?,
device: [String: AnyCodable], device: [String: AnyCodable]?,
auth: [String: AnyCodable]?, auth: [String: AnyCodable]?,
locale: String?, locale: String?,
useragent: String? useragent: String?

View File

@@ -1,7 +1,10 @@
import ClawdbotKit
import ClawdbotProtocol import ClawdbotProtocol
import Darwin import Darwin
import Foundation import Foundation
private typealias ProtoAnyCodable = ClawdbotProtocol.AnyCodable
struct WizardCliOptions { struct WizardCliOptions {
var url: String? var url: String?
var token: String? var token: String?
@@ -228,6 +231,10 @@ private func parseInt(_ value: Any?) -> Int? {
} }
actor GatewayWizardClient { actor GatewayWizardClient {
private enum ConnectChallengeError: Error {
case timeout
}
private let url: URL private let url: URL
private let token: String? private let token: String?
private let password: String? private let password: String?
@@ -235,6 +242,7 @@ actor GatewayWizardClient {
private let encoder = JSONEncoder() private let encoder = JSONEncoder()
private let decoder = JSONDecoder() private let decoder = JSONDecoder()
private let session = URLSession(configuration: .default) private let session = URLSession(configuration: .default)
private let connectChallengeTimeoutSeconds: Double = 0.75
private var task: URLSessionWebSocketTask? private var task: URLSessionWebSocketTask?
init(url: URL, token: String?, password: String?, json: Bool) { init(url: URL, token: String?, password: String?, json: Bool) {
@@ -257,7 +265,7 @@ actor GatewayWizardClient {
self.task = nil self.task = nil
} }
func request(method: String, params: [String: AnyCodable]?) async throws -> ResponseFrame { func request(method: String, params: [String: ProtoAnyCodable]?) async throws -> ResponseFrame {
guard let task = self.task else { guard let task = self.task else {
throw WizardCliError.gatewayError("gateway not connected") throw WizardCliError.gatewayError("gateway not connected")
} }
@@ -266,7 +274,7 @@ actor GatewayWizardClient {
type: "req", type: "req",
id: id, id: id,
method: method, method: method,
params: params.map { AnyCodable($0) }) params: params.map { ProtoAnyCodable($0) })
let data = try self.encoder.encode(frame) let data = try self.encoder.encode(frame)
try await task.send(.data(data)) try await task.send(.data(data))
@@ -309,28 +317,65 @@ actor GatewayWizardClient {
} }
let osVersion = ProcessInfo.processInfo.operatingSystemVersion let osVersion = ProcessInfo.processInfo.operatingSystemVersion
let platform = "macos \(osVersion.majorVersion).\(osVersion.minorVersion).\(osVersion.patchVersion)" let platform = "macos \(osVersion.majorVersion).\(osVersion.minorVersion).\(osVersion.patchVersion)"
let client: [String: AnyCodable] = [ let clientId = "clawdbot-macos"
"id": AnyCodable("clawdbot-macos"), let clientMode = "ui"
"displayName": AnyCodable(Host.current().localizedName ?? "Clawdbot macOS Wizard CLI"), let role = "operator"
"version": AnyCodable("dev"), let scopes: [String] = []
"platform": AnyCodable(platform), let client: [String: ProtoAnyCodable] = [
"deviceFamily": AnyCodable("Mac"), "id": ProtoAnyCodable(clientId),
"mode": AnyCodable("ui"), "displayName": ProtoAnyCodable(Host.current().localizedName ?? "Clawdbot macOS Wizard CLI"),
"instanceId": AnyCodable(UUID().uuidString), "version": ProtoAnyCodable("dev"),
"platform": ProtoAnyCodable(platform),
"deviceFamily": ProtoAnyCodable("Mac"),
"mode": ProtoAnyCodable(clientMode),
"instanceId": ProtoAnyCodable(UUID().uuidString),
] ]
var params: [String: AnyCodable] = [ var params: [String: ProtoAnyCodable] = [
"minProtocol": AnyCodable(GATEWAY_PROTOCOL_VERSION), "minProtocol": ProtoAnyCodable(GATEWAY_PROTOCOL_VERSION),
"maxProtocol": AnyCodable(GATEWAY_PROTOCOL_VERSION), "maxProtocol": ProtoAnyCodable(GATEWAY_PROTOCOL_VERSION),
"client": AnyCodable(client), "client": ProtoAnyCodable(client),
"caps": AnyCodable([String]()), "caps": ProtoAnyCodable([String]()),
"locale": AnyCodable(Locale.preferredLanguages.first ?? Locale.current.identifier), "locale": ProtoAnyCodable(Locale.preferredLanguages.first ?? Locale.current.identifier),
"userAgent": AnyCodable(ProcessInfo.processInfo.operatingSystemVersionString), "userAgent": ProtoAnyCodable(ProcessInfo.processInfo.operatingSystemVersionString),
"role": ProtoAnyCodable(role),
"scopes": ProtoAnyCodable(scopes),
] ]
if let token = self.token { if let token = self.token {
params["auth"] = AnyCodable(["token": AnyCodable(token)]) params["auth"] = ProtoAnyCodable(["token": ProtoAnyCodable(token)])
} else if let password = self.password { } else if let password = self.password {
params["auth"] = AnyCodable(["password": AnyCodable(password)]) params["auth"] = ProtoAnyCodable(["password": ProtoAnyCodable(password)])
}
let connectNonce = try await self.waitForConnectChallenge()
let identity = DeviceIdentityStore.loadOrCreate()
let signedAtMs = Int(Date().timeIntervalSince1970 * 1000)
let scopesValue = scopes.joined(separator: ",")
var payloadParts = [
connectNonce == nil ? "v1" : "v2",
identity.deviceId,
clientId,
clientMode,
role,
scopesValue,
String(signedAtMs),
self.token ?? "",
]
if let connectNonce {
payloadParts.append(connectNonce)
}
let payload = payloadParts.joined(separator: "|")
if let signature = DeviceIdentityStore.signPayload(payload, identity: identity),
let publicKey = DeviceIdentityStore.publicKeyBase64Url(identity) {
var device: [String: ProtoAnyCodable] = [
"id": ProtoAnyCodable(identity.deviceId),
"publicKey": ProtoAnyCodable(publicKey),
"signature": ProtoAnyCodable(signature),
"signedAt": ProtoAnyCodable(signedAtMs),
]
if let connectNonce {
device["nonce"] = ProtoAnyCodable(connectNonce)
}
params["device"] = ProtoAnyCodable(device)
} }
let reqId = UUID().uuidString let reqId = UUID().uuidString
@@ -338,31 +383,57 @@ actor GatewayWizardClient {
type: "req", type: "req",
id: reqId, id: reqId,
method: "connect", method: "connect",
params: AnyCodable(params)) params: ProtoAnyCodable(params))
let data = try self.encoder.encode(frame) let data = try self.encoder.encode(frame)
try await task.send(.data(data)) try await task.send(.data(data))
let message = try await task.receive() while true {
let frameResponse = try decodeFrame(message) let message = try await task.receive()
guard case let .res(res) = frameResponse, res.id == reqId else { let frameResponse = try decodeFrame(message)
throw WizardCliError.gatewayError("connect failed (unexpected response)") if case let .res(res) = frameResponse, res.id == reqId {
if res.ok == false {
let msg = (res.error?["message"]?.value as? String) ?? "gateway connect failed"
throw WizardCliError.gatewayError(msg)
}
_ = try self.decodePayload(res, as: HelloOk.self)
return
}
} }
if res.ok == false { }
let msg = (res.error?["message"]?.value as? String) ?? "gateway connect failed"
throw WizardCliError.gatewayError(msg) private func waitForConnectChallenge() async throws -> String? {
guard let task = self.task else { return nil }
do {
return try await AsyncTimeout.withTimeout(
seconds: self.connectChallengeTimeoutSeconds,
onTimeout: { ConnectChallengeError.timeout },
operation: {
while true {
let message = try await task.receive()
let frame = try decodeFrame(message)
if case let .event(evt) = frame, evt.event == "connect.challenge" {
if let payload = evt.payload?.value as? [String: ProtoAnyCodable],
let nonce = payload["nonce"]?.value as? String {
return nonce
}
}
}
})
} catch {
if error is ConnectChallengeError { return nil }
throw error
} }
_ = try self.decodePayload(res, as: HelloOk.self)
} }
} }
private func runWizard(client: GatewayWizardClient, opts: WizardCliOptions) async throws { private func runWizard(client: GatewayWizardClient, opts: WizardCliOptions) async throws {
var params: [String: AnyCodable] = [:] var params: [String: ProtoAnyCodable] = [:]
let mode = opts.mode.trimmingCharacters(in: .whitespacesAndNewlines).lowercased() let mode = opts.mode.trimmingCharacters(in: .whitespacesAndNewlines).lowercased()
if mode == "local" || mode == "remote" { if mode == "local" || mode == "remote" {
params["mode"] = AnyCodable(mode) params["mode"] = ProtoAnyCodable(mode)
} }
if let workspace = opts.workspace?.trimmingCharacters(in: .whitespacesAndNewlines), !workspace.isEmpty { if let workspace = opts.workspace?.trimmingCharacters(in: .whitespacesAndNewlines), !workspace.isEmpty {
params["workspace"] = AnyCodable(workspace) params["workspace"] = ProtoAnyCodable(workspace)
} }
let startResponse = try await client.request(method: "wizard.start", params: params) let startResponse = try await client.request(method: "wizard.start", params: params)
@@ -395,17 +466,17 @@ private func runWizard(client: GatewayWizardClient, opts: WizardCliOptions) asyn
if let step = decodeWizardStep(nextResult.step) { if let step = decodeWizardStep(nextResult.step) {
let answer = try promptAnswer(for: step) let answer = try promptAnswer(for: step)
var answerPayload: [String: AnyCodable] = [ var answerPayload: [String: ProtoAnyCodable] = [
"stepId": AnyCodable(step.id), "stepId": ProtoAnyCodable(step.id),
] ]
if !(answer is NSNull) { if !(answer is NSNull) {
answerPayload["value"] = AnyCodable(answer) answerPayload["value"] = ProtoAnyCodable(answer)
} }
let response = try await client.request( let response = try await client.request(
method: "wizard.next", method: "wizard.next",
params: [ params: [
"sessionId": AnyCodable(sessionId), "sessionId": ProtoAnyCodable(sessionId),
"answer": AnyCodable(answerPayload), "answer": ProtoAnyCodable(answerPayload),
]) ])
nextResult = try await client.decodePayload(response, as: WizardNextResult.self) nextResult = try await client.decodePayload(response, as: WizardNextResult.self)
if opts.json { if opts.json {
@@ -414,7 +485,7 @@ private func runWizard(client: GatewayWizardClient, opts: WizardCliOptions) asyn
} else { } else {
let response = try await client.request( let response = try await client.request(
method: "wizard.next", method: "wizard.next",
params: ["sessionId": AnyCodable(sessionId)]) params: ["sessionId": ProtoAnyCodable(sessionId)])
nextResult = try await client.decodePayload(response, as: WizardNextResult.self) nextResult = try await client.decodePayload(response, as: WizardNextResult.self)
if opts.json { if opts.json {
dumpResult(response) dumpResult(response)
@@ -424,7 +495,7 @@ private func runWizard(client: GatewayWizardClient, opts: WizardCliOptions) asyn
} catch WizardCliError.cancelled { } catch WizardCliError.cancelled {
_ = try? await client.request( _ = try? await client.request(
method: "wizard.cancel", method: "wizard.cancel",
params: ["sessionId": AnyCodable(sessionId)]) params: ["sessionId": ProtoAnyCodable(sessionId)])
throw WizardCliError.cancelled throw WizardCliError.cancelled
} }
} }

View File

@@ -1,11 +1,18 @@
import CryptoKit import CryptoKit
import Foundation import Foundation
struct DeviceIdentity: Codable, Sendable { public struct DeviceIdentity: Codable, Sendable {
var deviceId: String public var deviceId: String
var publicKey: String public var publicKey: String
var privateKey: String public var privateKey: String
var createdAtMs: Int public var createdAtMs: Int
public init(deviceId: String, publicKey: String, privateKey: String, createdAtMs: Int) {
self.deviceId = deviceId
self.publicKey = publicKey
self.privateKey = privateKey
self.createdAtMs = createdAtMs
}
} }
enum DeviceIdentityPaths { enum DeviceIdentityPaths {
@@ -27,10 +34,10 @@ enum DeviceIdentityPaths {
} }
} }
enum DeviceIdentityStore { public enum DeviceIdentityStore {
private static let fileName = "device.json" private static let fileName = "device.json"
static func loadOrCreate() -> DeviceIdentity { public static func loadOrCreate() -> DeviceIdentity {
let url = self.fileURL() let url = self.fileURL()
if let data = try? Data(contentsOf: url), if let data = try? Data(contentsOf: url),
let decoded = try? JSONDecoder().decode(DeviceIdentity.self, from: data), let decoded = try? JSONDecoder().decode(DeviceIdentity.self, from: data),
@@ -44,7 +51,7 @@ enum DeviceIdentityStore {
return identity return identity
} }
static func signPayload(_ payload: String, identity: DeviceIdentity) -> String? { public static func signPayload(_ payload: String, identity: DeviceIdentity) -> String? {
guard let privateKeyData = Data(base64Encoded: identity.privateKey) else { return nil } guard let privateKeyData = Data(base64Encoded: identity.privateKey) else { return nil }
do { do {
let privateKey = try Curve25519.Signing.PrivateKey(rawRepresentation: privateKeyData) let privateKey = try Curve25519.Signing.PrivateKey(rawRepresentation: privateKeyData)
@@ -76,7 +83,7 @@ enum DeviceIdentityStore {
.replacingOccurrences(of: "=", with: "") .replacingOccurrences(of: "=", with: "")
} }
static func publicKeyBase64Url(_ identity: DeviceIdentity) -> String? { public static func publicKeyBase64Url(_ identity: DeviceIdentity) -> String? {
guard let data = Data(base64Encoded: identity.publicKey) else { return nil } guard let data = Data(base64Encoded: identity.publicKey) else { return nil }
return self.base64UrlEncode(data) return self.base64UrlEncode(data)
} }

View File

@@ -94,6 +94,10 @@ public struct GatewayConnectOptions: Sendable {
// Avoid ambiguity with the app's own AnyCodable type. // Avoid ambiguity with the app's own AnyCodable type.
private typealias ProtoAnyCodable = ClawdbotProtocol.AnyCodable private typealias ProtoAnyCodable = ClawdbotProtocol.AnyCodable
private enum ConnectChallengeError: Error {
case timeout
}
public actor GatewayChannelActor { public actor GatewayChannelActor {
private let logger = Logger(subsystem: "com.clawdbot", category: "gateway") private let logger = Logger(subsystem: "com.clawdbot", category: "gateway")
private var task: WebSocketTaskBox? private var task: WebSocketTaskBox?
@@ -113,6 +117,7 @@ public actor GatewayChannelActor {
private let decoder = JSONDecoder() private let decoder = JSONDecoder()
private let encoder = JSONEncoder() private let encoder = JSONEncoder()
private let connectTimeoutSeconds: Double = 6 private let connectTimeoutSeconds: Double = 6
private let connectChallengeTimeoutSeconds: Double = 0.75
private var watchdogTask: Task<Void, Never>? private var watchdogTask: Task<Void, Never>?
private var tickTask: Task<Void, Never>? private var tickTask: Task<Void, Never>?
private let defaultRequestTimeoutMs: Double = 15000 private let defaultRequestTimeoutMs: Double = 15000
@@ -294,9 +299,10 @@ public actor GatewayChannelActor {
} }
let identity = DeviceIdentityStore.loadOrCreate() let identity = DeviceIdentityStore.loadOrCreate()
let signedAtMs = Int(Date().timeIntervalSince1970 * 1000) let signedAtMs = Int(Date().timeIntervalSince1970 * 1000)
let connectNonce = try await self.waitForConnectChallenge()
let scopes = options.scopes.joined(separator: ",") let scopes = options.scopes.joined(separator: ",")
let payload = [ var payloadParts = [
"v1", connectNonce == nil ? "v1" : "v2",
identity.deviceId, identity.deviceId,
clientId, clientId,
clientMode, clientMode,
@@ -304,15 +310,23 @@ public actor GatewayChannelActor {
scopes, scopes,
String(signedAtMs), String(signedAtMs),
self.token ?? "", self.token ?? "",
].joined(separator: "|") ]
if let connectNonce {
payloadParts.append(connectNonce)
}
let payload = payloadParts.joined(separator: "|")
if let signature = DeviceIdentityStore.signPayload(payload, identity: identity), if let signature = DeviceIdentityStore.signPayload(payload, identity: identity),
let publicKey = DeviceIdentityStore.publicKeyBase64Url(identity) { let publicKey = DeviceIdentityStore.publicKeyBase64Url(identity) {
params["device"] = ProtoAnyCodable([ var device: [String: ProtoAnyCodable] = [
"id": ProtoAnyCodable(identity.deviceId), "id": ProtoAnyCodable(identity.deviceId),
"publicKey": ProtoAnyCodable(publicKey), "publicKey": ProtoAnyCodable(publicKey),
"signature": ProtoAnyCodable(signature), "signature": ProtoAnyCodable(signature),
"signedAt": ProtoAnyCodable(signedAtMs), "signedAt": ProtoAnyCodable(signedAtMs),
]) ]
if let connectNonce {
device["nonce"] = ProtoAnyCodable(connectNonce)
}
params["device"] = ProtoAnyCodable(device)
} }
let frame = RequestFrame( let frame = RequestFrame(
@@ -322,40 +336,11 @@ public actor GatewayChannelActor {
params: ProtoAnyCodable(params)) params: ProtoAnyCodable(params))
let data = try self.encoder.encode(frame) let data = try self.encoder.encode(frame)
try await self.task?.send(.data(data)) try await self.task?.send(.data(data))
guard let msg = try await task?.receive() else { let response = try await self.waitForConnectResponse(reqId: reqId)
throw NSError( try await self.handleConnectResponse(response)
domain: "Gateway",
code: 1,
userInfo: [NSLocalizedDescriptionKey: "connect failed (no response)"])
}
try await self.handleConnectResponse(msg, reqId: reqId)
} }
private func handleConnectResponse(_ msg: URLSessionWebSocketTask.Message, reqId: String) async throws { private func handleConnectResponse(_ res: ResponseFrame) async throws {
let data: Data? = switch msg {
case let .data(d): d
case let .string(s): s.data(using: .utf8)
@unknown default: nil
}
guard let data else {
throw NSError(
domain: "Gateway",
code: 1,
userInfo: [NSLocalizedDescriptionKey: "connect failed (empty response)"])
}
let decoder = JSONDecoder()
guard let frame = try? decoder.decode(GatewayFrame.self, from: data) else {
throw NSError(
domain: "Gateway",
code: 1,
userInfo: [NSLocalizedDescriptionKey: "connect failed (invalid response)"])
}
guard case let .res(res) = frame, res.id == reqId else {
throw NSError(
domain: "Gateway",
code: 1,
userInfo: [NSLocalizedDescriptionKey: "connect failed (unexpected response)"])
}
if res.ok == false { if res.ok == false {
let msg = (res.error?["message"]?.value as? String) ?? "gateway connect failed" let msg = (res.error?["message"]?.value as? String) ?? "gateway connect failed"
throw NSError(domain: "Gateway", code: 1008, userInfo: [NSLocalizedDescriptionKey: msg]) throw NSError(domain: "Gateway", code: 1008, userInfo: [NSLocalizedDescriptionKey: msg])
@@ -424,6 +409,7 @@ public actor GatewayChannelActor {
waiter.resume(returning: .res(res)) waiter.resume(returning: .res(res))
} }
case let .event(evt): case let .event(evt):
if evt.event == "connect.challenge" { return }
if let seq = evt.seq { if let seq = evt.seq {
if let last = lastSeq, seq > last + 1 { if let last = lastSeq, seq > last + 1 {
await self.pushHandler?(.seqGap(expected: last + 1, received: seq)) await self.pushHandler?(.seqGap(expected: last + 1, received: seq))
@@ -437,6 +423,63 @@ public actor GatewayChannelActor {
} }
} }
private func waitForConnectChallenge() async throws -> String? {
guard let task = self.task else { return nil }
do {
return try await AsyncTimeout.withTimeout(
seconds: self.connectChallengeTimeoutSeconds,
onTimeout: { ConnectChallengeError.timeout },
operation: { [weak self] in
guard let self else { return nil }
while true {
let msg = try await task.receive()
guard let data = self.decodeMessageData(msg) else { continue }
guard let frame = try? self.decoder.decode(GatewayFrame.self, from: data) else { continue }
if case let .event(evt) = frame, evt.event == "connect.challenge" {
if let payload = evt.payload?.value as? [String: ProtoAnyCodable],
let nonce = payload["nonce"]?.value as? String {
return nonce
}
}
}
})
} catch {
if error is ConnectChallengeError { return nil }
throw error
}
}
private func waitForConnectResponse(reqId: String) async throws -> ResponseFrame {
guard let task = self.task else {
throw NSError(
domain: "Gateway",
code: 1,
userInfo: [NSLocalizedDescriptionKey: "connect failed (no response)"])
}
while true {
let msg = try await task.receive()
guard let data = self.decodeMessageData(msg) else { continue }
guard let frame = try? self.decoder.decode(GatewayFrame.self, from: data) else {
throw NSError(
domain: "Gateway",
code: 1,
userInfo: [NSLocalizedDescriptionKey: "connect failed (invalid response)"])
}
if case let .res(res) = frame, res.id == reqId {
return res
}
}
}
private func decodeMessageData(_ msg: URLSessionWebSocketTask.Message) -> Data? {
let data: Data? = switch msg {
case let .data(data): data
case let .string(text): text.data(using: .utf8)
@unknown default: nil
}
return data
}
private func watchTicks() async { private func watchTicks() async {
let tolerance = self.tickIntervalMs * 2 let tolerance = self.tickIntervalMs * 2
while self.connected { while self.connected {

View File

@@ -5,6 +5,7 @@ public let GATEWAY_PROTOCOL_VERSION = 3
public enum ErrorCode: String, Codable, Sendable { public enum ErrorCode: String, Codable, Sendable {
case notLinked = "NOT_LINKED" case notLinked = "NOT_LINKED"
case notPaired = "NOT_PAIRED"
case agentTimeout = "AGENT_TIMEOUT" case agentTimeout = "AGENT_TIMEOUT"
case invalidRequest = "INVALID_REQUEST" case invalidRequest = "INVALID_REQUEST"
case unavailable = "UNAVAILABLE" case unavailable = "UNAVAILABLE"
@@ -15,6 +16,11 @@ public struct ConnectParams: Codable, Sendable {
public let maxprotocol: Int public let maxprotocol: Int
public let client: [String: AnyCodable] public let client: [String: AnyCodable]
public let caps: [String]? public let caps: [String]?
public let commands: [String]?
public let permissions: [String: AnyCodable]?
public let role: String?
public let scopes: [String]?
public let device: [String: AnyCodable]?
public let auth: [String: AnyCodable]? public let auth: [String: AnyCodable]?
public let locale: String? public let locale: String?
public let useragent: String? public let useragent: String?
@@ -24,6 +30,11 @@ public struct ConnectParams: Codable, Sendable {
maxprotocol: Int, maxprotocol: Int,
client: [String: AnyCodable], client: [String: AnyCodable],
caps: [String]?, caps: [String]?,
commands: [String]?,
permissions: [String: AnyCodable]?,
role: String?,
scopes: [String]?,
device: [String: AnyCodable]?,
auth: [String: AnyCodable]?, auth: [String: AnyCodable]?,
locale: String?, locale: String?,
useragent: String? useragent: String?
@@ -32,6 +43,11 @@ public struct ConnectParams: Codable, Sendable {
self.maxprotocol = maxprotocol self.maxprotocol = maxprotocol
self.client = client self.client = client
self.caps = caps self.caps = caps
self.commands = commands
self.permissions = permissions
self.role = role
self.scopes = scopes
self.device = device
self.auth = auth self.auth = auth
self.locale = locale self.locale = locale
self.useragent = useragent self.useragent = useragent
@@ -41,6 +57,11 @@ public struct ConnectParams: Codable, Sendable {
case maxprotocol = "maxProtocol" case maxprotocol = "maxProtocol"
case client case client
case caps case caps
case commands
case permissions
case role
case scopes
case device
case auth case auth
case locale case locale
case useragent = "userAgent" case useragent = "userAgent"
@@ -54,6 +75,7 @@ public struct HelloOk: Codable, Sendable {
public let features: [String: AnyCodable] public let features: [String: AnyCodable]
public let snapshot: Snapshot public let snapshot: Snapshot
public let canvashosturl: String? public let canvashosturl: String?
public let auth: [String: AnyCodable]?
public let policy: [String: AnyCodable] public let policy: [String: AnyCodable]
public init( public init(
@@ -63,6 +85,7 @@ public struct HelloOk: Codable, Sendable {
features: [String: AnyCodable], features: [String: AnyCodable],
snapshot: Snapshot, snapshot: Snapshot,
canvashosturl: String?, canvashosturl: String?,
auth: [String: AnyCodable]?,
policy: [String: AnyCodable] policy: [String: AnyCodable]
) { ) {
self.type = type self.type = type
@@ -71,6 +94,7 @@ public struct HelloOk: Codable, Sendable {
self.features = features self.features = features
self.snapshot = snapshot self.snapshot = snapshot
self.canvashosturl = canvashosturl self.canvashosturl = canvashosturl
self.auth = auth
self.policy = policy self.policy = policy
} }
private enum CodingKeys: String, CodingKey { private enum CodingKeys: String, CodingKey {
@@ -80,6 +104,7 @@ public struct HelloOk: Codable, Sendable {
case features case features
case snapshot case snapshot
case canvashosturl = "canvasHostUrl" case canvashosturl = "canvasHostUrl"
case auth
case policy case policy
} }
} }
@@ -706,6 +731,93 @@ public struct NodeInvokeParams: Codable, Sendable {
} }
} }
public struct NodeInvokeResultParams: Codable, Sendable {
public let id: String
public let nodeid: String
public let ok: Bool
public let payload: AnyCodable?
public let payloadjson: String?
public let error: [String: AnyCodable]?
public init(
id: String,
nodeid: String,
ok: Bool,
payload: AnyCodable?,
payloadjson: String?,
error: [String: AnyCodable]?
) {
self.id = id
self.nodeid = nodeid
self.ok = ok
self.payload = payload
self.payloadjson = payloadjson
self.error = error
}
private enum CodingKeys: String, CodingKey {
case id
case nodeid = "nodeId"
case ok
case payload
case payloadjson = "payloadJSON"
case error
}
}
public struct NodeEventParams: Codable, Sendable {
public let event: String
public let payload: AnyCodable?
public let payloadjson: String?
public init(
event: String,
payload: AnyCodable?,
payloadjson: String?
) {
self.event = event
self.payload = payload
self.payloadjson = payloadjson
}
private enum CodingKeys: String, CodingKey {
case event
case payload
case payloadjson = "payloadJSON"
}
}
public struct NodeInvokeRequestEvent: Codable, Sendable {
public let id: String
public let nodeid: String
public let command: String
public let paramsjson: String?
public let timeoutms: Int?
public let idempotencykey: String?
public init(
id: String,
nodeid: String,
command: String,
paramsjson: String?,
timeoutms: Int?,
idempotencykey: String?
) {
self.id = id
self.nodeid = nodeid
self.command = command
self.paramsjson = paramsjson
self.timeoutms = timeoutms
self.idempotencykey = idempotencykey
}
private enum CodingKeys: String, CodingKey {
case id
case nodeid = "nodeId"
case command
case paramsjson = "paramsJSON"
case timeoutms = "timeoutMs"
case idempotencykey = "idempotencyKey"
}
}
public struct SessionsListParams: Codable, Sendable { public struct SessionsListParams: Codable, Sendable {
public let limit: Int? public let limit: Int?
public let activeminutes: Int? public let activeminutes: Int?
@@ -1381,6 +1493,22 @@ public struct ModelsListResult: Codable, Sendable {
public struct SkillsStatusParams: Codable, Sendable { public struct SkillsStatusParams: Codable, Sendable {
} }
public struct SkillsBinsParams: Codable, Sendable {
}
public struct SkillsBinsResult: Codable, Sendable {
public let bins: [String]
public init(
bins: [String]
) {
self.bins = bins
}
private enum CodingKeys: String, CodingKey {
case bins
}
}
public struct SkillsInstallParams: Codable, Sendable { public struct SkillsInstallParams: Codable, Sendable {
public let name: String public let name: String
public let installid: String public let installid: String
@@ -1735,6 +1863,225 @@ public struct ExecApprovalsSnapshot: Codable, Sendable {
} }
} }
public struct ExecApprovalRequestParams: Codable, Sendable {
public let command: String
public let cwd: String?
public let host: String?
public let security: String?
public let ask: String?
public let agentid: String?
public let resolvedpath: String?
public let sessionkey: String?
public let timeoutms: Int?
public init(
command: String,
cwd: String?,
host: String?,
security: String?,
ask: String?,
agentid: String?,
resolvedpath: String?,
sessionkey: String?,
timeoutms: Int?
) {
self.command = command
self.cwd = cwd
self.host = host
self.security = security
self.ask = ask
self.agentid = agentid
self.resolvedpath = resolvedpath
self.sessionkey = sessionkey
self.timeoutms = timeoutms
}
private enum CodingKeys: String, CodingKey {
case command
case cwd
case host
case security
case ask
case agentid = "agentId"
case resolvedpath = "resolvedPath"
case sessionkey = "sessionKey"
case timeoutms = "timeoutMs"
}
}
public struct ExecApprovalResolveParams: Codable, Sendable {
public let id: String
public let decision: String
public init(
id: String,
decision: String
) {
self.id = id
self.decision = decision
}
private enum CodingKeys: String, CodingKey {
case id
case decision
}
}
public struct DevicePairListParams: Codable, Sendable {
}
public struct DevicePairApproveParams: Codable, Sendable {
public let requestid: String
public init(
requestid: String
) {
self.requestid = requestid
}
private enum CodingKeys: String, CodingKey {
case requestid = "requestId"
}
}
public struct DevicePairRejectParams: Codable, Sendable {
public let requestid: String
public init(
requestid: String
) {
self.requestid = requestid
}
private enum CodingKeys: String, CodingKey {
case requestid = "requestId"
}
}
public struct DeviceTokenRotateParams: Codable, Sendable {
public let deviceid: String
public let role: String
public let scopes: [String]?
public init(
deviceid: String,
role: String,
scopes: [String]?
) {
self.deviceid = deviceid
self.role = role
self.scopes = scopes
}
private enum CodingKeys: String, CodingKey {
case deviceid = "deviceId"
case role
case scopes
}
}
public struct DeviceTokenRevokeParams: Codable, Sendable {
public let deviceid: String
public let role: String
public init(
deviceid: String,
role: String
) {
self.deviceid = deviceid
self.role = role
}
private enum CodingKeys: String, CodingKey {
case deviceid = "deviceId"
case role
}
}
public struct DevicePairRequestedEvent: Codable, Sendable {
public let requestid: String
public let deviceid: String
public let publickey: String
public let displayname: String?
public let platform: String?
public let clientid: String?
public let clientmode: String?
public let role: String?
public let roles: [String]?
public let scopes: [String]?
public let remoteip: String?
public let silent: Bool?
public let isrepair: Bool?
public let ts: Int
public init(
requestid: String,
deviceid: String,
publickey: String,
displayname: String?,
platform: String?,
clientid: String?,
clientmode: String?,
role: String?,
roles: [String]?,
scopes: [String]?,
remoteip: String?,
silent: Bool?,
isrepair: Bool?,
ts: Int
) {
self.requestid = requestid
self.deviceid = deviceid
self.publickey = publickey
self.displayname = displayname
self.platform = platform
self.clientid = clientid
self.clientmode = clientmode
self.role = role
self.roles = roles
self.scopes = scopes
self.remoteip = remoteip
self.silent = silent
self.isrepair = isrepair
self.ts = ts
}
private enum CodingKeys: String, CodingKey {
case requestid = "requestId"
case deviceid = "deviceId"
case publickey = "publicKey"
case displayname = "displayName"
case platform
case clientid = "clientId"
case clientmode = "clientMode"
case role
case roles
case scopes
case remoteip = "remoteIp"
case silent
case isrepair = "isRepair"
case ts
}
}
public struct DevicePairResolvedEvent: Codable, Sendable {
public let requestid: String
public let deviceid: String
public let decision: String
public let ts: Int
public init(
requestid: String,
deviceid: String,
decision: String,
ts: Int
) {
self.requestid = requestid
self.deviceid = deviceid
self.decision = decision
self.ts = ts
}
private enum CodingKeys: String, CodingKey {
case requestid = "requestId"
case deviceid = "deviceId"
case decision
case ts
}
}
public struct ChatHistoryParams: Codable, Sendable { public struct ChatHistoryParams: Codable, Sendable {
public let sessionkey: String public let sessionkey: String
public let limit: Int? public let limit: Int?

View File

@@ -20,6 +20,16 @@ handshake time.
## Handshake (connect) ## Handshake (connect)
Gateway → Client (pre-connect challenge):
```json
{
"type": "event",
"event": "connect.challenge",
"payload": { "nonce": "…", "ts": 1737264000000 }
}
```
Client → Gateway: Client → Gateway:
```json ```json
@@ -43,7 +53,14 @@ Client → Gateway:
"permissions": {}, "permissions": {},
"auth": { "token": "…" }, "auth": { "token": "…" },
"locale": "en-US", "locale": "en-US",
"userAgent": "clawdbot-cli/1.2.3" "userAgent": "clawdbot-cli/1.2.3",
"device": {
"id": "device_fingerprint",
"publicKey": "…",
"signature": "…",
"signedAt": 1737264000000,
"nonce": "…"
}
} }
} }
``` ```
@@ -99,7 +116,8 @@ When a device token is issued, `hello-ok` also includes:
"id": "device_fingerprint", "id": "device_fingerprint",
"publicKey": "…", "publicKey": "…",
"signature": "…", "signature": "…",
"signedAt": 1737264000000 "signedAt": 1737264000000,
"nonce": "…"
} }
} }
} }
@@ -167,6 +185,7 @@ The Gateway treats these as **claims** and enforces server-side allowlists.
- Pairing approvals are required for new device IDs unless local auto-approval - Pairing approvals are required for new device IDs unless local auto-approval
is enabled. is enabled.
- All WS clients must include `device` identity during `connect` (operator + node). - All WS clients must include `device` identity during `connect` (operator + node).
- Non-local connections must sign the server-provided `connect.challenge` nonce.
## TLS + pinning ## TLS + pinning

View File

@@ -288,6 +288,26 @@ Same `deviceId` across roles → single “Instance” row:
--- ---
# Execution checklist (ship order)
- [x] **Devicebound auth (PoP):** nonce challenge + signature verify on connect; remove beareronly for nonlocal.
- [ ] **Rolescoped creds:** issue perrole tokens, rotate, revoke, list; UI/CLI surfaced; audit log entries.
- [ ] **Scope enforcement:** keep paired scopes in sync on rotation; reject/upgrade flows explicit; tests.
- [ ] **Approvals routing:** gatewayhosted approvals; operator UI prompt/resolve; node stops prompting.
- [ ] **TLS pinning for WS:** reuse bridge TLS runtime; discovery advertises fingerprint; client validation.
- [ ] **Discovery + allowlist:** WS discovery TXT includes TLS fingerprint + role hints; node commands filtered by server allowlist.
- [ ] **Presence unification:** dedupe deviceId across roles; include role/scope metadata; “single instance row”.
- [ ] **Docs + examples:** protocol doc, CLI docs, onboarding + security notes; no personal hostnames.
- [ ] **Test coverage:** connect auth paths, rotation/revoke, approvals, TLS fingerprint mismatch, presence.
Process per item:
- Do implementation.
- Fresheyes review (scan for regressions + missing tests).
- Fix issues.
- Commit with Conventional Commit.
- Move to next item.
---
# Security notes # Security notes
- Role/allowlist enforced at gateway boundary. - Role/allowlist enforced at gateway boundary.

8
pnpm-lock.yaml generated
View File

@@ -361,6 +361,9 @@ importers:
ui: ui:
dependencies: dependencies:
'@noble/ed25519':
specifier: 3.0.0
version: 3.0.0
dompurify: dompurify:
specifier: ^3.3.1 specifier: ^3.3.1
version: 3.3.1 version: 3.3.1
@@ -1292,6 +1295,9 @@ packages:
'@napi-rs/wasm-runtime@1.1.1': '@napi-rs/wasm-runtime@1.1.1':
resolution: {integrity: sha512-p64ah1M1ld8xjWv3qbvFwHiFVWrq1yFvV4f7w+mzaqiR4IlSgkqhcRdHwsGgomwzBH51sRY4NEowLxnaBjcW/A==} resolution: {integrity: sha512-p64ah1M1ld8xjWv3qbvFwHiFVWrq1yFvV4f7w+mzaqiR4IlSgkqhcRdHwsGgomwzBH51sRY4NEowLxnaBjcW/A==}
'@noble/ed25519@3.0.0':
resolution: {integrity: sha512-QyteqMNm0GLqfa5SoYbSC3+Pvykwpn95Zgth4MFVSMKBB75ELl9tX1LAVsN4c3HXOrakHsF2gL4zWDAYCcsnzg==}
'@node-llama-cpp/linux-arm64@3.14.5': '@node-llama-cpp/linux-arm64@3.14.5':
resolution: {integrity: sha512-58IcWW7EOqc/66mYWXRsoMCy1MR3pTX/YaC0HYF9Rg5XeAPKhUP7NHrglbqgjO62CkcuFZaSEiX2AtG972GQYQ==} resolution: {integrity: sha512-58IcWW7EOqc/66mYWXRsoMCy1MR3pTX/YaC0HYF9Rg5XeAPKhUP7NHrglbqgjO62CkcuFZaSEiX2AtG972GQYQ==}
engines: {node: '>=20.0.0'} engines: {node: '>=20.0.0'}
@@ -6131,6 +6137,8 @@ snapshots:
'@tybys/wasm-util': 0.10.1 '@tybys/wasm-util': 0.10.1
optional: true optional: true
'@noble/ed25519@3.0.0': {}
'@node-llama-cpp/linux-arm64@3.14.5': '@node-llama-cpp/linux-arm64@3.14.5':
optional: true optional: true

View File

@@ -18,14 +18,25 @@ type JsonSchema = {
const __dirname = path.dirname(fileURLToPath(import.meta.url)); const __dirname = path.dirname(fileURLToPath(import.meta.url));
const repoRoot = path.resolve(__dirname, ".."); const repoRoot = path.resolve(__dirname, "..");
const outPath = path.join( const outPaths = [
repoRoot, path.join(
"apps", repoRoot,
"macos", "apps",
"Sources", "macos",
"ClawdbotProtocol", "Sources",
"GatewayModels.swift", "ClawdbotProtocol",
); "GatewayModels.swift",
),
path.join(
repoRoot,
"apps",
"shared",
"ClawdbotKit",
"Sources",
"ClawdbotProtocol",
"GatewayModels.swift",
),
];
const header = `// Generated by scripts/protocol-gen-swift.ts — do not edit by hand\nimport Foundation\n\npublic let GATEWAY_PROTOCOL_VERSION = ${PROTOCOL_VERSION}\n\npublic enum ErrorCode: String, Codable, Sendable {\n${Object.values(ErrorCodes) const header = `// Generated by scripts/protocol-gen-swift.ts — do not edit by hand\nimport Foundation\n\npublic let GATEWAY_PROTOCOL_VERSION = ${PROTOCOL_VERSION}\n\npublic enum ErrorCode: String, Codable, Sendable {\n${Object.values(ErrorCodes)
.map((c) => ` case ${camelCase(c)} = "${c}"`) .map((c) => ` case ${camelCase(c)} = "${c}"`)
@@ -221,9 +232,11 @@ async function generate() {
parts.push(emitGatewayFrame()); parts.push(emitGatewayFrame());
const content = parts.join("\n"); const content = parts.join("\n");
await fs.mkdir(path.dirname(outPath), { recursive: true }); for (const outPath of outPaths) {
await fs.writeFile(outPath, content); await fs.mkdir(path.dirname(outPath), { recursive: true });
console.log(`wrote ${outPath}`); await fs.writeFile(outPath, content);
console.log(`wrote ${outPath}`);
}
} }
generate().catch((err) => { generate().catch((err) => {

View File

@@ -72,11 +72,14 @@ export function describeGatewayCloseCode(code: number): string | undefined {
export class GatewayClient { export class GatewayClient {
private ws: WebSocket | null = null; private ws: WebSocket | null = null;
private opts: GatewayClientOptions & { deviceIdentity: DeviceIdentity }; private opts: GatewayClientOptions;
private pending = new Map<string, Pending>(); private pending = new Map<string, Pending>();
private backoffMs = 1000; private backoffMs = 1000;
private closed = false; private closed = false;
private lastSeq: number | null = null; private lastSeq: number | null = null;
private connectNonce: string | null = null;
private connectSent = false;
private connectTimer: NodeJS.Timeout | null = null;
// Track last tick to detect silent stalls. // Track last tick to detect silent stalls.
private lastTick: number | null = null; private lastTick: number | null = null;
private tickIntervalMs = 30_000; private tickIntervalMs = 30_000;
@@ -121,7 +124,7 @@ export class GatewayClient {
} }
this.ws = new WebSocket(url, wsOptions); this.ws = new WebSocket(url, wsOptions);
this.ws.on("open", () => this.sendConnect()); this.ws.on("open", () => this.queueConnect());
this.ws.on("message", (data) => this.handleMessage(rawDataToString(data))); this.ws.on("message", (data) => this.handleMessage(rawDataToString(data)));
this.ws.on("close", (code, reason) => { this.ws.on("close", (code, reason) => {
const reasonText = rawDataToString(reason); const reasonText = rawDataToString(reason);
@@ -147,6 +150,12 @@ export class GatewayClient {
} }
private sendConnect() { private sendConnect() {
if (this.connectSent) return;
this.connectSent = true;
if (this.connectTimer) {
clearTimeout(this.connectTimer);
this.connectTimer = null;
}
const role = this.opts.role ?? "operator"; const role = this.opts.role ?? "operator";
const storedToken = this.opts.deviceIdentity const storedToken = this.opts.deviceIdentity
? loadDeviceAuthToken({ deviceId: this.opts.deviceIdentity.deviceId, role })?.token ? loadDeviceAuthToken({ deviceId: this.opts.deviceIdentity.deviceId, role })?.token
@@ -160,24 +169,29 @@ export class GatewayClient {
} }
: undefined; : undefined;
const signedAtMs = Date.now(); const signedAtMs = Date.now();
const nonce = this.connectNonce ?? undefined;
const scopes = this.opts.scopes ?? ["operator.admin"]; const scopes = this.opts.scopes ?? ["operator.admin"];
const deviceIdentity = this.opts.deviceIdentity; const device = (() => {
const payload = buildDeviceAuthPayload({ if (!this.opts.deviceIdentity) return undefined;
deviceId: deviceIdentity.deviceId, const payload = buildDeviceAuthPayload({
clientId: this.opts.clientName ?? GATEWAY_CLIENT_NAMES.GATEWAY_CLIENT, deviceId: this.opts.deviceIdentity.deviceId,
clientMode: this.opts.mode ?? GATEWAY_CLIENT_MODES.BACKEND, clientId: this.opts.clientName ?? GATEWAY_CLIENT_NAMES.GATEWAY_CLIENT,
role, clientMode: this.opts.mode ?? GATEWAY_CLIENT_MODES.BACKEND,
scopes, role,
signedAtMs, scopes,
token: authToken ?? null, signedAtMs,
}); token: authToken ?? null,
const signature = signDevicePayload(deviceIdentity.privateKeyPem, payload); nonce,
const device = { });
id: deviceIdentity.deviceId, const signature = signDevicePayload(this.opts.deviceIdentity.privateKeyPem, payload);
publicKey: publicKeyRawBase64UrlFromPem(deviceIdentity.publicKeyPem), return {
signature, id: this.opts.deviceIdentity.deviceId,
signedAt: signedAtMs, publicKey: publicKeyRawBase64UrlFromPem(this.opts.deviceIdentity.publicKeyPem),
}; signature,
signedAt: signedAtMs,
nonce,
};
})();
const params: ConnectParams = { const params: ConnectParams = {
minProtocol: this.opts.minProtocol ?? PROTOCOL_VERSION, minProtocol: this.opts.minProtocol ?? PROTOCOL_VERSION,
maxProtocol: this.opts.maxProtocol ?? PROTOCOL_VERSION, maxProtocol: this.opts.maxProtocol ?? PROTOCOL_VERSION,
@@ -235,6 +249,15 @@ export class GatewayClient {
const parsed = JSON.parse(raw); const parsed = JSON.parse(raw);
if (validateEventFrame(parsed)) { if (validateEventFrame(parsed)) {
const evt = parsed as EventFrame; const evt = parsed as EventFrame;
if (evt.event === "connect.challenge") {
const payload = evt.payload as { nonce?: unknown } | undefined;
const nonce = payload && typeof payload.nonce === "string" ? payload.nonce : null;
if (nonce) {
this.connectNonce = nonce;
this.sendConnect();
}
return;
}
const seq = typeof evt.seq === "number" ? evt.seq : null; const seq = typeof evt.seq === "number" ? evt.seq : null;
if (seq !== null) { if (seq !== null) {
if (this.lastSeq !== null && seq > this.lastSeq + 1) { if (this.lastSeq !== null && seq > this.lastSeq + 1) {
@@ -266,6 +289,15 @@ export class GatewayClient {
} }
} }
private queueConnect() {
this.connectNonce = null;
this.connectSent = false;
if (this.connectTimer) clearTimeout(this.connectTimer);
this.connectTimer = setTimeout(() => {
this.sendConnect();
}, 750);
}
private scheduleReconnect() { private scheduleReconnect() {
if (this.closed) return; if (this.closed) return;
if (this.tickTimer) { if (this.tickTimer) {

View File

@@ -6,13 +6,16 @@ export type DeviceAuthPayloadParams = {
scopes: string[]; scopes: string[];
signedAtMs: number; signedAtMs: number;
token?: string | null; token?: string | null;
nonce?: string | null;
version?: "v1" | "v2";
}; };
export function buildDeviceAuthPayload(params: DeviceAuthPayloadParams): string { export function buildDeviceAuthPayload(params: DeviceAuthPayloadParams): string {
const version = params.version ?? (params.nonce ? "v2" : "v1");
const scopes = params.scopes.join(","); const scopes = params.scopes.join(",");
const token = params.token ?? ""; const token = params.token ?? "";
return [ const base = [
"v1", version,
params.deviceId, params.deviceId,
params.clientId, params.clientId,
params.clientMode, params.clientMode,
@@ -20,5 +23,9 @@ export function buildDeviceAuthPayload(params: DeviceAuthPayloadParams): string
scopes, scopes,
String(params.signedAtMs), String(params.signedAtMs),
token, token,
].join("|"); ];
if (version === "v2") {
base.push(params.nonce ?? "");
}
return base.join("|");
} }

View File

@@ -46,6 +46,7 @@ export const ConnectParamsSchema = Type.Object(
publicKey: NonEmptyString, publicKey: NonEmptyString,
signature: NonEmptyString, signature: NonEmptyString,
signedAt: Type.Integer({ minimum: 0 }), signedAt: Type.Integer({ minimum: 0 }),
nonce: Type.Optional(NonEmptyString),
}, },
{ additionalProperties: false }, { additionalProperties: false },
), ),

View File

@@ -81,6 +81,7 @@ export function listGatewayMethods(): string[] {
} }
export const GATEWAY_EVENTS = [ export const GATEWAY_EVENTS = [
"connect.challenge",
"agent", "agent",
"chat", "chat",
"presence", "presence",

View File

@@ -57,6 +57,22 @@ describe("gateway server auth/connect", () => {
await server.close(); await server.close();
}); });
test("sends connect challenge on open", async () => {
const port = await getFreePort();
const server = await startGatewayServer(port);
const ws = new WebSocket(`ws://127.0.0.1:${port}`);
const evtPromise = onceMessage<{ payload?: unknown }>(
ws,
(o) => o.type === "event" && o.event === "connect.challenge",
);
await new Promise<void>((resolve) => ws.once("open", resolve));
const evt = await evtPromise;
const nonce = (evt.payload as { nonce?: unknown } | undefined)?.nonce;
expect(typeof nonce).toBe("string");
ws.close();
await server.close();
});
test("rejects protocol mismatch", async () => { test("rejects protocol mismatch", async () => {
const { server, ws } = await startServerWithClient(); const { server, ws } = await startServerWithClient();
try { try {

View File

@@ -116,6 +116,13 @@ export function attachGatewayWsConnectionHandler(params: {
} }
}; };
const connectNonce = randomUUID();
send({
type: "event",
event: "connect.challenge",
payload: { nonce: connectNonce, ts: Date.now() },
});
const close = (code = 1000, reason?: string) => { const close = (code = 1000, reason?: string) => {
if (closed) return; if (closed) return;
closed = true; closed = true;
@@ -224,6 +231,7 @@ export function attachGatewayWsConnectionHandler(params: {
requestOrigin, requestOrigin,
requestUserAgent, requestUserAgent,
canvasHostUrl, canvasHostUrl,
connectNonce,
resolvedAuth, resolvedAuth,
gatewayMethods, gatewayMethods,
events, events,

View File

@@ -68,6 +68,7 @@ export function attachGatewayWsMessageHandler(params: {
requestOrigin?: string; requestOrigin?: string;
requestUserAgent?: string; requestUserAgent?: string;
canvasHostUrl?: string; canvasHostUrl?: string;
connectNonce: string;
resolvedAuth: ResolvedGatewayAuth; resolvedAuth: ResolvedGatewayAuth;
gatewayMethods: string[]; gatewayMethods: string[];
events: string[]; events: string[];
@@ -96,6 +97,7 @@ export function attachGatewayWsMessageHandler(params: {
requestOrigin, requestOrigin,
requestUserAgent, requestUserAgent,
canvasHostUrl, canvasHostUrl,
connectNonce,
resolvedAuth, resolvedAuth,
gatewayMethods, gatewayMethods,
events, events,
@@ -307,6 +309,40 @@ export function attachGatewayWsMessageHandler(params: {
close(1008, "device signature expired"); close(1008, "device signature expired");
return; return;
} }
const nonceRequired = !isLoopbackAddress(remoteAddr);
const providedNonce = typeof device.nonce === "string" ? device.nonce.trim() : "";
if (nonceRequired && !providedNonce) {
setHandshakeState("failed");
setCloseCause("device-auth-invalid", {
reason: "device-nonce-missing",
client: connectParams.client.id,
deviceId: device.id,
});
send({
type: "res",
id: frame.id,
ok: false,
error: errorShape(ErrorCodes.INVALID_REQUEST, "device nonce required"),
});
close(1008, "device nonce required");
return;
}
if (providedNonce && providedNonce !== connectNonce) {
setHandshakeState("failed");
setCloseCause("device-auth-invalid", {
reason: "device-nonce-mismatch",
client: connectParams.client.id,
deviceId: device.id,
});
send({
type: "res",
id: frame.id,
ok: false,
error: errorShape(ErrorCodes.INVALID_REQUEST, "device nonce mismatch"),
});
close(1008, "device nonce mismatch");
return;
}
const payload = buildDeviceAuthPayload({ const payload = buildDeviceAuthPayload({
deviceId: device.id, deviceId: device.id,
clientId: connectParams.client.id, clientId: connectParams.client.id,
@@ -315,8 +351,41 @@ export function attachGatewayWsMessageHandler(params: {
scopes: requestedScopes, scopes: requestedScopes,
signedAtMs: signedAt, signedAtMs: signedAt,
token: connectParams.auth?.token ?? null, token: connectParams.auth?.token ?? null,
nonce: providedNonce || undefined,
version: providedNonce ? "v2" : "v1",
}); });
if (!verifyDeviceSignature(device.publicKey, payload, device.signature)) { const signatureOk = verifyDeviceSignature(device.publicKey, payload, device.signature);
const allowLegacy = !nonceRequired && !providedNonce;
if (!signatureOk && allowLegacy) {
const legacyPayload = buildDeviceAuthPayload({
deviceId: device.id,
clientId: connectParams.client.id,
clientMode: connectParams.client.mode,
role,
scopes: requestedScopes,
signedAtMs: signedAt,
token: connectParams.auth?.token ?? null,
version: "v1",
});
if (verifyDeviceSignature(device.publicKey, legacyPayload, device.signature)) {
// accepted legacy loopback signature
} else {
setHandshakeState("failed");
setCloseCause("device-auth-invalid", {
reason: "device-signature",
client: connectParams.client.id,
deviceId: device.id,
});
send({
type: "res",
id: frame.id,
ok: false,
error: errorShape(ErrorCodes.INVALID_REQUEST, "device signature invalid"),
});
close(1008, "device signature invalid");
return;
}
} else if (!signatureOk) {
setHandshakeState("failed"); setHandshakeState("failed");
setCloseCause("device-auth-invalid", { setCloseCause("device-auth-invalid", {
reason: "device-signature", reason: "device-signature",
@@ -460,11 +529,7 @@ export function attachGatewayWsMessageHandler(params: {
if (!ok) return; if (!ok) return;
} else { } else {
const allowedRoles = new Set( const allowedRoles = new Set(
Array.isArray(paired.roles) Array.isArray(paired.roles) ? paired.roles : paired.role ? [paired.role] : [],
? paired.roles
: paired.role
? [paired.role]
: [],
); );
if (allowedRoles.size === 0) { if (allowedRoles.size === 0) {
const ok = await requirePairing("role-upgrade", paired); const ok = await requirePairing("role-upgrade", paired);

View File

@@ -279,6 +279,7 @@ export async function connectReq(
publicKey: string; publicKey: string;
signature: string; signature: string;
signedAt: number; signedAt: number;
nonce?: string;
}; };
}, },
): Promise<ConnectResponse> { ): Promise<ConnectResponse> {
@@ -310,6 +311,7 @@ export async function connectReq(
publicKey: publicKeyRawBase64UrlFromPem(identity.publicKeyPem), publicKey: publicKeyRawBase64UrlFromPem(identity.publicKeyPem),
signature: signDevicePayload(identity.privateKeyPem, payload), signature: signDevicePayload(identity.privateKeyPem, payload),
signedAt: signedAtMs, signedAt: signedAtMs,
nonce: opts?.device?.nonce,
}; };
})(); })();
ws.send( ws.send(

View File

@@ -9,6 +9,7 @@
"test": "vitest run --config vitest.config.ts" "test": "vitest run --config vitest.config.ts"
}, },
"dependencies": { "dependencies": {
"@noble/ed25519": "3.0.0",
"dompurify": "^3.3.1", "dompurify": "^3.3.1",
"lit": "^3.3.2", "lit": "^3.3.2",
"marked": "^17.0.1", "marked": "^17.0.1",

View File

@@ -0,0 +1,108 @@
import { ed25519 } from "@noble/ed25519";
type StoredIdentity = {
version: 1;
deviceId: string;
publicKey: string;
privateKey: string;
createdAtMs: number;
};
export type DeviceIdentity = {
deviceId: string;
publicKey: string;
privateKey: string;
};
const STORAGE_KEY = "clawdbot-device-identity-v1";
function base64UrlEncode(bytes: Uint8Array): string {
let binary = "";
for (const byte of bytes) binary += String.fromCharCode(byte);
return btoa(binary).replaceAll("+", "-").replaceAll("/", "_").replace(/=+$/g, "");
}
function base64UrlDecode(input: string): Uint8Array {
const normalized = input.replaceAll("-", "+").replaceAll("_", "/");
const padded = normalized + "=".repeat((4 - (normalized.length % 4)) % 4);
const binary = atob(padded);
const out = new Uint8Array(binary.length);
for (let i = 0; i < binary.length; i += 1) out[i] = binary.charCodeAt(i);
return out;
}
function bytesToHex(bytes: Uint8Array): string {
return Array.from(bytes)
.map((b) => b.toString(16).padStart(2, "0"))
.join("");
}
async function fingerprintPublicKey(publicKey: Uint8Array): Promise<string> {
const hash = await crypto.subtle.digest("SHA-256", publicKey);
return bytesToHex(new Uint8Array(hash));
}
async function generateIdentity(): Promise<DeviceIdentity> {
const privateKey = ed25519.utils.randomPrivateKey();
const publicKey = await ed25519.getPublicKey(privateKey);
const deviceId = await fingerprintPublicKey(publicKey);
return {
deviceId,
publicKey: base64UrlEncode(publicKey),
privateKey: base64UrlEncode(privateKey),
};
}
export async function loadOrCreateDeviceIdentity(): Promise<DeviceIdentity> {
try {
const raw = localStorage.getItem(STORAGE_KEY);
if (raw) {
const parsed = JSON.parse(raw) as StoredIdentity;
if (
parsed?.version === 1 &&
typeof parsed.deviceId === "string" &&
typeof parsed.publicKey === "string" &&
typeof parsed.privateKey === "string"
) {
const derivedId = await fingerprintPublicKey(base64UrlDecode(parsed.publicKey));
if (derivedId !== parsed.deviceId) {
const updated: StoredIdentity = {
...parsed,
deviceId: derivedId,
};
localStorage.setItem(STORAGE_KEY, JSON.stringify(updated));
return {
deviceId: derivedId,
publicKey: parsed.publicKey,
privateKey: parsed.privateKey,
};
}
return {
deviceId: parsed.deviceId,
publicKey: parsed.publicKey,
privateKey: parsed.privateKey,
};
}
}
} catch {
// fall through to regenerate
}
const identity = await generateIdentity();
const stored: StoredIdentity = {
version: 1,
deviceId: identity.deviceId,
publicKey: identity.publicKey,
privateKey: identity.privateKey,
createdAtMs: Date.now(),
};
localStorage.setItem(STORAGE_KEY, JSON.stringify(stored));
return identity;
}
export async function signDevicePayload(privateKeyBase64Url: string, payload: string) {
const key = base64UrlDecode(privateKeyBase64Url);
const data = new TextEncoder().encode(payload);
const sig = await ed25519.sign(data, key);
return base64UrlEncode(sig);
}

View File

@@ -5,6 +5,8 @@ import {
type GatewayClientMode, type GatewayClientMode,
type GatewayClientName, type GatewayClientName,
} from "../../../src/gateway/protocol/client-info.js"; } from "../../../src/gateway/protocol/client-info.js";
import { buildDeviceAuthPayload } from "../../../src/gateway/device-auth.js";
import { loadOrCreateDeviceIdentity, signDevicePayload } from "./device-identity";
export type GatewayEventFrame = { export type GatewayEventFrame = {
type: "event"; type: "event";
@@ -58,6 +60,9 @@ export class GatewayBrowserClient {
private pending = new Map<string, Pending>(); private pending = new Map<string, Pending>();
private closed = false; private closed = false;
private lastSeq: number | null = null; private lastSeq: number | null = null;
private connectNonce: string | null = null;
private connectSent = false;
private connectTimer: number | null = null;
private backoffMs = 800; private backoffMs = 800;
constructor(private opts: GatewayBrowserClientOptions) {} constructor(private opts: GatewayBrowserClientOptions) {}
@@ -81,7 +86,7 @@ export class GatewayBrowserClient {
private connect() { private connect() {
if (this.closed) return; if (this.closed) return;
this.ws = new WebSocket(this.opts.url); this.ws = new WebSocket(this.opts.url);
this.ws.onopen = () => this.sendConnect(); this.ws.onopen = () => this.queueConnect();
this.ws.onmessage = (ev) => this.handleMessage(String(ev.data ?? "")); this.ws.onmessage = (ev) => this.handleMessage(String(ev.data ?? ""));
this.ws.onclose = (ev) => { this.ws.onclose = (ev) => {
const reason = String(ev.reason ?? ""); const reason = String(ev.reason ?? "");
@@ -107,7 +112,14 @@ export class GatewayBrowserClient {
this.pending.clear(); this.pending.clear();
} }
private sendConnect() { private async sendConnect() {
if (this.connectSent) return;
this.connectSent = true;
if (this.connectTimer !== null) {
window.clearTimeout(this.connectTimer);
this.connectTimer = null;
}
const deviceIdentity = await loadOrCreateDeviceIdentity();
const auth = const auth =
this.opts.token || this.opts.password this.opts.token || this.opts.password
? { ? {
@@ -115,6 +127,21 @@ export class GatewayBrowserClient {
password: this.opts.password, password: this.opts.password,
} }
: undefined; : undefined;
const scopes = ["operator.admin"];
const role = "operator";
const signedAtMs = Date.now();
const nonce = this.connectNonce ?? undefined;
const payload = buildDeviceAuthPayload({
deviceId: deviceIdentity.deviceId,
clientId: this.opts.clientName ?? GATEWAY_CLIENT_NAMES.CONTROL_UI,
clientMode: this.opts.mode ?? GATEWAY_CLIENT_MODES.WEBCHAT,
role,
scopes,
signedAtMs,
token: this.opts.token ?? null,
nonce,
});
const signature = await signDevicePayload(deviceIdentity.privateKey, payload);
const params = { const params = {
minProtocol: 3, minProtocol: 3,
maxProtocol: 3, maxProtocol: 3,
@@ -125,6 +152,15 @@ export class GatewayBrowserClient {
mode: this.opts.mode ?? GATEWAY_CLIENT_MODES.WEBCHAT, mode: this.opts.mode ?? GATEWAY_CLIENT_MODES.WEBCHAT,
instanceId: this.opts.instanceId, instanceId: this.opts.instanceId,
}, },
role,
scopes,
device: {
id: deviceIdentity.deviceId,
publicKey: deviceIdentity.publicKey,
signature,
signedAt: signedAtMs,
nonce,
},
caps: [], caps: [],
auth, auth,
userAgent: navigator.userAgent, userAgent: navigator.userAgent,
@@ -152,6 +188,15 @@ export class GatewayBrowserClient {
const frame = parsed as { type?: unknown }; const frame = parsed as { type?: unknown };
if (frame.type === "event") { if (frame.type === "event") {
const evt = parsed as GatewayEventFrame; const evt = parsed as GatewayEventFrame;
if (evt.event === "connect.challenge") {
const payload = evt.payload as { nonce?: unknown } | undefined;
const nonce = payload && typeof payload.nonce === "string" ? payload.nonce : null;
if (nonce) {
this.connectNonce = nonce;
void this.sendConnect();
}
return;
}
const seq = typeof evt.seq === "number" ? evt.seq : null; const seq = typeof evt.seq === "number" ? evt.seq : null;
if (seq !== null) { if (seq !== null) {
if (this.lastSeq !== null && seq > this.lastSeq + 1) { if (this.lastSeq !== null && seq > this.lastSeq + 1) {
@@ -186,4 +231,13 @@ export class GatewayBrowserClient {
this.ws.send(JSON.stringify(frame)); this.ws.send(JSON.stringify(frame));
return p; return p;
} }
private queueConnect() {
this.connectNonce = null;
this.connectSent = false;
if (this.connectTimer !== null) window.clearTimeout(this.connectTimer);
this.connectTimer = window.setTimeout(() => {
void this.sendConnect();
}, 750);
}
} }