fix(macos): share gateway websocket connection

This commit is contained in:
Peter Steinberger
2025-12-12 21:34:33 +00:00
parent ce8db12b22
commit e944a0239d
11 changed files with 300 additions and 105 deletions

View File

@@ -11,26 +11,9 @@ actor AgentRPC {
static let shared = AgentRPC()
private let logger = Logger(subsystem: "com.steipete.clawdis", category: "agent.rpc")
private let gateway = GatewayChannel()
private var configured = false
private var gatewayURL: URL {
let port = GatewayEnvironment.gatewayPort()
return URL(string: "ws://127.0.0.1:\(port)")!
}
private var gatewayToken: String? {
ProcessInfo.processInfo.environment["CLAWDIS_GATEWAY_TOKEN"]
}
func start() async throws {
if self.configured { return }
await self.gateway.configure(url: self.gatewayURL, token: self.gatewayToken)
self.configured = true
}
func shutdown() async {
// no-op for WS; socket managed by GatewayChannel
// no-op; socket managed by GatewayConnection
}
func setHeartbeatsEnabled(_ enabled: Bool) async -> Bool {
@@ -85,8 +68,7 @@ actor AgentRPC {
}
func controlRequest(method: String, params: ControlRequestParams? = nil) async throws -> Data {
try await self.start()
let rawParams = params?.raw.reduce(into: [String: AnyCodable]()) { $0[$1.key] = AnyCodable($1.value) }
return try await self.gateway.request(method: method, params: rawParams)
return try await GatewayConnection.shared.request(method: method, params: rawParams)
}
}

View File

@@ -55,35 +55,34 @@ final class ControlChannel: ObservableObject {
@Published private(set) var lastPingMs: Double?
private let logger = Logger(subsystem: "com.steipete.clawdis", category: "control")
private let gateway = GatewayChannel()
private var gatewayPort: Int = GatewayEnvironment.gatewayPort()
private var gatewayURL: URL { URL(string: "ws://127.0.0.1:\(self.gatewayPort)")! }
private var gatewayToken: String? {
ProcessInfo.processInfo.environment["CLAWDIS_GATEWAY_TOKEN"]
}
private var eventTokens: [NSObjectProtocol] = []
private init() {
self.startEventStream()
}
func configure() async {
self.state = .connecting
await self.gateway.configure(url: self.gatewayURL, token: self.gatewayToken)
self.startEventStream()
self.state = .connected
PresenceReporter.shared.sendImmediate(reason: "connect")
do {
try await GatewayConnection.shared.refresh()
self.state = .connected
PresenceReporter.shared.sendImmediate(reason: "connect")
} catch {
let message = self.friendlyGatewayMessage(error)
self.state = .degraded(message)
}
}
func configure(mode: Mode = .local) async throws {
switch mode {
case .local:
self.gatewayPort = GatewayEnvironment.gatewayPort()
await self.configure()
case let .remote(target, identity):
// Create/ensure SSH tunnel, then talk to the forwarded local port.
_ = (target, identity)
do {
let forwarded = try await RemoteTunnelManager.shared.ensureControlTunnel()
self.gatewayPort = Int(forwarded)
_ = try await RemoteTunnelManager.shared.ensureControlTunnel()
await self.configure()
} catch {
self.state = .degraded(error.localizedDescription)
@@ -124,7 +123,7 @@ final class ControlChannel: ObservableObject {
{
do {
let rawParams = params?.reduce(into: [String: AnyCodable]()) { $0[$1.key] = AnyCodable($1.value) }
let data = try await self.gateway.request(method: method, params: rawParams, timeoutMs: timeoutMs)
let data = try await GatewayConnection.shared.request(method: method, params: rawParams, timeoutMs: timeoutMs)
self.state = .connected
return data
} catch {

View File

@@ -47,6 +47,10 @@ extension URLSession: WebSocketSessioning {
}
}
struct WebSocketSessionBox: @unchecked Sendable {
let session: any WebSocketSessioning
}
struct GatewayEvent: Codable {
let type: String
let event: String?
@@ -81,17 +85,40 @@ actor GatewayChannelActor {
private let decoder = JSONDecoder()
private let encoder = JSONEncoder()
private var watchdogTask: Task<Void, Never>?
private var tickTask: Task<Void, Never>?
private let defaultRequestTimeoutMs: Double = 15000
init(url: URL, token: String?, session: WebSocketSessioning? = nil) {
init(url: URL, token: String?, session: WebSocketSessionBox? = nil) {
self.url = url
self.token = token
self.session = session ?? URLSession(configuration: .default)
self.session = session?.session ?? URLSession(configuration: .default)
Task { [weak self] in
await self?.startWatchdog()
}
}
func shutdown() async {
self.shouldReconnect = false
self.connected = false
self.watchdogTask?.cancel()
self.watchdogTask = nil
self.tickTask?.cancel()
self.tickTask = nil
self.task?.cancel(with: .goingAway, reason: nil)
self.task = nil
await self.failPending(NSError(domain: "Gateway", code: 0, userInfo: [NSLocalizedDescriptionKey: "gateway channel shutdown"]))
let waiters = self.connectWaiters
self.connectWaiters.removeAll()
for waiter in waiters {
waiter.resume(throwing: NSError(domain: "Gateway", code: 0, userInfo: [NSLocalizedDescriptionKey: "gateway channel shutdown"]))
}
}
private func startWatchdog() {
self.watchdogTask?.cancel()
self.watchdogTask = Task { [weak self] in
@@ -104,6 +131,7 @@ actor GatewayChannelActor {
// Keep nudging reconnect in case exponential backoff stalls.
while self.shouldReconnect {
try? await Task.sleep(nanoseconds: 30 * 1_000_000_000) // 30s cadence
guard self.shouldReconnect else { return }
if self.connected { continue }
do {
try await self.connect()
@@ -207,7 +235,11 @@ actor GatewayChannelActor {
self.tickIntervalMs = Double(tick)
}
self.lastTick = Date()
Task { await self.watchTicks() }
self.tickTask?.cancel()
self.tickTask = Task { [weak self] in
guard let self else { return }
await self.watchTicks()
}
let frame = GatewayFrame.helloOk(ok)
NotificationCenter.default.post(name: .gatewaySnapshot, object: frame)
return
@@ -314,6 +346,7 @@ actor GatewayChannelActor {
let delay = self.backoffMs / 1000
self.backoffMs = min(self.backoffMs * 2, 30000)
try? await Task.sleep(nanoseconds: UInt64(delay * 1_000_000_000))
guard self.shouldReconnect else { return }
do {
try await self.connect()
} catch {
@@ -414,21 +447,4 @@ actor GatewayChannelActor {
}
}
actor GatewayChannel {
private var inner: GatewayChannelActor?
func configure(url: URL, token: String?) {
self.inner = GatewayChannelActor(url: url, token: token)
}
func request(
method: String,
params: [String: AnyCodable]?,
timeoutMs: Double? = nil) async throws -> Data
{
guard let inner else {
throw NSError(domain: "Gateway", code: 0, userInfo: [NSLocalizedDescriptionKey: "not configured"])
}
return try await inner.request(method: method, params: params, timeoutMs: timeoutMs)
}
}
// Intentionally no `GatewayChannel` wrapper: the app should use the single shared `GatewayConnection`.

View File

@@ -0,0 +1,80 @@
import Foundation
/// Single, shared Gateway websocket connection for the whole app.
///
/// This owns exactly one `GatewayChannelActor` and reuses it across all callers
/// (ControlChannel, AgentRPC, SwiftUI WebChat, etc.).
actor GatewayConnection {
static let shared = GatewayConnection()
typealias Config = (url: URL, token: String?)
private let configProvider: @Sendable () async throws -> Config
private let sessionBox: WebSocketSessionBox?
private var client: GatewayChannelActor?
private var configuredURL: URL?
private var configuredToken: String?
init(
configProvider: @escaping @Sendable () async throws -> Config = GatewayConnection.defaultConfigProvider,
sessionBox: WebSocketSessionBox? = nil)
{
self.configProvider = configProvider
self.sessionBox = sessionBox
}
func request(
method: String,
params: [String: AnyCodable]?,
timeoutMs: Double? = nil) async throws -> Data
{
let cfg = try await self.configProvider()
await self.configure(url: cfg.url, token: cfg.token)
guard let client else {
throw NSError(domain: "Gateway", code: 0, userInfo: [NSLocalizedDescriptionKey: "gateway not configured"])
}
return try await client.request(method: method, params: params, timeoutMs: timeoutMs)
}
/// Ensure the underlying socket is configured (and replaced if config changed).
func refresh() async throws {
let cfg = try await self.configProvider()
await self.configure(url: cfg.url, token: cfg.token)
}
func shutdown() async {
if let client {
await client.shutdown()
}
self.client = nil
self.configuredURL = nil
self.configuredToken = nil
}
private func configure(url: URL, token: String?) async {
if self.client != nil, self.configuredURL == url, self.configuredToken == token {
return
}
if let client {
await client.shutdown()
}
self.client = GatewayChannelActor(url: url, token: token, session: self.sessionBox)
self.configuredURL = url
self.configuredToken = token
}
private static func defaultConfigProvider() async throws -> Config {
let mode = await MainActor.run { AppStateStore.shared.connectionMode }
let token = ProcessInfo.processInfo.environment["CLAWDIS_GATEWAY_TOKEN"]
switch mode {
case .local:
let port = GatewayEnvironment.gatewayPort()
return (URL(string: "ws://127.0.0.1:\(port)")!, token)
case .remote:
let forwarded = try await RemoteTunnelManager.shared.ensureControlTunnel()
return (URL(string: "ws://127.0.0.1:\(Int(forwarded))")!, token)
}
}
}

View File

@@ -149,12 +149,8 @@ final class GatewayProcessManager: ObservableObject {
/// If successful, mark status as attached and skip spawning a new process.
private func attachExistingGatewayIfAvailable() async -> Bool {
let port = GatewayEnvironment.gatewayPort()
guard let url = URL(string: "ws://127.0.0.1:\(port)") else { return false }
let token = ProcessInfo.processInfo.environment["CLAWDIS_GATEWAY_TOKEN"]
let channel = GatewayChannel()
await channel.configure(url: url, token: token)
do {
let data = try await channel.request(method: "health", params: nil)
let data = try await GatewayConnection.shared.request(method: "health", params: nil)
let details: String
if let snap = decodeHealthSnapshot(from: data) {
let linked = snap.web.linked ? "linked" : "not linked"

View File

@@ -92,16 +92,6 @@ final class HealthStore: ObservableObject {
defer { self.isRefreshing = false }
do {
let mode = AppStateStore.shared.connectionMode
switch mode {
case .local:
try await ControlChannel.shared.configure(mode: .local)
case .remote:
let target = AppStateStore.shared.remoteTarget
let identity = AppStateStore.shared.remoteIdentity
try await ControlChannel.shared.configure(mode: .remote(target: target, identity: identity))
}
let data = try await ControlChannel.shared.health(timeout: 15)
if let decoded = decodeHealthSnapshot(from: data) {
self.snapshot = decoded

View File

@@ -189,6 +189,7 @@ final class AppDelegate: NSObject, NSApplicationDelegate {
WebChatManager.shared.resetTunnels()
Task { await RemoteTunnelManager.shared.stopAll() }
Task { await AgentRPC.shared.shutdown() }
Task { await GatewayConnection.shared.shutdown() }
Task { await self.socketServer.stop() }
Task { await BridgeServer.shared.stop() }
}

View File

@@ -79,11 +79,8 @@ final class WebChatViewModel: ObservableObject {
@Published var healthOK: Bool = true
private let sessionKey: String
private let gateway = GatewayChannel()
private var gatewayConfigured = false
private var eventToken: NSObjectProtocol?
private var pendingRuns = Set<String>()
private var currentPort: Int?
init(sessionKey: String) {
self.sessionKey = sessionKey
@@ -141,7 +138,6 @@ final class WebChatViewModel: ObservableObject {
self.isLoading = true
defer { self.isLoading = false }
do {
try await self.ensureGatewayConfigured()
let payload = try await self.requestHistory()
self.messages = payload.messages ?? []
if let level = payload.thinkingLevel, !level.isEmpty {
@@ -157,12 +153,6 @@ final class WebChatViewModel: ObservableObject {
guard !self.isSending else { return }
let trimmed = self.input.trimmingCharacters(in: .whitespacesAndNewlines)
guard !trimmed.isEmpty || !self.attachments.isEmpty else { return }
do {
try await self.ensureGatewayConfigured()
} catch {
self.errorText = error.localizedDescription
return
}
self.isSending = true
self.errorText = nil
@@ -202,7 +192,7 @@ final class WebChatViewModel: ObservableObject {
"idempotencyKey": AnyCodable(runId),
"timeoutMs": AnyCodable(30_000)
]
let data = try await self.gateway.request(method: "chat.send", params: params)
let data = try await GatewayConnection.shared.request(method: "chat.send", params: params)
let response = try JSONDecoder().decode(ChatSendResponse.self, from: data)
self.pendingRuns.insert(response.runId)
} catch {
@@ -215,26 +205,8 @@ final class WebChatViewModel: ObservableObject {
self.isSending = false
}
private func ensureGatewayConfigured() async throws {
guard !self.gatewayConfigured else { return }
let port = try await self.resolveGatewayPort()
self.currentPort = port
let url = URL(string: "ws://127.0.0.1:\(port)")!
let token = ProcessInfo.processInfo.environment["CLAWDIS_GATEWAY_TOKEN"]
await self.gateway.configure(url: url, token: token)
self.gatewayConfigured = true
}
private func resolveGatewayPort() async throws -> Int {
if CommandResolver.connectionModeIsRemote() {
let forwarded = try await RemoteTunnelManager.shared.ensureControlTunnel()
return Int(forwarded)
}
return GatewayEnvironment.gatewayPort()
}
private func requestHistory() async throws -> ChatHistoryPayload {
let data = try await self.gateway.request(
let data = try await GatewayConnection.shared.request(
method: "chat.history",
params: ["sessionKey": AnyCodable(self.sessionKey)])
return try JSONDecoder().decode(ChatHistoryPayload.self, from: data)

View File

@@ -0,0 +1,160 @@
import Foundation
import os
import Testing
@testable import Clawdis
@Suite struct GatewayConnectionTests {
private final class FakeWebSocketTask: WebSocketTasking, @unchecked Sendable {
private let pendingReceiveHandler =
OSAllocatedUnfairLock<(@Sendable (Result<URLSessionWebSocketTask.Message, Error>) -> Void)?>(initialState: nil)
private let cancelCount = OSAllocatedUnfairLock(initialState: 0)
private let sendCount = OSAllocatedUnfairLock(initialState: 0)
var state: URLSessionTask.State = .suspended
func snapshotCancelCount() -> Int { self.cancelCount.withLock { $0 } }
func resume() {
self.state = .running
}
func cancel(with closeCode: URLSessionWebSocketTask.CloseCode, reason: Data?) {
_ = (closeCode, reason)
self.state = .canceling
self.cancelCount.withLock { $0 += 1 }
let handler = self.pendingReceiveHandler.withLock { handler in
defer { handler = nil }
return handler
}
handler?(Result<URLSessionWebSocketTask.Message, Error>.failure(URLError(.cancelled)))
}
func send(_ message: URLSessionWebSocketTask.Message) async throws {
let currentSendCount = self.sendCount.withLock { count in
defer { count += 1 }
return count
}
// First send is the hello frame. Subsequent sends are request frames.
if currentSendCount == 0 { return }
guard case let .data(data) = message else { return }
guard
let obj = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
(obj["type"] as? String) == "req",
let id = obj["id"] as? String
else {
return
}
let response = Self.responseData(id: id)
let handler = self.pendingReceiveHandler.withLock { $0 }
handler?(Result<URLSessionWebSocketTask.Message, Error>.success(.data(response)))
}
func receive() async throws -> URLSessionWebSocketTask.Message {
.data(Self.helloOkData())
}
func receive(
completionHandler: @escaping @Sendable (Result<URLSessionWebSocketTask.Message, Error>) -> Void)
{
self.pendingReceiveHandler.withLock { $0 = completionHandler }
}
private static func helloOkData() -> Data {
let json = """
{
"type": "hello-ok",
"protocol": 1,
"server": { "version": "test", "connId": "test" },
"features": { "methods": [], "events": [] },
"snapshot": {
"presence": [ { "ts": 1 } ],
"health": {},
"stateVersion": { "presence": 0, "health": 0 },
"uptimeMs": 0
},
"policy": { "maxPayload": 1, "maxBufferedBytes": 1, "tickIntervalMs": 30000 }
}
"""
return Data(json.utf8)
}
private static func responseData(id: String) -> Data {
let json = """
{
"type": "res",
"id": "\(id)",
"ok": true,
"payload": { "ok": true }
}
"""
return Data(json.utf8)
}
}
private final class FakeWebSocketSession: WebSocketSessioning, @unchecked Sendable {
private let makeCount = OSAllocatedUnfairLock(initialState: 0)
private let tasks = OSAllocatedUnfairLock(initialState: [FakeWebSocketTask]())
func snapshotMakeCount() -> Int { self.makeCount.withLock { $0 } }
func snapshotCancelCount() -> Int {
self.tasks.withLock { tasks in
tasks.reduce(0) { $0 + $1.snapshotCancelCount() }
}
}
func makeWebSocketTask(url: URL) -> WebSocketTaskBox {
_ = url
self.makeCount.withLock { $0 += 1 }
let task = FakeWebSocketTask()
self.tasks.withLock { $0.append(task) }
return WebSocketTaskBox(task: task)
}
}
private final class ConfigSource: @unchecked Sendable {
private let token = OSAllocatedUnfairLock<String?>(initialState: nil)
init(token: String?) {
self.token.withLock { $0 = token }
}
func snapshotToken() -> String? { self.token.withLock { $0 } }
func setToken(_ value: String?) { self.token.withLock { $0 = value } }
}
@Test func requestReusesSingleWebSocketForSameConfig() async throws {
let session = FakeWebSocketSession()
let url = URL(string: "ws://example.invalid")!
let cfg = ConfigSource(token: nil)
let conn = GatewayConnection(
configProvider: { (url, cfg.snapshotToken()) },
sessionBox: WebSocketSessionBox(session: session))
_ = try await conn.request(method: "status", params: nil)
#expect(session.snapshotMakeCount() == 1)
_ = try await conn.request(method: "status", params: nil)
#expect(session.snapshotMakeCount() == 1)
#expect(session.snapshotCancelCount() == 0)
}
@Test func requestReconfiguresAndCancelsOnTokenChange() async throws {
let session = FakeWebSocketSession()
let url = URL(string: "ws://example.invalid")!
let cfg = ConfigSource(token: "a")
let conn = GatewayConnection(
configProvider: { (url, cfg.snapshotToken()) },
sessionBox: WebSocketSessionBox(session: session))
_ = try await conn.request(method: "status", params: nil)
#expect(session.snapshotMakeCount() == 1)
cfg.setToken("b")
_ = try await conn.request(method: "status", params: nil)
#expect(session.snapshotMakeCount() == 2)
#expect(session.snapshotCancelCount() == 1)
}
}

View File

@@ -101,7 +101,7 @@ import Testing
let channel = GatewayChannelActor(
url: URL(string: "ws://example.invalid")!,
token: nil,
session: session)
session: WebSocketSessionBox(session: session))
let t1 = Task { try await channel.connect() }
let t2 = Task { try await channel.connect() }
@@ -117,7 +117,7 @@ import Testing
let channel = GatewayChannelActor(
url: URL(string: "ws://example.invalid")!,
token: nil,
session: session)
session: WebSocketSessionBox(session: session))
let t1 = Task { try await channel.connect() }
let t2 = Task { try await channel.connect() }

View File

@@ -93,7 +93,7 @@ import Testing
let channel = GatewayChannelActor(
url: URL(string: "ws://example.invalid")!,
token: nil,
session: session)
session: WebSocketSessionBox(session: session))
do {
_ = try await channel.request(method: "test", params: nil, timeoutMs: 10)
@@ -108,4 +108,3 @@ import Testing
try? await Task.sleep(nanoseconds: 250 * 1_000_000)
}
}