From e944a0239d24eac5cbb587870f1df34e50dbe6b5 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Fri, 12 Dec 2025 21:34:33 +0000 Subject: [PATCH] fix(macos): share gateway websocket connection --- apps/macos/Sources/Clawdis/AgentRPC.swift | 22 +-- .../Sources/Clawdis/ControlChannel.swift | 29 ++-- .../Sources/Clawdis/GatewayChannel.swift | 58 ++++--- .../Sources/Clawdis/GatewayConnection.swift | 80 +++++++++ .../Clawdis/GatewayProcessManager.swift | 6 +- apps/macos/Sources/Clawdis/HealthStore.swift | 10 -- apps/macos/Sources/Clawdis/MenuBar.swift | 1 + .../Sources/Clawdis/WebChatSwiftUI.swift | 32 +--- .../GatewayChannelConfigureTests.swift | 160 ++++++++++++++++++ .../GatewayChannelConnectTests.swift | 4 +- .../GatewayChannelRequestTests.swift | 3 +- 11 files changed, 300 insertions(+), 105 deletions(-) create mode 100644 apps/macos/Sources/Clawdis/GatewayConnection.swift create mode 100644 apps/macos/Tests/ClawdisIPCTests/GatewayChannelConfigureTests.swift diff --git a/apps/macos/Sources/Clawdis/AgentRPC.swift b/apps/macos/Sources/Clawdis/AgentRPC.swift index 6c5d7b53b..3cee19cca 100644 --- a/apps/macos/Sources/Clawdis/AgentRPC.swift +++ b/apps/macos/Sources/Clawdis/AgentRPC.swift @@ -11,26 +11,9 @@ actor AgentRPC { static let shared = AgentRPC() private let logger = Logger(subsystem: "com.steipete.clawdis", category: "agent.rpc") - private let gateway = GatewayChannel() - private var configured = false - - private var gatewayURL: URL { - let port = GatewayEnvironment.gatewayPort() - return URL(string: "ws://127.0.0.1:\(port)")! - } - - private var gatewayToken: String? { - ProcessInfo.processInfo.environment["CLAWDIS_GATEWAY_TOKEN"] - } - - func start() async throws { - if self.configured { return } - await self.gateway.configure(url: self.gatewayURL, token: self.gatewayToken) - self.configured = true - } func shutdown() async { - // no-op for WS; socket managed by GatewayChannel + // no-op; socket managed by GatewayConnection } func setHeartbeatsEnabled(_ enabled: Bool) async -> Bool { @@ -85,8 +68,7 @@ actor AgentRPC { } func controlRequest(method: String, params: ControlRequestParams? = nil) async throws -> Data { - try await self.start() let rawParams = params?.raw.reduce(into: [String: AnyCodable]()) { $0[$1.key] = AnyCodable($1.value) } - return try await self.gateway.request(method: method, params: rawParams) + return try await GatewayConnection.shared.request(method: method, params: rawParams) } } diff --git a/apps/macos/Sources/Clawdis/ControlChannel.swift b/apps/macos/Sources/Clawdis/ControlChannel.swift index b9d7bd413..76b87c27e 100644 --- a/apps/macos/Sources/Clawdis/ControlChannel.swift +++ b/apps/macos/Sources/Clawdis/ControlChannel.swift @@ -55,35 +55,34 @@ final class ControlChannel: ObservableObject { @Published private(set) var lastPingMs: Double? private let logger = Logger(subsystem: "com.steipete.clawdis", category: "control") - private let gateway = GatewayChannel() - private var gatewayPort: Int = GatewayEnvironment.gatewayPort() - private var gatewayURL: URL { URL(string: "ws://127.0.0.1:\(self.gatewayPort)")! } - - private var gatewayToken: String? { - ProcessInfo.processInfo.environment["CLAWDIS_GATEWAY_TOKEN"] - } private var eventTokens: [NSObjectProtocol] = [] + private init() { + self.startEventStream() + } + func configure() async { self.state = .connecting - await self.gateway.configure(url: self.gatewayURL, token: self.gatewayToken) - self.startEventStream() - self.state = .connected - PresenceReporter.shared.sendImmediate(reason: "connect") + do { + try await GatewayConnection.shared.refresh() + self.state = .connected + PresenceReporter.shared.sendImmediate(reason: "connect") + } catch { + let message = self.friendlyGatewayMessage(error) + self.state = .degraded(message) + } } func configure(mode: Mode = .local) async throws { switch mode { case .local: - self.gatewayPort = GatewayEnvironment.gatewayPort() await self.configure() case let .remote(target, identity): // Create/ensure SSH tunnel, then talk to the forwarded local port. _ = (target, identity) do { - let forwarded = try await RemoteTunnelManager.shared.ensureControlTunnel() - self.gatewayPort = Int(forwarded) + _ = try await RemoteTunnelManager.shared.ensureControlTunnel() await self.configure() } catch { self.state = .degraded(error.localizedDescription) @@ -124,7 +123,7 @@ final class ControlChannel: ObservableObject { { do { let rawParams = params?.reduce(into: [String: AnyCodable]()) { $0[$1.key] = AnyCodable($1.value) } - let data = try await self.gateway.request(method: method, params: rawParams, timeoutMs: timeoutMs) + let data = try await GatewayConnection.shared.request(method: method, params: rawParams, timeoutMs: timeoutMs) self.state = .connected return data } catch { diff --git a/apps/macos/Sources/Clawdis/GatewayChannel.swift b/apps/macos/Sources/Clawdis/GatewayChannel.swift index 6ef99c4c8..e69e47ac3 100644 --- a/apps/macos/Sources/Clawdis/GatewayChannel.swift +++ b/apps/macos/Sources/Clawdis/GatewayChannel.swift @@ -47,6 +47,10 @@ extension URLSession: WebSocketSessioning { } } +struct WebSocketSessionBox: @unchecked Sendable { + let session: any WebSocketSessioning +} + struct GatewayEvent: Codable { let type: String let event: String? @@ -81,17 +85,40 @@ actor GatewayChannelActor { private let decoder = JSONDecoder() private let encoder = JSONEncoder() private var watchdogTask: Task? + private var tickTask: Task? private let defaultRequestTimeoutMs: Double = 15000 - init(url: URL, token: String?, session: WebSocketSessioning? = nil) { + init(url: URL, token: String?, session: WebSocketSessionBox? = nil) { self.url = url self.token = token - self.session = session ?? URLSession(configuration: .default) + self.session = session?.session ?? URLSession(configuration: .default) Task { [weak self] in await self?.startWatchdog() } } + func shutdown() async { + self.shouldReconnect = false + self.connected = false + + self.watchdogTask?.cancel() + self.watchdogTask = nil + + self.tickTask?.cancel() + self.tickTask = nil + + self.task?.cancel(with: .goingAway, reason: nil) + self.task = nil + + await self.failPending(NSError(domain: "Gateway", code: 0, userInfo: [NSLocalizedDescriptionKey: "gateway channel shutdown"])) + + let waiters = self.connectWaiters + self.connectWaiters.removeAll() + for waiter in waiters { + waiter.resume(throwing: NSError(domain: "Gateway", code: 0, userInfo: [NSLocalizedDescriptionKey: "gateway channel shutdown"])) + } + } + private func startWatchdog() { self.watchdogTask?.cancel() self.watchdogTask = Task { [weak self] in @@ -104,6 +131,7 @@ actor GatewayChannelActor { // Keep nudging reconnect in case exponential backoff stalls. while self.shouldReconnect { try? await Task.sleep(nanoseconds: 30 * 1_000_000_000) // 30s cadence + guard self.shouldReconnect else { return } if self.connected { continue } do { try await self.connect() @@ -207,7 +235,11 @@ actor GatewayChannelActor { self.tickIntervalMs = Double(tick) } self.lastTick = Date() - Task { await self.watchTicks() } + self.tickTask?.cancel() + self.tickTask = Task { [weak self] in + guard let self else { return } + await self.watchTicks() + } let frame = GatewayFrame.helloOk(ok) NotificationCenter.default.post(name: .gatewaySnapshot, object: frame) return @@ -314,6 +346,7 @@ actor GatewayChannelActor { let delay = self.backoffMs / 1000 self.backoffMs = min(self.backoffMs * 2, 30000) try? await Task.sleep(nanoseconds: UInt64(delay * 1_000_000_000)) + guard self.shouldReconnect else { return } do { try await self.connect() } catch { @@ -414,21 +447,4 @@ actor GatewayChannelActor { } } -actor GatewayChannel { - private var inner: GatewayChannelActor? - - func configure(url: URL, token: String?) { - self.inner = GatewayChannelActor(url: url, token: token) - } - - func request( - method: String, - params: [String: AnyCodable]?, - timeoutMs: Double? = nil) async throws -> Data - { - guard let inner else { - throw NSError(domain: "Gateway", code: 0, userInfo: [NSLocalizedDescriptionKey: "not configured"]) - } - return try await inner.request(method: method, params: params, timeoutMs: timeoutMs) - } -} +// Intentionally no `GatewayChannel` wrapper: the app should use the single shared `GatewayConnection`. diff --git a/apps/macos/Sources/Clawdis/GatewayConnection.swift b/apps/macos/Sources/Clawdis/GatewayConnection.swift new file mode 100644 index 000000000..999dcf992 --- /dev/null +++ b/apps/macos/Sources/Clawdis/GatewayConnection.swift @@ -0,0 +1,80 @@ +import Foundation + +/// Single, shared Gateway websocket connection for the whole app. +/// +/// This owns exactly one `GatewayChannelActor` and reuses it across all callers +/// (ControlChannel, AgentRPC, SwiftUI WebChat, etc.). +actor GatewayConnection { + static let shared = GatewayConnection() + + typealias Config = (url: URL, token: String?) + + private let configProvider: @Sendable () async throws -> Config + private let sessionBox: WebSocketSessionBox? + + private var client: GatewayChannelActor? + private var configuredURL: URL? + private var configuredToken: String? + + init( + configProvider: @escaping @Sendable () async throws -> Config = GatewayConnection.defaultConfigProvider, + sessionBox: WebSocketSessionBox? = nil) + { + self.configProvider = configProvider + self.sessionBox = sessionBox + } + + func request( + method: String, + params: [String: AnyCodable]?, + timeoutMs: Double? = nil) async throws -> Data + { + let cfg = try await self.configProvider() + await self.configure(url: cfg.url, token: cfg.token) + guard let client else { + throw NSError(domain: "Gateway", code: 0, userInfo: [NSLocalizedDescriptionKey: "gateway not configured"]) + } + return try await client.request(method: method, params: params, timeoutMs: timeoutMs) + } + + /// Ensure the underlying socket is configured (and replaced if config changed). + func refresh() async throws { + let cfg = try await self.configProvider() + await self.configure(url: cfg.url, token: cfg.token) + } + + func shutdown() async { + if let client { + await client.shutdown() + } + self.client = nil + self.configuredURL = nil + self.configuredToken = nil + } + + private func configure(url: URL, token: String?) async { + if self.client != nil, self.configuredURL == url, self.configuredToken == token { + return + } + if let client { + await client.shutdown() + } + self.client = GatewayChannelActor(url: url, token: token, session: self.sessionBox) + self.configuredURL = url + self.configuredToken = token + } + + private static func defaultConfigProvider() async throws -> Config { + let mode = await MainActor.run { AppStateStore.shared.connectionMode } + let token = ProcessInfo.processInfo.environment["CLAWDIS_GATEWAY_TOKEN"] + switch mode { + case .local: + let port = GatewayEnvironment.gatewayPort() + return (URL(string: "ws://127.0.0.1:\(port)")!, token) + case .remote: + let forwarded = try await RemoteTunnelManager.shared.ensureControlTunnel() + return (URL(string: "ws://127.0.0.1:\(Int(forwarded))")!, token) + } + } +} + diff --git a/apps/macos/Sources/Clawdis/GatewayProcessManager.swift b/apps/macos/Sources/Clawdis/GatewayProcessManager.swift index 232364b0e..040140fd1 100644 --- a/apps/macos/Sources/Clawdis/GatewayProcessManager.swift +++ b/apps/macos/Sources/Clawdis/GatewayProcessManager.swift @@ -149,12 +149,8 @@ final class GatewayProcessManager: ObservableObject { /// If successful, mark status as attached and skip spawning a new process. private func attachExistingGatewayIfAvailable() async -> Bool { let port = GatewayEnvironment.gatewayPort() - guard let url = URL(string: "ws://127.0.0.1:\(port)") else { return false } - let token = ProcessInfo.processInfo.environment["CLAWDIS_GATEWAY_TOKEN"] - let channel = GatewayChannel() - await channel.configure(url: url, token: token) do { - let data = try await channel.request(method: "health", params: nil) + let data = try await GatewayConnection.shared.request(method: "health", params: nil) let details: String if let snap = decodeHealthSnapshot(from: data) { let linked = snap.web.linked ? "linked" : "not linked" diff --git a/apps/macos/Sources/Clawdis/HealthStore.swift b/apps/macos/Sources/Clawdis/HealthStore.swift index 94bd0bb21..ce4d0b69e 100644 --- a/apps/macos/Sources/Clawdis/HealthStore.swift +++ b/apps/macos/Sources/Clawdis/HealthStore.swift @@ -92,16 +92,6 @@ final class HealthStore: ObservableObject { defer { self.isRefreshing = false } do { - let mode = AppStateStore.shared.connectionMode - switch mode { - case .local: - try await ControlChannel.shared.configure(mode: .local) - case .remote: - let target = AppStateStore.shared.remoteTarget - let identity = AppStateStore.shared.remoteIdentity - try await ControlChannel.shared.configure(mode: .remote(target: target, identity: identity)) - } - let data = try await ControlChannel.shared.health(timeout: 15) if let decoded = decodeHealthSnapshot(from: data) { self.snapshot = decoded diff --git a/apps/macos/Sources/Clawdis/MenuBar.swift b/apps/macos/Sources/Clawdis/MenuBar.swift index 32329e7eb..fa8b7209a 100644 --- a/apps/macos/Sources/Clawdis/MenuBar.swift +++ b/apps/macos/Sources/Clawdis/MenuBar.swift @@ -189,6 +189,7 @@ final class AppDelegate: NSObject, NSApplicationDelegate { WebChatManager.shared.resetTunnels() Task { await RemoteTunnelManager.shared.stopAll() } Task { await AgentRPC.shared.shutdown() } + Task { await GatewayConnection.shared.shutdown() } Task { await self.socketServer.stop() } Task { await BridgeServer.shared.stop() } } diff --git a/apps/macos/Sources/Clawdis/WebChatSwiftUI.swift b/apps/macos/Sources/Clawdis/WebChatSwiftUI.swift index 400ee2df0..388cbb00e 100644 --- a/apps/macos/Sources/Clawdis/WebChatSwiftUI.swift +++ b/apps/macos/Sources/Clawdis/WebChatSwiftUI.swift @@ -79,11 +79,8 @@ final class WebChatViewModel: ObservableObject { @Published var healthOK: Bool = true private let sessionKey: String - private let gateway = GatewayChannel() - private var gatewayConfigured = false private var eventToken: NSObjectProtocol? private var pendingRuns = Set() - private var currentPort: Int? init(sessionKey: String) { self.sessionKey = sessionKey @@ -141,7 +138,6 @@ final class WebChatViewModel: ObservableObject { self.isLoading = true defer { self.isLoading = false } do { - try await self.ensureGatewayConfigured() let payload = try await self.requestHistory() self.messages = payload.messages ?? [] if let level = payload.thinkingLevel, !level.isEmpty { @@ -157,12 +153,6 @@ final class WebChatViewModel: ObservableObject { guard !self.isSending else { return } let trimmed = self.input.trimmingCharacters(in: .whitespacesAndNewlines) guard !trimmed.isEmpty || !self.attachments.isEmpty else { return } - do { - try await self.ensureGatewayConfigured() - } catch { - self.errorText = error.localizedDescription - return - } self.isSending = true self.errorText = nil @@ -202,7 +192,7 @@ final class WebChatViewModel: ObservableObject { "idempotencyKey": AnyCodable(runId), "timeoutMs": AnyCodable(30_000) ] - let data = try await self.gateway.request(method: "chat.send", params: params) + let data = try await GatewayConnection.shared.request(method: "chat.send", params: params) let response = try JSONDecoder().decode(ChatSendResponse.self, from: data) self.pendingRuns.insert(response.runId) } catch { @@ -215,26 +205,8 @@ final class WebChatViewModel: ObservableObject { self.isSending = false } - private func ensureGatewayConfigured() async throws { - guard !self.gatewayConfigured else { return } - let port = try await self.resolveGatewayPort() - self.currentPort = port - let url = URL(string: "ws://127.0.0.1:\(port)")! - let token = ProcessInfo.processInfo.environment["CLAWDIS_GATEWAY_TOKEN"] - await self.gateway.configure(url: url, token: token) - self.gatewayConfigured = true - } - - private func resolveGatewayPort() async throws -> Int { - if CommandResolver.connectionModeIsRemote() { - let forwarded = try await RemoteTunnelManager.shared.ensureControlTunnel() - return Int(forwarded) - } - return GatewayEnvironment.gatewayPort() - } - private func requestHistory() async throws -> ChatHistoryPayload { - let data = try await self.gateway.request( + let data = try await GatewayConnection.shared.request( method: "chat.history", params: ["sessionKey": AnyCodable(self.sessionKey)]) return try JSONDecoder().decode(ChatHistoryPayload.self, from: data) diff --git a/apps/macos/Tests/ClawdisIPCTests/GatewayChannelConfigureTests.swift b/apps/macos/Tests/ClawdisIPCTests/GatewayChannelConfigureTests.swift new file mode 100644 index 000000000..86dbdece0 --- /dev/null +++ b/apps/macos/Tests/ClawdisIPCTests/GatewayChannelConfigureTests.swift @@ -0,0 +1,160 @@ +import Foundation +import os +import Testing +@testable import Clawdis + +@Suite struct GatewayConnectionTests { + private final class FakeWebSocketTask: WebSocketTasking, @unchecked Sendable { + private let pendingReceiveHandler = + OSAllocatedUnfairLock<(@Sendable (Result) -> Void)?>(initialState: nil) + private let cancelCount = OSAllocatedUnfairLock(initialState: 0) + private let sendCount = OSAllocatedUnfairLock(initialState: 0) + + var state: URLSessionTask.State = .suspended + + func snapshotCancelCount() -> Int { self.cancelCount.withLock { $0 } } + + func resume() { + self.state = .running + } + + func cancel(with closeCode: URLSessionWebSocketTask.CloseCode, reason: Data?) { + _ = (closeCode, reason) + self.state = .canceling + self.cancelCount.withLock { $0 += 1 } + let handler = self.pendingReceiveHandler.withLock { handler in + defer { handler = nil } + return handler + } + handler?(Result.failure(URLError(.cancelled))) + } + + func send(_ message: URLSessionWebSocketTask.Message) async throws { + let currentSendCount = self.sendCount.withLock { count in + defer { count += 1 } + return count + } + + // First send is the hello frame. Subsequent sends are request frames. + if currentSendCount == 0 { return } + + guard case let .data(data) = message else { return } + guard + let obj = try? JSONSerialization.jsonObject(with: data) as? [String: Any], + (obj["type"] as? String) == "req", + let id = obj["id"] as? String + else { + return + } + + let response = Self.responseData(id: id) + let handler = self.pendingReceiveHandler.withLock { $0 } + handler?(Result.success(.data(response))) + } + + func receive() async throws -> URLSessionWebSocketTask.Message { + .data(Self.helloOkData()) + } + + func receive( + completionHandler: @escaping @Sendable (Result) -> Void) + { + self.pendingReceiveHandler.withLock { $0 = completionHandler } + } + + private static func helloOkData() -> Data { + let json = """ + { + "type": "hello-ok", + "protocol": 1, + "server": { "version": "test", "connId": "test" }, + "features": { "methods": [], "events": [] }, + "snapshot": { + "presence": [ { "ts": 1 } ], + "health": {}, + "stateVersion": { "presence": 0, "health": 0 }, + "uptimeMs": 0 + }, + "policy": { "maxPayload": 1, "maxBufferedBytes": 1, "tickIntervalMs": 30000 } + } + """ + return Data(json.utf8) + } + + private static func responseData(id: String) -> Data { + let json = """ + { + "type": "res", + "id": "\(id)", + "ok": true, + "payload": { "ok": true } + } + """ + return Data(json.utf8) + } + } + + private final class FakeWebSocketSession: WebSocketSessioning, @unchecked Sendable { + private let makeCount = OSAllocatedUnfairLock(initialState: 0) + private let tasks = OSAllocatedUnfairLock(initialState: [FakeWebSocketTask]()) + + func snapshotMakeCount() -> Int { self.makeCount.withLock { $0 } } + func snapshotCancelCount() -> Int { + self.tasks.withLock { tasks in + tasks.reduce(0) { $0 + $1.snapshotCancelCount() } + } + } + + func makeWebSocketTask(url: URL) -> WebSocketTaskBox { + _ = url + self.makeCount.withLock { $0 += 1 } + let task = FakeWebSocketTask() + self.tasks.withLock { $0.append(task) } + return WebSocketTaskBox(task: task) + } + } + + private final class ConfigSource: @unchecked Sendable { + private let token = OSAllocatedUnfairLock(initialState: nil) + + init(token: String?) { + self.token.withLock { $0 = token } + } + + func snapshotToken() -> String? { self.token.withLock { $0 } } + func setToken(_ value: String?) { self.token.withLock { $0 = value } } + } + + @Test func requestReusesSingleWebSocketForSameConfig() async throws { + let session = FakeWebSocketSession() + let url = URL(string: "ws://example.invalid")! + let cfg = ConfigSource(token: nil) + let conn = GatewayConnection( + configProvider: { (url, cfg.snapshotToken()) }, + sessionBox: WebSocketSessionBox(session: session)) + + _ = try await conn.request(method: "status", params: nil) + #expect(session.snapshotMakeCount() == 1) + + _ = try await conn.request(method: "status", params: nil) + #expect(session.snapshotMakeCount() == 1) + #expect(session.snapshotCancelCount() == 0) + } + + @Test func requestReconfiguresAndCancelsOnTokenChange() async throws { + let session = FakeWebSocketSession() + let url = URL(string: "ws://example.invalid")! + let cfg = ConfigSource(token: "a") + let conn = GatewayConnection( + configProvider: { (url, cfg.snapshotToken()) }, + sessionBox: WebSocketSessionBox(session: session)) + + _ = try await conn.request(method: "status", params: nil) + #expect(session.snapshotMakeCount() == 1) + + cfg.setToken("b") + _ = try await conn.request(method: "status", params: nil) + #expect(session.snapshotMakeCount() == 2) + #expect(session.snapshotCancelCount() == 1) + } +} diff --git a/apps/macos/Tests/ClawdisIPCTests/GatewayChannelConnectTests.swift b/apps/macos/Tests/ClawdisIPCTests/GatewayChannelConnectTests.swift index 1f4085298..af5c335ca 100644 --- a/apps/macos/Tests/ClawdisIPCTests/GatewayChannelConnectTests.swift +++ b/apps/macos/Tests/ClawdisIPCTests/GatewayChannelConnectTests.swift @@ -101,7 +101,7 @@ import Testing let channel = GatewayChannelActor( url: URL(string: "ws://example.invalid")!, token: nil, - session: session) + session: WebSocketSessionBox(session: session)) let t1 = Task { try await channel.connect() } let t2 = Task { try await channel.connect() } @@ -117,7 +117,7 @@ import Testing let channel = GatewayChannelActor( url: URL(string: "ws://example.invalid")!, token: nil, - session: session) + session: WebSocketSessionBox(session: session)) let t1 = Task { try await channel.connect() } let t2 = Task { try await channel.connect() } diff --git a/apps/macos/Tests/ClawdisIPCTests/GatewayChannelRequestTests.swift b/apps/macos/Tests/ClawdisIPCTests/GatewayChannelRequestTests.swift index f053d7da8..7bac7d4f3 100644 --- a/apps/macos/Tests/ClawdisIPCTests/GatewayChannelRequestTests.swift +++ b/apps/macos/Tests/ClawdisIPCTests/GatewayChannelRequestTests.swift @@ -93,7 +93,7 @@ import Testing let channel = GatewayChannelActor( url: URL(string: "ws://example.invalid")!, token: nil, - session: session) + session: WebSocketSessionBox(session: session)) do { _ = try await channel.request(method: "test", params: nil, timeoutMs: 10) @@ -108,4 +108,3 @@ import Testing try? await Task.sleep(nanoseconds: 250 * 1_000_000) } } -