diff --git a/apps/macos/Sources/Clawdis/GatewayChannel.swift b/apps/macos/Sources/Clawdis/GatewayChannel.swift index b85cf54d0..db37c6a06 100644 --- a/apps/macos/Sources/Clawdis/GatewayChannel.swift +++ b/apps/macos/Sources/Clawdis/GatewayChannel.swift @@ -2,6 +2,51 @@ import ClawdisProtocol import Foundation import OSLog +protocol WebSocketTasking: AnyObject { + var state: URLSessionTask.State { get } + func resume() + func cancel(with closeCode: URLSessionWebSocketTask.CloseCode, reason: Data?) + func send(_ message: URLSessionWebSocketTask.Message) async throws + func receive() async throws -> URLSessionWebSocketTask.Message + func receive(completionHandler: @escaping @Sendable (Result) -> Void) +} + +extension URLSessionWebSocketTask: WebSocketTasking {} + +struct WebSocketTaskBox: @unchecked Sendable { + let task: any WebSocketTasking + + var state: URLSessionTask.State { self.task.state } + + func resume() { self.task.resume() } + + func cancel(with closeCode: URLSessionWebSocketTask.CloseCode, reason: Data?) { + self.task.cancel(with: closeCode, reason: reason) + } + + func send(_ message: URLSessionWebSocketTask.Message) async throws { + try await self.task.send(message) + } + + func receive() async throws -> URLSessionWebSocketTask.Message { + try await self.task.receive() + } + + func receive(completionHandler: @escaping @Sendable (Result) -> Void) { + self.task.receive(completionHandler: completionHandler) + } +} + +protocol WebSocketSessioning: AnyObject { + func makeWebSocketTask(url: URL) -> WebSocketTaskBox +} + +extension URLSession: WebSocketSessioning { + func makeWebSocketTask(url: URL) -> WebSocketTaskBox { + WebSocketTaskBox(task: self.webSocketTask(with: url)) + } +} + struct GatewayEvent: Codable { let type: String let event: String? @@ -18,16 +63,16 @@ extension Notification.Name { static let gatewaySeqGap = Notification.Name("clawdis.gateway.seqgap") } -private actor GatewayChannelActor { +actor GatewayChannelActor { private let logger = Logger(subsystem: "com.steipete.clawdis", category: "gateway") - private var task: URLSessionWebSocketTask? + private var task: WebSocketTaskBox? private var pending: [String: CheckedContinuation] = [:] private var connected = false private var isConnecting = false private var connectWaiters: [CheckedContinuation] = [] private var url: URL private var token: String? - private let session = URLSession(configuration: .default) + private let session: WebSocketSessioning private var backoffMs: Double = 500 private var shouldReconnect = true private var lastSeq: Int? @@ -38,9 +83,10 @@ private actor GatewayChannelActor { private var watchdogTask: Task? private let defaultRequestTimeoutMs: Double = 15000 - init(url: URL, token: String?) { + init(url: URL, token: String?, session: WebSocketSessioning? = nil) { self.url = url self.token = token + self.session = session ?? URLSession(configuration: .default) Task { [weak self] in await self?.startWatchdog() } @@ -80,7 +126,7 @@ private actor GatewayChannelActor { defer { self.isConnecting = false } self.task?.cancel(with: .goingAway, reason: nil) - self.task = self.session.webSocketTask(with: self.url) + self.task = self.session.makeWebSocketTask(url: self.url) self.task?.resume() do { try await self.sendHello() diff --git a/apps/macos/Tests/ClawdisIPCTests/GatewayChannelConnectTests.swift b/apps/macos/Tests/ClawdisIPCTests/GatewayChannelConnectTests.swift new file mode 100644 index 000000000..1f4085298 --- /dev/null +++ b/apps/macos/Tests/ClawdisIPCTests/GatewayChannelConnectTests.swift @@ -0,0 +1,136 @@ +import Foundation +import os +import Testing +@testable import Clawdis + +@Suite struct GatewayChannelConnectTests { + private enum FakeResponse { + case helloOk(delayMs: Int) + case invalid(delayMs: Int) + } + + private final class FakeWebSocketTask: WebSocketTasking, @unchecked Sendable { + private let response: FakeResponse + private let pendingReceiveHandler = + OSAllocatedUnfairLock<(@Sendable (Result) -> Void)?>( + initialState: nil) + + var state: URLSessionTask.State = .suspended + + init(response: FakeResponse) { + self.response = response + } + + func resume() { + self.state = .running + } + + func cancel(with closeCode: URLSessionWebSocketTask.CloseCode, reason: Data?) { + _ = (closeCode, reason) + self.state = .canceling + let handler = self.pendingReceiveHandler.withLock { handler in + defer { handler = nil } + return handler + } + handler?(Result.failure(URLError(.cancelled))) + } + + func send(_ message: URLSessionWebSocketTask.Message) async throws { + _ = message + } + + func receive() async throws -> URLSessionWebSocketTask.Message { + let (delayMs, msg): (Int, URLSessionWebSocketTask.Message) = switch self.response { + case let .helloOk(delayMs): + (delayMs, .data(Self.helloOkData())) + case let .invalid(delayMs): + (delayMs, .string("not json")) + } + try await Task.sleep(nanoseconds: UInt64(delayMs) * 1_000_000) + return msg + } + + func receive( + completionHandler: @escaping @Sendable (Result) -> Void) + { + // The production channel sets up a continuous receive loop after hello. + // Tests only need the handshake receive; keep the loop idle. + self.pendingReceiveHandler.withLock { $0 = completionHandler } + } + + private static func helloOkData() -> Data { + let json = """ + { + "type": "hello-ok", + "protocol": 1, + "server": { "version": "test", "connId": "test" }, + "features": { "methods": [], "events": [] }, + "snapshot": { + "presence": [ { "ts": 1 } ], + "health": {}, + "stateVersion": { "presence": 0, "health": 0 }, + "uptimeMs": 0 + }, + "policy": { "maxPayload": 1, "maxBufferedBytes": 1, "tickIntervalMs": 30000 } + } + """ + return Data(json.utf8) + } + } + + private final class FakeWebSocketSession: WebSocketSessioning, @unchecked Sendable { + private let response: FakeResponse + private let makeCount = OSAllocatedUnfairLock(initialState: 0) + + init(response: FakeResponse) { + self.response = response + } + + func snapshotMakeCount() -> Int { self.makeCount.withLock { $0 } } + + func makeWebSocketTask(url: URL) -> WebSocketTaskBox { + _ = url + self.makeCount.withLock { $0 += 1 } + let task = FakeWebSocketTask(response: self.response) + return WebSocketTaskBox(task: task) + } + } + + @Test func concurrentConnectIsSingleFlightOnSuccess() async throws { + let session = FakeWebSocketSession(response: .helloOk(delayMs: 200)) + let channel = GatewayChannelActor( + url: URL(string: "ws://example.invalid")!, + token: nil, + session: session) + + let t1 = Task { try await channel.connect() } + let t2 = Task { try await channel.connect() } + + _ = try await t1.value + _ = try await t2.value + + #expect(session.snapshotMakeCount() == 1) + } + + @Test func concurrentConnectSharesFailure() async { + let session = FakeWebSocketSession(response: .invalid(delayMs: 200)) + let channel = GatewayChannelActor( + url: URL(string: "ws://example.invalid")!, + token: nil, + session: session) + + let t1 = Task { try await channel.connect() } + let t2 = Task { try await channel.connect() } + + let r1 = await t1.result + let r2 = await t2.result + + #expect({ + if case .failure = r1 { true } else { false } + }()) + #expect({ + if case .failure = r2 { true } else { false } + }()) + #expect(session.snapshotMakeCount() == 1) + } +}