fix: stabilize gateway ws + iOS

This commit is contained in:
Peter Steinberger
2026-01-19 06:22:01 +00:00
parent 73afbc9193
commit 3776de906f
14 changed files with 105 additions and 46 deletions

View File

@@ -264,8 +264,8 @@ final class NodeAppModel {
self.gatewayRemoteAddress = nil self.gatewayRemoteAddress = nil
self.gatewayConnected = false self.gatewayConnected = false
self.showLocalCanvasOnDisconnect() self.showLocalCanvasOnDisconnect()
self.gatewayStatusText = "Disconnected: \(reason)"
} }
self.gatewayStatusText = "Disconnected: \(reason)"
}, },
onInvoke: { [weak self] req in onInvoke: { [weak self] req in
guard let self else { guard let self else {
@@ -409,8 +409,10 @@ final class NodeAppModel {
for await evt in stream { for await evt in stream {
if Task.isCancelled { return } if Task.isCancelled { return }
guard evt.event == "voicewake.changed" else { continue } guard evt.event == "voicewake.changed" else { continue }
guard let payloadJSON = evt.payloadJSON else { continue } guard let payload = evt.payload else { continue }
guard let triggers = VoiceWakePreferences.decodeGatewayTriggers(from: payloadJSON) else { continue } struct Payload: Decodable { var triggers: [String] }
guard let decoded = try? GatewayPayloadDecoding.decode(payload, as: Payload.self) else { continue }
let triggers = VoiceWakePreferences.sanitizeTriggerWords(decoded.triggers)
VoiceWakePreferences.saveTriggerWords(triggers) VoiceWakePreferences.saveTriggerWords(triggers)
} }
} }

View File

@@ -44,7 +44,7 @@ public protocol WebSocketSessioning: AnyObject {
} }
extension URLSession: WebSocketSessioning { extension URLSession: WebSocketSessioning {
func makeWebSocketTask(url: URL) -> WebSocketTaskBox { public func makeWebSocketTask(url: URL) -> WebSocketTaskBox {
let task = self.webSocketTask(with: url) let task = self.webSocketTask(with: url)
// Avoid "Message too long" receive errors for large snapshots / history payloads. // Avoid "Message too long" receive errors for large snapshots / history payloads.
task.maximumMessageSize = 16 * 1024 * 1024 // 16 MB task.maximumMessageSize = 16 * 1024 * 1024 // 16 MB
@@ -54,6 +54,10 @@ extension URLSession: WebSocketSessioning {
public struct WebSocketSessionBox: @unchecked Sendable { public struct WebSocketSessionBox: @unchecked Sendable {
public let session: any WebSocketSessioning public let session: any WebSocketSessioning
public init(session: any WebSocketSessioning) {
self.session = session
}
} }
public struct GatewayConnectOptions: Sendable { public struct GatewayConnectOptions: Sendable {
@@ -472,7 +476,7 @@ public actor GatewayChannelActor {
public func request( public func request(
method: String, method: String,
params: [String: ClawdbotProtocol.AnyCodable]?, params: [String: AnyCodable]?,
timeoutMs: Double? = nil) async throws -> Data timeoutMs: Double? = nil) async throws -> Data
{ {
do { do {
@@ -525,8 +529,8 @@ public actor GatewayChannelActor {
if res.ok == false { if res.ok == false {
let code = res.error?["code"]?.value as? String let code = res.error?["code"]?.value as? String
let msg = res.error?["message"]?.value as? String let msg = res.error?["message"]?.value as? String
let details: [String: ClawdbotProtocol.AnyCodable] = (res.error ?? [:]).reduce(into: [:]) { acc, pair in let details: [String: AnyCodable] = (res.error ?? [:]).reduce(into: [:]) { acc, pair in
acc[pair.key] = ClawdbotProtocol.AnyCodable(pair.value.value) acc[pair.key] = AnyCodable(pair.value.value)
} }
throw GatewayResponseError(method: method, code: code, message: msg, details: details) throw GatewayResponseError(method: method, code: code, message: msg, details: details)
} }

View File

@@ -26,6 +26,8 @@ public actor GatewayNodeSession {
private var serverEventSubscribers: [UUID: AsyncStream<EventFrame>.Continuation] = [:] private var serverEventSubscribers: [UUID: AsyncStream<EventFrame>.Continuation] = [:]
private var canvasHostUrl: String? private var canvasHostUrl: String?
public init() {}
public func connect( public func connect(
url: URL, url: URL,
token: String?, token: String?,
@@ -107,9 +109,9 @@ public actor GatewayNodeSession {
public func sendEvent(event: String, payloadJSON: String?) async { public func sendEvent(event: String, payloadJSON: String?) async {
guard let channel = self.channel else { return } guard let channel = self.channel else { return }
let params: [String: ClawdbotProtocol.AnyCodable] = [ let params: [String: AnyCodable] = [
"event": ClawdbotProtocol.AnyCodable(event), "event": AnyCodable(event),
"payloadJSON": ClawdbotProtocol.AnyCodable(payloadJSON ?? NSNull()), "payloadJSON": AnyCodable(payloadJSON ?? NSNull()),
] ]
do { do {
_ = try await channel.request(method: "node.event", params: params, timeoutMs: 8000) _ = try await channel.request(method: "node.event", params: params, timeoutMs: 8000)
@@ -174,16 +176,16 @@ public actor GatewayNodeSession {
private func sendInvokeResult(request: NodeInvokeRequestPayload, response: BridgeInvokeResponse) async { private func sendInvokeResult(request: NodeInvokeRequestPayload, response: BridgeInvokeResponse) async {
guard let channel = self.channel else { return } guard let channel = self.channel else { return }
var params: [String: ClawdbotProtocol.AnyCodable] = [ var params: [String: AnyCodable] = [
"id": ClawdbotProtocol.AnyCodable(request.id), "id": AnyCodable(request.id),
"nodeId": ClawdbotProtocol.AnyCodable(request.nodeId), "nodeId": AnyCodable(request.nodeId),
"ok": ClawdbotProtocol.AnyCodable(response.ok), "ok": AnyCodable(response.ok),
"payloadJSON": ClawdbotProtocol.AnyCodable(response.payloadJSON ?? NSNull()), "payloadJSON": AnyCodable(response.payloadJSON ?? NSNull()),
] ]
if let error = response.error { if let error = response.error {
params["error"] = ClawdbotProtocol.AnyCodable([ params["error"] = AnyCodable([
"code": ClawdbotProtocol.AnyCodable(error.code.rawValue), "code": AnyCodable(error.code.rawValue),
"message": ClawdbotProtocol.AnyCodable(error.message), "message": AnyCodable(error.message),
]) ])
} }
do { do {
@@ -194,7 +196,7 @@ public actor GatewayNodeSession {
} }
private func decodeParamsJSON( private func decodeParamsJSON(
_ paramsJSON: String?) throws -> [String: ClawdbotProtocol.AnyCodable]? _ paramsJSON: String?) throws -> [String: AnyCodable]?
{ {
guard let paramsJSON, !paramsJSON.isEmpty else { return nil } guard let paramsJSON, !paramsJSON.isEmpty else { return nil }
guard let data = paramsJSON.data(using: .utf8) else { guard let data = paramsJSON.data(using: .utf8) else {
@@ -207,13 +209,13 @@ public actor GatewayNodeSession {
return nil return nil
} }
return dict.reduce(into: [:]) { acc, entry in return dict.reduce(into: [:]) { acc, entry in
acc[entry.key] = ClawdbotProtocol.AnyCodable(entry.value) acc[entry.key] = AnyCodable(entry.value)
} }
} }
private func broadcastServerEvent(_ evt: EventFrame) { private func broadcastServerEvent(_ evt: EventFrame) {
for (id, continuation) in self.serverEventSubscribers { for (id, continuation) in self.serverEventSubscribers {
if continuation.yield(evt) == .terminated { if case .terminated = continuation.yield(evt) {
self.serverEventSubscribers.removeValue(forKey: id) self.serverEventSubscribers.removeValue(forKey: id)
} }
} }

View File

@@ -10,6 +10,14 @@ public enum GatewayPayloadDecoding {
return try JSONDecoder().decode(T.self, from: data) return try JSONDecoder().decode(T.self, from: data)
} }
public static func decode<T: Decodable>(
_ payload: AnyCodable,
as _: T.Type = T.self) throws -> T
{
let data = try JSONEncoder().encode(payload)
return try JSONDecoder().decode(T.self, from: data)
}
public static func decodeIfPresent<T: Decodable>( public static func decodeIfPresent<T: Decodable>(
_ payload: ClawdbotProtocol.AnyCodable?, _ payload: ClawdbotProtocol.AnyCodable?,
as _: T.Type = T.self) throws -> T? as _: T.Type = T.self) throws -> T?
@@ -17,4 +25,12 @@ public enum GatewayPayloadDecoding {
guard let payload else { return nil } guard let payload else { return nil }
return try self.decode(payload, as: T.self) return try self.decode(payload, as: T.self)
} }
public static func decodeIfPresent<T: Decodable>(
_ payload: AnyCodable?,
as _: T.Type = T.self) throws -> T?
{
guard let payload else { return nil }
return try self.decode(payload, as: T.self)
}
} }

View File

@@ -12,6 +12,17 @@ public enum InstanceIdentity {
UserDefaults(suiteName: suiteName) ?? .standard UserDefaults(suiteName: suiteName) ?? .standard
} }
#if canImport(UIKit)
private static func readMainActor<T: Sendable>(_ body: @MainActor () -> T) -> T {
if Thread.isMainThread {
return MainActor.assumeIsolated { body() }
}
return DispatchQueue.main.sync {
MainActor.assumeIsolated { body() }
}
}
#endif
public static let instanceId: String = { public static let instanceId: String = {
let defaults = Self.defaults let defaults = Self.defaults
if let existing = defaults.string(forKey: instanceIdKey)? if let existing = defaults.string(forKey: instanceIdKey)?
@@ -28,7 +39,9 @@ public enum InstanceIdentity {
public static let displayName: String = { public static let displayName: String = {
#if canImport(UIKit) #if canImport(UIKit)
let name = UIDevice.current.name.trimmingCharacters(in: .whitespacesAndNewlines) let name = Self.readMainActor {
UIDevice.current.name.trimmingCharacters(in: .whitespacesAndNewlines)
}
return name.isEmpty ? "clawdbot" : name return name.isEmpty ? "clawdbot" : name
#else #else
if let name = Host.current().localizedName?.trimmingCharacters(in: .whitespacesAndNewlines), if let name = Host.current().localizedName?.trimmingCharacters(in: .whitespacesAndNewlines),
@@ -65,10 +78,12 @@ public enum InstanceIdentity {
public static let deviceFamily: String = { public static let deviceFamily: String = {
#if canImport(UIKit) #if canImport(UIKit)
switch UIDevice.current.userInterfaceIdiom { return Self.readMainActor {
case .pad: return "iPad" switch UIDevice.current.userInterfaceIdiom {
case .phone: return "iPhone" case .pad: return "iPad"
default: return "iOS" case .phone: return "iPhone"
default: return "iOS"
}
} }
#else #else
return "Mac" return "Mac"
@@ -78,11 +93,12 @@ public enum InstanceIdentity {
public static let platformString: String = { public static let platformString: String = {
let v = ProcessInfo.processInfo.operatingSystemVersion let v = ProcessInfo.processInfo.operatingSystemVersion
#if canImport(UIKit) #if canImport(UIKit)
let name: String let name = Self.readMainActor {
switch UIDevice.current.userInterfaceIdiom { switch UIDevice.current.userInterfaceIdiom {
case .pad: name = "iPadOS" case .pad: return "iPadOS"
case .phone: name = "iOS" case .phone: return "iOS"
default: name = "iOS" default: return "iOS"
}
} }
return "\(name) \(v.majorVersion).\(v.minorVersion).\(v.patchVersion)" return "\(name) \(v.majorVersion).\(v.minorVersion).\(v.patchVersion)"
#else #else

View File

@@ -20,7 +20,9 @@ function normalizeForHash(value: unknown): unknown {
.filter((item): item is unknown => item !== undefined); .filter((item): item is unknown => item !== undefined);
const primitives = normalized.filter(isPrimitive); const primitives = normalized.filter(isPrimitive);
if (primitives.length === normalized.length) { if (primitives.length === normalized.length) {
return [...primitives].sort((a, b) => String(a).localeCompare(String(b))); return [...primitives].sort((a, b) =>
primitiveToString(a).localeCompare(primitiveToString(b)),
);
} }
return normalized; return normalized;
} }
@@ -36,6 +38,14 @@ function normalizeForHash(value: unknown): unknown {
return value; return value;
} }
function primitiveToString(value: unknown): string {
if (value === null) return "null";
if (typeof value === "string") return value;
if (typeof value === "number") return String(value);
if (typeof value === "boolean") return value ? "true" : "false";
return JSON.stringify(value);
}
export function computeSandboxConfigHash(input: SandboxHashInput): string { export function computeSandboxConfigHash(input: SandboxHashInput): string {
const payload = normalizeForHash(input); const payload = normalizeForHash(input);
const raw = JSON.stringify(payload); const raw = JSON.stringify(payload);

View File

@@ -1,5 +1,5 @@
import { randomUUID } from "node:crypto"; import { randomUUID } from "node:crypto";
import { WebSocket } from "ws"; import { WebSocket, type ClientOptions, type CertMeta } from "ws";
import { rawDataToString } from "../infra/ws.js"; import { rawDataToString } from "../infra/ws.js";
import { logDebug, logError } from "../logger.js"; import { logDebug, logError } from "../logger.js";
import type { DeviceIdentity } from "../infra/device-identity.js"; import type { DeviceIdentity } from "../infra/device-identity.js";
@@ -85,18 +85,21 @@ export class GatewayClient {
if (this.closed) return; if (this.closed) return;
const url = this.opts.url ?? "ws://127.0.0.1:18789"; const url = this.opts.url ?? "ws://127.0.0.1:18789";
// Allow node screen snapshots and other large responses. // Allow node screen snapshots and other large responses.
const wsOptions: ConstructorParameters<typeof WebSocket>[1] = { const wsOptions: ClientOptions = {
maxPayload: 25 * 1024 * 1024, maxPayload: 25 * 1024 * 1024,
}; };
if (url.startsWith("wss://") && this.opts.tlsFingerprint) { if (url.startsWith("wss://") && this.opts.tlsFingerprint) {
wsOptions.rejectUnauthorized = false; wsOptions.rejectUnauthorized = false;
wsOptions.checkServerIdentity = (_host, cert) => { wsOptions.checkServerIdentity = (_host: string, cert: CertMeta) => {
const fingerprintValue =
typeof cert === "object" && cert && "fingerprint256" in cert
? (cert as { fingerprint256?: string }).fingerprint256 ?? ""
: "";
const fingerprint = normalizeFingerprint( const fingerprint = normalizeFingerprint(
typeof cert?.fingerprint256 === "string" ? cert.fingerprint256 : "", typeof fingerprintValue === "string" ? fingerprintValue : "",
); );
const expected = normalizeFingerprint(this.opts.tlsFingerprint ?? ""); const expected = normalizeFingerprint(this.opts.tlsFingerprint ?? "");
if (fingerprint && fingerprint === expected) return undefined; return Boolean(fingerprint && fingerprint === expected);
return new Error("gateway tls fingerprint mismatch");
}; };
} }
this.ws = new WebSocket(url, wsOptions); this.ws = new WebSocket(url, wsOptions);

View File

@@ -119,7 +119,7 @@ export class NodeRegistry {
timeoutMs: params.timeoutMs, timeoutMs: params.timeoutMs,
idempotencyKey: params.idempotencyKey, idempotencyKey: params.idempotencyKey,
}; };
const ok = this.sendEvent(node, "node.invoke.request", payload); const ok = this.sendEventToSession(node, "node.invoke.request", payload);
if (!ok) { if (!ok) {
return { return {
ok: false, ok: false,
@@ -172,7 +172,7 @@ export class NodeRegistry {
return this.sendEventToSession(node, event, payload); return this.sendEventToSession(node, event, payload);
} }
private sendEvent(node: NodeSession, event: string, payload: unknown): boolean { private sendEventInternal(node: NodeSession, event: string, payload: unknown): boolean {
try { try {
node.client.socket.send( node.client.socket.send(
JSON.stringify({ JSON.stringify({
@@ -188,6 +188,6 @@ export class NodeRegistry {
} }
private sendEventToSession(node: NodeSession, event: string, payload: unknown): boolean { private sendEventToSession(node: NodeSession, event: string, payload: unknown): boolean {
return this.sendEvent(node, event, payload); return this.sendEventInternal(node, event, payload);
} }
} }

View File

@@ -451,7 +451,6 @@ export const nodeHandlers: GatewayRequestHandlers = {
nodeContext, nodeContext,
"node", "node",
{ {
type: "event",
event: p.event, event: p.event,
payloadJSON, payloadJSON,
}, },

View File

@@ -356,13 +356,15 @@ export async function startGatewayServer(
const execApprovalManager = new ExecApprovalManager(); const execApprovalManager = new ExecApprovalManager();
const execApprovalHandlers = createExecApprovalHandlers(execApprovalManager); const execApprovalHandlers = createExecApprovalHandlers(execApprovalManager);
const canvasHostServerPort = (canvasHostServer as CanvasHostServer | null)?.port;
attachGatewayWsHandlers({ attachGatewayWsHandlers({
wss, wss,
clients, clients,
port, port,
gatewayHost: bindHost ?? undefined, gatewayHost: bindHost ?? undefined,
canvasHostEnabled: Boolean(canvasHost), canvasHostEnabled: Boolean(canvasHost),
canvasHostServerPort: canvasHostServer?.port ?? undefined, canvasHostServerPort,
resolvedAuth, resolvedAuth,
gatewayMethods, gatewayMethods,
events: GATEWAY_EVENTS, events: GATEWAY_EVENTS,

View File

@@ -11,7 +11,7 @@ import {
rpcReq, rpcReq,
startServerWithClient, startServerWithClient,
} from "./test-helpers.js"; } from "./test-helpers.js";
import { GATEWAY_CLIENT_MODES } from "../utils/message-channel.js"; import { GATEWAY_CLIENT_MODES, GATEWAY_CLIENT_NAMES } from "../utils/message-channel.js";
installGatewayTestHooks(); installGatewayTestHooks();
@@ -127,7 +127,7 @@ describe("gateway server models + voicewake", () => {
await connectOk(nodeWs, { await connectOk(nodeWs, {
role: "node", role: "node",
client: { client: {
id: "n1", id: GATEWAY_CLIENT_NAMES.NODE_HOST,
version: "1.0.0", version: "1.0.0",
platform: "ios", platform: "ios",
mode: GATEWAY_CLIENT_MODES.NODE, mode: GATEWAY_CLIENT_MODES.NODE,

View File

@@ -32,6 +32,7 @@ describe("sessions_send gateway loopback", () => {
it("returns reply when lifecycle ends before agent.wait", async () => { it("returns reply when lifecycle ends before agent.wait", async () => {
const port = await getFreePort(); const port = await getFreePort();
vi.stubEnv("CLAWDBOT_GATEWAY_PORT", String(port)); vi.stubEnv("CLAWDBOT_GATEWAY_PORT", String(port));
vi.stubEnv("CLAWDBOT_GATEWAY_TOKEN", "test-token");
const server = await startGatewayServer(port); const server = await startGatewayServer(port);
const spy = vi.mocked(agentCommand); const spy = vi.mocked(agentCommand);
@@ -105,6 +106,7 @@ describe("sessions_send label lookup", () => {
it("finds session by label and sends message", { timeout: 60_000 }, async () => { it("finds session by label and sends message", { timeout: 60_000 }, async () => {
const port = await getFreePort(); const port = await getFreePort();
vi.stubEnv("CLAWDBOT_GATEWAY_PORT", String(port)); vi.stubEnv("CLAWDBOT_GATEWAY_PORT", String(port));
vi.stubEnv("CLAWDBOT_GATEWAY_TOKEN", "test-token");
const server = await startGatewayServer(port); const server = await startGatewayServer(port);
servers.push(server); servers.push(server);
@@ -171,6 +173,7 @@ describe("sessions_send label lookup", () => {
it("returns error when label not found", { timeout: 60_000 }, async () => { it("returns error when label not found", { timeout: 60_000 }, async () => {
const port = await getFreePort(); const port = await getFreePort();
vi.stubEnv("CLAWDBOT_GATEWAY_PORT", String(port)); vi.stubEnv("CLAWDBOT_GATEWAY_PORT", String(port));
vi.stubEnv("CLAWDBOT_GATEWAY_TOKEN", "test-token");
const server = await startGatewayServer(port); const server = await startGatewayServer(port);
servers.push(server); servers.push(server);
@@ -191,6 +194,7 @@ describe("sessions_send label lookup", () => {
it("returns error when neither sessionKey nor label provided", { timeout: 60_000 }, async () => { it("returns error when neither sessionKey nor label provided", { timeout: 60_000 }, async () => {
const port = await getFreePort(); const port = await getFreePort();
vi.stubEnv("CLAWDBOT_GATEWAY_PORT", String(port)); vi.stubEnv("CLAWDBOT_GATEWAY_PORT", String(port));
vi.stubEnv("CLAWDBOT_GATEWAY_TOKEN", "test-token");
const server = await startGatewayServer(port); const server = await startGatewayServer(port);
servers.push(server); servers.push(server);

View File

@@ -4,6 +4,8 @@ import {
loadGatewayTlsRuntime as loadGatewayTlsRuntimeConfig, loadGatewayTlsRuntime as loadGatewayTlsRuntimeConfig,
} from "../../infra/tls/gateway.js"; } from "../../infra/tls/gateway.js";
export type { GatewayTlsRuntime } from "../../infra/tls/gateway.js";
export async function loadGatewayTlsRuntime( export async function loadGatewayTlsRuntime(
cfg: GatewayTlsConfig | undefined, cfg: GatewayTlsConfig | undefined,
log?: { info?: (msg: string) => void; warn?: (msg: string) => void }, log?: { info?: (msg: string) => void; warn?: (msg: string) => void },

View File

@@ -1,7 +1,6 @@
import crypto from "node:crypto"; import crypto from "node:crypto";
import { spawn } from "node:child_process"; import { spawn } from "node:child_process";
import fs from "node:fs"; import fs from "node:fs";
import os from "node:os";
import path from "node:path"; import path from "node:path";
import { import {