test(macos): cover gateway host resolution

This commit is contained in:
Peter Steinberger
2026-01-21 02:38:41 +00:00
parent d1c2fc4bc8
commit d3898ee8df
5 changed files with 88 additions and 53 deletions

View File

@@ -263,7 +263,8 @@ extension ConfigSettings {
let subsections = self.resolveSubsections(for: section) let subsections = self.resolveSubsections(for: section)
let resolved: (ConfigSchemaNode, ConfigPath) = { let resolved: (ConfigSchemaNode, ConfigPath) = {
if case let .key(key) = subsection, if case let .key(key) = subsection,
let match = subsections.first(where: { $0.key == key }) { let match = subsections.first(where: { $0.key == key })
{
return (match.node, match.path) return (match.node, match.path)
} }
return (self.resolvedSchemaNode(section.node), defaultPath) return (self.resolvedSchemaNode(section.node), defaultPath)

View File

@@ -319,7 +319,7 @@ private enum ExecHostExecutor {
security: context.security, security: context.security,
allowlistMatch: context.allowlistMatch, allowlistMatch: context.allowlistMatch,
skillAllow: context.skillAllow), skillAllow: context.skillAllow),
approvalDecision == nil approvalDecision == nil
{ {
let decision = ExecApprovalsPromptPresenter.prompt( let decision = ExecApprovalsPromptPresenter.prompt(
ExecApprovalPromptRequest( ExecApprovalPromptRequest(

View File

@@ -634,11 +634,12 @@ extension GatewayEndpointStore {
static func _testResolveLocalGatewayHost( static func _testResolveLocalGatewayHost(
bindMode: String?, bindMode: String?,
tailscaleIP: String?) -> String tailscaleIP: String?,
customBindHost: String? = nil) -> String
{ {
self.resolveLocalGatewayHost( self.resolveLocalGatewayHost(
bindMode: bindMode, bindMode: bindMode,
customBindHost: nil, customBindHost: customBindHost,
tailscaleIP: tailscaleIP) tailscaleIP: tailscaleIP)
} }
} }

View File

@@ -10,7 +10,7 @@ struct ConnectOptions {
var token: String? var token: String?
var password: String? var password: String?
var mode: String? var mode: String?
var timeoutMs: Int = 15_000 var timeoutMs: Int = 15000
var json: Bool = false var json: Bool = false
var probe: Bool = false var probe: Bool = false
var clientId: String = "clawdbot-macos" var clientId: String = "clawdbot-macos"
@@ -22,53 +22,43 @@ struct ConnectOptions {
static func parse(_ args: [String]) -> ConnectOptions { static func parse(_ args: [String]) -> ConnectOptions {
var opts = ConnectOptions() var opts = ConnectOptions()
let flagHandlers: [String: (inout ConnectOptions) -> Void] = [
"-h": { $0.help = true },
"--help": { $0.help = true },
"--json": { $0.json = true },
"--probe": { $0.probe = true },
]
let valueHandlers: [String: (inout ConnectOptions, String) -> Void] = [
"--url": { $0.url = $1 },
"--token": { $0.token = $1 },
"--password": { $0.password = $1 },
"--mode": { $0.mode = $1 },
"--timeout": { opts, raw in
if let parsed = Int(raw.trimmingCharacters(in: .whitespacesAndNewlines)) {
opts.timeoutMs = max(250, parsed)
}
},
"--client-id": { $0.clientId = $1 },
"--client-mode": { $0.clientMode = $1 },
"--display-name": { $0.displayName = $1 },
"--role": { $0.role = $1 },
"--scopes": { opts, raw in
opts.scopes = raw.split(separator: ",").map { $0.trimmingCharacters(in: .whitespacesAndNewlines) }
.filter { !$0.isEmpty }
},
]
var i = 0 var i = 0
while i < args.count { while i < args.count {
let arg = args[i] let arg = args[i]
switch arg { if let handler = flagHandlers[arg] {
case "-h", "--help": handler(&opts)
opts.help = true i += 1
case "--json": continue
opts.json = true }
case "--probe": if let handler = valueHandlers[arg], let value = self.nextValue(args, index: &i) {
opts.probe = true handler(&opts, value)
case "--url": i += 1
opts.url = self.nextValue(args, index: &i) continue
case "--token":
opts.token = self.nextValue(args, index: &i)
case "--password":
opts.password = self.nextValue(args, index: &i)
case "--mode":
if let value = self.nextValue(args, index: &i) {
opts.mode = value
}
case "--timeout":
if let raw = self.nextValue(args, index: &i),
let parsed = Int(raw.trimmingCharacters(in: .whitespacesAndNewlines))
{
opts.timeoutMs = max(250, parsed)
}
case "--client-id":
if let value = self.nextValue(args, index: &i) {
opts.clientId = value
}
case "--client-mode":
if let value = self.nextValue(args, index: &i) {
opts.clientMode = value
}
case "--display-name":
opts.displayName = self.nextValue(args, index: &i)
case "--role":
if let value = self.nextValue(args, index: &i) {
opts.role = value
}
case "--scopes":
if let value = self.nextValue(args, index: &i) {
opts.scopes = value.split(separator: ",").map { $0.trimmingCharacters(in: .whitespacesAndNewlines) }
.filter { !$0.isEmpty }
}
default:
break
} }
i += 1 i += 1
} }
@@ -257,8 +247,12 @@ private func resolveGatewayEndpoint(opts: ConnectOptions, config: GatewayConfig)
if resolvedMode == "remote" { if resolvedMode == "remote" {
guard let raw = config.remoteUrl?.trimmingCharacters(in: .whitespacesAndNewlines), guard let raw = config.remoteUrl?.trimmingCharacters(in: .whitespacesAndNewlines),
!raw.isEmpty else { !raw.isEmpty
throw NSError(domain: "Gateway", code: 1, userInfo: [NSLocalizedDescriptionKey: "gateway.remote.url is missing"]) else {
throw NSError(
domain: "Gateway",
code: 1,
userInfo: [NSLocalizedDescriptionKey: "gateway.remote.url is missing"])
} }
guard let url = URL(string: raw) else { guard let url = URL(string: raw) else {
throw NSError(domain: "Gateway", code: 1, userInfo: [NSLocalizedDescriptionKey: "invalid url: \(raw)"]) throw NSError(domain: "Gateway", code: 1, userInfo: [NSLocalizedDescriptionKey: "invalid url: \(raw)"])
@@ -273,7 +267,10 @@ private func resolveGatewayEndpoint(opts: ConnectOptions, config: GatewayConfig)
let port = config.port ?? 18789 let port = config.port ?? 18789
let host = resolveLocalHost(bind: config.bind) let host = resolveLocalHost(bind: config.bind)
guard let url = URL(string: "ws://\(host):\(port)") else { guard let url = URL(string: "ws://\(host):\(port)") else {
throw NSError(domain: "Gateway", code: 1, userInfo: [NSLocalizedDescriptionKey: "invalid url: ws://\(host):\(port)"]) throw NSError(
domain: "Gateway",
code: 1,
userInfo: [NSLocalizedDescriptionKey: "invalid url: ws://\(host):\(port)"])
} }
return GatewayEndpoint( return GatewayEndpoint(
url: url, url: url,
@@ -283,7 +280,7 @@ private func resolveGatewayEndpoint(opts: ConnectOptions, config: GatewayConfig)
} }
private func bestEffortEndpoint(opts: ConnectOptions, config: GatewayConfig) -> GatewayEndpoint? { private func bestEffortEndpoint(opts: ConnectOptions, config: GatewayConfig) -> GatewayEndpoint? {
return try? resolveGatewayEndpoint(opts: opts, config: config) try? resolveGatewayEndpoint(opts: opts, config: config)
} }
private func resolvedToken(opts: ConnectOptions, mode: String, config: GatewayConfig) -> String? { private func resolvedToken(opts: ConnectOptions, mode: String, config: GatewayConfig) -> String? {

View File

@@ -139,4 +139,40 @@ import Testing
let resolved = ConnectionModeResolver.resolve(root: root, defaults: defaults) let resolved = ConnectionModeResolver.resolve(root: root, defaults: defaults)
#expect(resolved.mode == .remote) #expect(resolved.mode == .remote)
} }
@Test func resolveLocalGatewayHostPrefersTailnetForAuto() {
let host = GatewayEndpointStore._testResolveLocalGatewayHost(
bindMode: "auto",
tailscaleIP: "100.64.1.2")
#expect(host == "100.64.1.2")
}
@Test func resolveLocalGatewayHostFallsBackToLoopbackForAuto() {
let host = GatewayEndpointStore._testResolveLocalGatewayHost(
bindMode: "auto",
tailscaleIP: nil)
#expect(host == "127.0.0.1")
}
@Test func resolveLocalGatewayHostPrefersTailnetForTailnetMode() {
let host = GatewayEndpointStore._testResolveLocalGatewayHost(
bindMode: "tailnet",
tailscaleIP: "100.64.1.5")
#expect(host == "100.64.1.5")
}
@Test func resolveLocalGatewayHostFallsBackToLoopbackForTailnetMode() {
let host = GatewayEndpointStore._testResolveLocalGatewayHost(
bindMode: "tailnet",
tailscaleIP: nil)
#expect(host == "127.0.0.1")
}
@Test func resolveLocalGatewayHostUsesCustomBindHost() {
let host = GatewayEndpointStore._testResolveLocalGatewayHost(
bindMode: "custom",
tailscaleIP: "100.64.1.9",
customBindHost: "192.168.1.10")
#expect(host == "192.168.1.10")
}
} }