feat: migrate android node to gateway ws

This commit is contained in:
Peter Steinberger
2026-01-19 11:05:50 +00:00
parent fcea6303ed
commit e6a4cf01ee
18 changed files with 1793 additions and 288 deletions

View File

@@ -103,6 +103,7 @@ dependencies {
implementation("androidx.security:security-crypto:1.1.0")
implementation("androidx.exifinterface:exifinterface:1.4.2")
implementation("com.squareup.okhttp3:okhttp:4.12.0")
// CameraX (for node.invoke camera.* parity)
implementation("androidx.camera:camera-core:1.5.2")

View File

@@ -2,7 +2,7 @@ package com.clawdbot.android
import android.app.Application
import androidx.lifecycle.AndroidViewModel
import com.clawdbot.android.bridge.BridgeEndpoint
import com.clawdbot.android.gateway.GatewayEndpoint
import com.clawdbot.android.chat.OutgoingAttachment
import com.clawdbot.android.node.CameraCaptureManager
import com.clawdbot.android.node.CanvasController
@@ -18,7 +18,7 @@ class MainViewModel(app: Application) : AndroidViewModel(app) {
val screenRecorder: ScreenRecordManager = runtime.screenRecorder
val sms: SmsManager = runtime.sms
val bridges: StateFlow<List<BridgeEndpoint>> = runtime.bridges
val gateways: StateFlow<List<GatewayEndpoint>> = runtime.gateways
val discoveryStatusText: StateFlow<String> = runtime.discoveryStatusText
val isConnected: StateFlow<Boolean> = runtime.isConnected
@@ -50,6 +50,7 @@ class MainViewModel(app: Application) : AndroidViewModel(app) {
val manualEnabled: StateFlow<Boolean> = runtime.manualEnabled
val manualHost: StateFlow<String> = runtime.manualHost
val manualPort: StateFlow<Int> = runtime.manualPort
val manualTls: StateFlow<Boolean> = runtime.manualTls
val canvasDebugStatusEnabled: StateFlow<Boolean> = runtime.canvasDebugStatusEnabled
val chatSessionKey: StateFlow<String> = runtime.chatSessionKey
@@ -99,6 +100,10 @@ class MainViewModel(app: Application) : AndroidViewModel(app) {
runtime.setManualPort(value)
}
fun setManualTls(value: Boolean) {
runtime.setManualTls(value)
}
fun setCanvasDebugStatusEnabled(value: Boolean) {
runtime.setCanvasDebugStatusEnabled(value)
}
@@ -119,11 +124,11 @@ class MainViewModel(app: Application) : AndroidViewModel(app) {
runtime.setTalkEnabled(enabled)
}
fun refreshBridgeHello() {
runtime.refreshBridgeHello()
fun refreshGatewayConnection() {
runtime.refreshGatewayConnection()
}
fun connect(endpoint: BridgeEndpoint) {
fun connect(endpoint: GatewayEndpoint) {
runtime.connect(endpoint)
}

View File

@@ -12,11 +12,12 @@ import com.clawdbot.android.chat.ChatMessage
import com.clawdbot.android.chat.ChatPendingToolCall
import com.clawdbot.android.chat.ChatSessionEntry
import com.clawdbot.android.chat.OutgoingAttachment
import com.clawdbot.android.bridge.BridgeDiscovery
import com.clawdbot.android.bridge.BridgeEndpoint
import com.clawdbot.android.bridge.BridgePairingClient
import com.clawdbot.android.bridge.BridgeSession
import com.clawdbot.android.bridge.BridgeTlsParams
import com.clawdbot.android.gateway.DeviceIdentityStore
import com.clawdbot.android.gateway.GatewayConnectOptions
import com.clawdbot.android.gateway.GatewayDiscovery
import com.clawdbot.android.gateway.GatewayEndpoint
import com.clawdbot.android.gateway.GatewaySession
import com.clawdbot.android.gateway.GatewayTlsParams
import com.clawdbot.android.node.CameraCaptureManager
import com.clawdbot.android.node.LocationCaptureManager
import com.clawdbot.android.BuildConfig
@@ -74,7 +75,7 @@ class NodeRuntime(context: Context) {
context = appContext,
scope = scope,
onCommand = { command ->
session.sendEvent(
nodeSession.sendNodeEvent(
event = "agent.request",
payloadJson =
buildJsonObject {
@@ -103,10 +104,12 @@ class NodeRuntime(context: Context) {
val talkIsSpeaking: StateFlow<Boolean>
get() = talkMode.isSpeaking
private val discovery = BridgeDiscovery(appContext, scope = scope)
val bridges: StateFlow<List<BridgeEndpoint>> = discovery.bridges
private val discovery = GatewayDiscovery(appContext, scope = scope)
val gateways: StateFlow<List<GatewayEndpoint>> = discovery.gateways
val discoveryStatusText: StateFlow<String> = discovery.statusText
private val identityStore = DeviceIdentityStore(appContext)
private val _isConnected = MutableStateFlow(false)
val isConnected: StateFlow<Boolean> = _isConnected.asStateFlow()
@@ -139,52 +142,87 @@ class NodeRuntime(context: Context) {
val isForeground: StateFlow<Boolean> = _isForeground.asStateFlow()
private var lastAutoA2uiUrl: String? = null
private var operatorConnected = false
private var nodeConnected = false
private var operatorStatusText: String = "Offline"
private var nodeStatusText: String = "Offline"
private var connectedEndpoint: GatewayEndpoint? = null
private val session =
BridgeSession(
private val operatorSession =
GatewaySession(
scope = scope,
identityStore = identityStore,
onConnected = { name, remote, mainSessionKey ->
_statusText.value = "Connected"
operatorConnected = true
operatorStatusText = "Connected"
_serverName.value = name
_remoteAddress.value = remote
_isConnected.value = true
_seamColorArgb.value = DEFAULT_SEAM_COLOR_ARGB
applyMainSessionKey(mainSessionKey)
updateStatus()
scope.launch { refreshBrandingFromGateway() }
scope.launch { refreshWakeWordsFromGateway() }
},
onDisconnected = { message ->
operatorConnected = false
operatorStatusText = message
_serverName.value = null
_remoteAddress.value = null
_seamColorArgb.value = DEFAULT_SEAM_COLOR_ARGB
if (!isCanonicalMainSessionKey(_mainSessionKey.value)) {
_mainSessionKey.value = "main"
}
val mainKey = resolveMainSessionKey()
talkMode.setMainSessionKey(mainKey)
chat.applyMainSessionKey(mainKey)
chat.onDisconnected(message)
updateStatus()
},
onEvent = { event, payloadJson ->
handleGatewayEvent(event, payloadJson)
},
)
private val nodeSession =
GatewaySession(
scope = scope,
identityStore = identityStore,
onConnected = { _, _, _ ->
nodeConnected = true
nodeStatusText = "Connected"
updateStatus()
maybeNavigateToA2uiOnConnect()
},
onDisconnected = { message -> handleSessionDisconnected(message) },
onEvent = { event, payloadJson ->
handleBridgeEvent(event, payloadJson)
onDisconnected = { message ->
nodeConnected = false
nodeStatusText = message
updateStatus()
showLocalCanvasOnDisconnect()
},
onEvent = { _, _ -> },
onInvoke = { req ->
handleInvoke(req.command, req.paramsJson)
},
onTlsFingerprint = { stableId, fingerprint ->
prefs.saveBridgeTlsFingerprint(stableId, fingerprint)
prefs.saveGatewayTlsFingerprint(stableId, fingerprint)
},
)
private val chat = ChatController(scope = scope, session = session, json = json)
private val chat =
ChatController(
scope = scope,
session = operatorSession,
json = json,
supportsChatSubscribe = false,
)
private val talkMode: TalkModeManager by lazy {
TalkModeManager(context = appContext, scope = scope).also { it.attachSession(session) }
}
private fun handleSessionDisconnected(message: String) {
_statusText.value = message
_serverName.value = null
_remoteAddress.value = null
_isConnected.value = false
_seamColorArgb.value = DEFAULT_SEAM_COLOR_ARGB
if (!isCanonicalMainSessionKey(_mainSessionKey.value)) {
_mainSessionKey.value = "main"
}
val mainKey = resolveMainSessionKey()
talkMode.setMainSessionKey(mainKey)
chat.applyMainSessionKey(mainKey)
chat.onDisconnected(message)
showLocalCanvasOnDisconnect()
TalkModeManager(
context = appContext,
scope = scope,
session = operatorSession,
supportsChatSubscribe = false,
isConnected = { operatorConnected },
)
}
private fun applyMainSessionKey(candidate: String?) {
@@ -197,6 +235,18 @@ class NodeRuntime(context: Context) {
chat.applyMainSessionKey(trimmed)
}
private fun updateStatus() {
_isConnected.value = operatorConnected
_statusText.value =
when {
operatorConnected && nodeConnected -> "Connected"
operatorConnected && !nodeConnected -> "Connected (node offline)"
!operatorConnected && nodeConnected -> "Connected (operator offline)"
operatorStatusText.isNotBlank() && operatorStatusText != "Offline" -> operatorStatusText
else -> nodeStatusText
}
}
private fun resolveMainSessionKey(): String {
val trimmed = _mainSessionKey.value.trim()
return if (trimmed.isEmpty()) "main" else trimmed
@@ -228,6 +278,7 @@ class NodeRuntime(context: Context) {
val manualEnabled: StateFlow<Boolean> = prefs.manualEnabled
val manualHost: StateFlow<String> = prefs.manualHost
val manualPort: StateFlow<Int> = prefs.manualPort
val manualTls: StateFlow<Boolean> = prefs.manualTls
val lastDiscoveredStableId: StateFlow<String> = prefs.lastDiscoveredStableId
val canvasDebugStatusEnabled: StateFlow<Boolean> = prefs.canvasDebugStatusEnabled
@@ -288,24 +339,21 @@ class NodeRuntime(context: Context) {
}
scope.launch(Dispatchers.Default) {
bridges.collect { list ->
gateways.collect { list ->
if (list.isNotEmpty()) {
// Persist the last discovered bridge (best-effort UX parity with iOS).
// Persist the last discovered gateway (best-effort UX parity with iOS).
prefs.setLastDiscoveredStableId(list.last().stableId)
}
if (didAutoConnect) return@collect
if (_isConnected.value) return@collect
val token = prefs.loadBridgeToken()
if (token.isNullOrBlank()) return@collect
if (manualEnabled.value) {
val host = manualHost.value.trim()
val port = manualPort.value
if (host.isNotEmpty() && port in 1..65535) {
didAutoConnect = true
connect(BridgeEndpoint.manual(host = host, port = port))
connect(GatewayEndpoint.manual(host = host, port = port))
}
return@collect
}
@@ -371,6 +419,10 @@ class NodeRuntime(context: Context) {
prefs.setManualPort(value)
}
fun setManualTls(value: Boolean) {
prefs.setManualTls(value)
}
fun setCanvasDebugStatusEnabled(value: Boolean) {
prefs.setCanvasDebugStatusEnabled(value)
}
@@ -429,99 +481,78 @@ class NodeRuntime(context: Context) {
}
}
private fun buildPairingHello(token: String?): BridgePairingClient.Hello {
val modelIdentifier = listOfNotNull(Build.MANUFACTURER, Build.MODEL)
.joinToString(" ")
.trim()
.ifEmpty { null }
private fun resolvedVersionName(): String {
val versionName = BuildConfig.VERSION_NAME.trim().ifEmpty { "dev" }
val advertisedVersion =
if (BuildConfig.DEBUG && !versionName.contains("dev", ignoreCase = true)) {
"$versionName-dev"
} else {
versionName
}
return BridgePairingClient.Hello(
nodeId = instanceId.value,
displayName = displayName.value,
token = token,
platform = "Android",
version = advertisedVersion,
deviceFamily = "Android",
modelIdentifier = modelIdentifier,
caps = buildCapabilities(),
commands = buildInvokeCommands(),
)
}
private fun buildSessionHello(token: String?): BridgeSession.Hello {
val modelIdentifier = listOfNotNull(Build.MANUFACTURER, Build.MODEL)
.joinToString(" ")
.trim()
.ifEmpty { null }
val versionName = BuildConfig.VERSION_NAME.trim().ifEmpty { "dev" }
val advertisedVersion =
if (BuildConfig.DEBUG && !versionName.contains("dev", ignoreCase = true)) {
"$versionName-dev"
} else {
versionName
}
return BridgeSession.Hello(
nodeId = instanceId.value,
displayName = displayName.value,
token = token,
platform = "Android",
version = advertisedVersion,
deviceFamily = "Android",
modelIdentifier = modelIdentifier,
caps = buildCapabilities(),
commands = buildInvokeCommands(),
)
}
fun refreshBridgeHello() {
scope.launch {
if (!_isConnected.value) return@launch
val token = prefs.loadBridgeToken()
if (token.isNullOrBlank()) return@launch
session.updateHello(buildSessionHello(token))
return if (BuildConfig.DEBUG && !versionName.contains("dev", ignoreCase = true)) {
"$versionName-dev"
} else {
versionName
}
}
fun connect(endpoint: BridgeEndpoint) {
scope.launch {
_statusText.value = "Connecting…"
val storedToken = prefs.loadBridgeToken()
val tls = resolveTlsParams(endpoint)
val resolved =
if (storedToken.isNullOrBlank()) {
_statusText.value = "Pairing…"
BridgePairingClient().pairAndHello(
endpoint = endpoint,
hello = buildPairingHello(token = null),
tls = tls,
onTlsFingerprint = { fingerprint ->
prefs.saveBridgeTlsFingerprint(endpoint.stableId, fingerprint)
},
)
} else {
BridgePairingClient.PairResult(ok = true, token = storedToken.trim())
}
private fun resolveModelIdentifier(): String? {
return listOfNotNull(Build.MANUFACTURER, Build.MODEL)
.joinToString(" ")
.trim()
.ifEmpty { null }
}
if (!resolved.ok || resolved.token.isNullOrBlank()) {
val errorMessage = resolved.error?.trim().orEmpty().ifEmpty { "pairing required" }
_statusText.value = "Failed: $errorMessage"
return@launch
}
private fun buildClientInfo(clientId: String, clientMode: String): GatewayClientInfo {
return GatewayClientInfo(
id = clientId,
displayName = displayName.value,
version = resolvedVersionName(),
platform = "android",
mode = clientMode,
instanceId = instanceId.value,
deviceFamily = "Android",
modelIdentifier = resolveModelIdentifier(),
)
}
val authToken = requireNotNull(resolved.token).trim()
prefs.saveBridgeToken(authToken)
session.connect(
endpoint = endpoint,
hello = buildSessionHello(token = authToken),
tls = tls,
)
}
private fun buildNodeConnectOptions(): GatewayConnectOptions {
return GatewayConnectOptions(
role = "node",
scopes = emptyList(),
caps = buildCapabilities(),
commands = buildInvokeCommands(),
permissions = emptyMap(),
client = buildClientInfo(clientId = "node-host", clientMode = "node"),
)
}
private fun buildOperatorConnectOptions(): GatewayConnectOptions {
return GatewayConnectOptions(
role = "operator",
scopes = emptyList(),
caps = emptyList(),
commands = emptyList(),
permissions = emptyMap(),
client = buildClientInfo(clientId = "clawdbot-control-ui", clientMode = "ui"),
)
}
fun refreshGatewayConnection() {
val endpoint = connectedEndpoint ?: return
val token = prefs.loadGatewayToken()
val password = prefs.loadGatewayPassword()
val tls = resolveTlsParams(endpoint)
operatorSession.connect(endpoint, token, password, buildOperatorConnectOptions(), tls)
nodeSession.connect(endpoint, token, password, buildNodeConnectOptions(), tls)
operatorSession.reconnect()
nodeSession.reconnect()
}
fun connect(endpoint: GatewayEndpoint) {
connectedEndpoint = endpoint
operatorStatusText = "Connecting…"
nodeStatusText = "Connecting…"
updateStatus()
val token = prefs.loadGatewayToken()
val password = prefs.loadGatewayPassword()
val tls = resolveTlsParams(endpoint)
operatorSession.connect(endpoint, token, password, buildOperatorConnectOptions(), tls)
nodeSession.connect(endpoint, token, password, buildNodeConnectOptions(), tls)
}
private fun hasRecordAudioPermission(): Boolean {
@@ -559,20 +590,32 @@ class NodeRuntime(context: Context) {
_statusText.value = "Failed: invalid manual host/port"
return
}
connect(BridgeEndpoint.manual(host = host, port = port))
connect(GatewayEndpoint.manual(host = host, port = port))
}
fun disconnect() {
session.disconnect()
connectedEndpoint = null
operatorSession.disconnect()
nodeSession.disconnect()
}
private fun resolveTlsParams(endpoint: BridgeEndpoint): BridgeTlsParams? {
val stored = prefs.loadBridgeTlsFingerprint(endpoint.stableId)
private fun resolveTlsParams(endpoint: GatewayEndpoint): GatewayTlsParams? {
val stored = prefs.loadGatewayTlsFingerprint(endpoint.stableId)
val hinted = endpoint.tlsEnabled || !endpoint.tlsFingerprintSha256.isNullOrBlank()
val manual = endpoint.stableId.startsWith("manual|")
if (manual) {
if (!manualTls.value) return null
return GatewayTlsParams(
required = true,
expectedFingerprint = endpoint.tlsFingerprintSha256 ?: stored,
allowTOFU = stored == null,
stableId = endpoint.stableId,
)
}
if (hinted) {
return BridgeTlsParams(
return GatewayTlsParams(
required = true,
expectedFingerprint = endpoint.tlsFingerprintSha256 ?: stored,
allowTOFU = stored == null,
@@ -581,7 +624,7 @@ class NodeRuntime(context: Context) {
}
if (!stored.isNullOrBlank()) {
return BridgeTlsParams(
return GatewayTlsParams(
required = true,
expectedFingerprint = stored,
allowTOFU = false,
@@ -589,15 +632,6 @@ class NodeRuntime(context: Context) {
)
}
if (manual) {
return BridgeTlsParams(
required = false,
expectedFingerprint = null,
allowTOFU = true,
stableId = endpoint.stableId,
)
}
return null
}
@@ -637,11 +671,11 @@ class NodeRuntime(context: Context) {
contextJson = contextJson,
)
val connected = isConnected.value
val connected = nodeConnected
var error: String? = null
if (connected) {
try {
session.sendEvent(
nodeSession.sendNodeEvent(
event = "agent.request",
payloadJson =
buildJsonObject {
@@ -656,7 +690,7 @@ class NodeRuntime(context: Context) {
error = e.message ?: "send failed"
}
} else {
error = "bridge not connected"
error = "gateway not connected"
}
try {
@@ -702,7 +736,7 @@ class NodeRuntime(context: Context) {
chat.sendMessage(message = message, thinkingLevel = thinking, attachments = attachments)
}
private fun handleBridgeEvent(event: String, payloadJson: String?) {
private fun handleGatewayEvent(event: String, payloadJson: String?) {
if (event == "voicewake.changed") {
if (payloadJson.isNullOrBlank()) return
try {
@@ -716,8 +750,8 @@ class NodeRuntime(context: Context) {
return
}
talkMode.handleBridgeEvent(event, payloadJson)
chat.handleBridgeEvent(event, payloadJson)
talkMode.handleGatewayEvent(event, payloadJson)
chat.handleGatewayEvent(event, payloadJson)
}
private fun applyWakeWordsFromGateway(words: List<String>) {
@@ -738,7 +772,7 @@ class NodeRuntime(context: Context) {
val jsonList = snapshot.joinToString(separator = ",") { it.toJsonString() }
val params = """{"triggers":[$jsonList]}"""
try {
session.request("voicewake.set", params)
operatorSession.request("voicewake.set", params)
} catch (_: Throwable) {
// ignore
}
@@ -748,7 +782,7 @@ class NodeRuntime(context: Context) {
private suspend fun refreshWakeWordsFromGateway() {
if (!_isConnected.value) return
try {
val res = session.request("voicewake.get", "{}")
val res = operatorSession.request("voicewake.get", "{}")
val payload = json.parseToJsonElement(res).asObjectOrNull() ?: return
val array = payload["triggers"] as? JsonArray ?: return
val triggers = array.mapNotNull { it.asStringOrNull() }
@@ -761,7 +795,7 @@ class NodeRuntime(context: Context) {
private suspend fun refreshBrandingFromGateway() {
if (!_isConnected.value) return
try {
val res = session.request("config.get", "{}")
val res = operatorSession.request("config.get", "{}")
val root = json.parseToJsonElement(res).asObjectOrNull()
val config = root?.get("config").asObjectOrNull()
val ui = config?.get("ui").asObjectOrNull()
@@ -777,7 +811,7 @@ class NodeRuntime(context: Context) {
}
}
private suspend fun handleInvoke(command: String, paramsJson: String?): BridgeSession.InvokeResult {
private suspend fun handleInvoke(command: String, paramsJson: String?): GatewaySession.InvokeResult {
if (
command.startsWith(ClawdbotCanvasCommand.NamespacePrefix) ||
command.startsWith(ClawdbotCanvasA2UICommand.NamespacePrefix) ||
@@ -785,14 +819,14 @@ class NodeRuntime(context: Context) {
command.startsWith(ClawdbotScreenCommand.NamespacePrefix)
) {
if (!isForeground.value) {
return BridgeSession.InvokeResult.error(
return GatewaySession.InvokeResult.error(
code = "NODE_BACKGROUND_UNAVAILABLE",
message = "NODE_BACKGROUND_UNAVAILABLE: canvas/camera/screen commands require foreground",
)
}
}
if (command.startsWith(ClawdbotCameraCommand.NamespacePrefix) && !cameraEnabled.value) {
return BridgeSession.InvokeResult.error(
return GatewaySession.InvokeResult.error(
code = "CAMERA_DISABLED",
message = "CAMERA_DISABLED: enable Camera in Settings",
)
@@ -800,7 +834,7 @@ class NodeRuntime(context: Context) {
if (command.startsWith(ClawdbotLocationCommand.NamespacePrefix) &&
locationMode.value == LocationMode.Off
) {
return BridgeSession.InvokeResult.error(
return GatewaySession.InvokeResult.error(
code = "LOCATION_DISABLED",
message = "LOCATION_DISABLED: enable Location in Settings",
)
@@ -810,18 +844,18 @@ class NodeRuntime(context: Context) {
ClawdbotCanvasCommand.Present.rawValue -> {
val url = CanvasController.parseNavigateUrl(paramsJson)
canvas.navigate(url)
BridgeSession.InvokeResult.ok(null)
GatewaySession.InvokeResult.ok(null)
}
ClawdbotCanvasCommand.Hide.rawValue -> BridgeSession.InvokeResult.ok(null)
ClawdbotCanvasCommand.Hide.rawValue -> GatewaySession.InvokeResult.ok(null)
ClawdbotCanvasCommand.Navigate.rawValue -> {
val url = CanvasController.parseNavigateUrl(paramsJson)
canvas.navigate(url)
BridgeSession.InvokeResult.ok(null)
GatewaySession.InvokeResult.ok(null)
}
ClawdbotCanvasCommand.Eval.rawValue -> {
val js =
CanvasController.parseEvalJs(paramsJson)
?: return BridgeSession.InvokeResult.error(
?: return GatewaySession.InvokeResult.error(
code = "INVALID_REQUEST",
message = "INVALID_REQUEST: javaScript required",
)
@@ -829,12 +863,12 @@ class NodeRuntime(context: Context) {
try {
canvas.eval(js)
} catch (err: Throwable) {
return BridgeSession.InvokeResult.error(
return GatewaySession.InvokeResult.error(
code = "NODE_BACKGROUND_UNAVAILABLE",
message = "NODE_BACKGROUND_UNAVAILABLE: canvas unavailable",
)
}
BridgeSession.InvokeResult.ok("""{"result":${result.toJsonString()}}""")
GatewaySession.InvokeResult.ok("""{"result":${result.toJsonString()}}""")
}
ClawdbotCanvasCommand.Snapshot.rawValue -> {
val snapshotParams = CanvasController.parseSnapshotParams(paramsJson)
@@ -846,51 +880,51 @@ class NodeRuntime(context: Context) {
maxWidth = snapshotParams.maxWidth,
)
} catch (err: Throwable) {
return BridgeSession.InvokeResult.error(
return GatewaySession.InvokeResult.error(
code = "NODE_BACKGROUND_UNAVAILABLE",
message = "NODE_BACKGROUND_UNAVAILABLE: canvas unavailable",
)
}
BridgeSession.InvokeResult.ok("""{"format":"${snapshotParams.format.rawValue}","base64":"$base64"}""")
GatewaySession.InvokeResult.ok("""{"format":"${snapshotParams.format.rawValue}","base64":"$base64"}""")
}
ClawdbotCanvasA2UICommand.Reset.rawValue -> {
val a2uiUrl = resolveA2uiHostUrl()
?: return BridgeSession.InvokeResult.error(
?: return GatewaySession.InvokeResult.error(
code = "A2UI_HOST_NOT_CONFIGURED",
message = "A2UI_HOST_NOT_CONFIGURED: gateway did not advertise canvas host",
)
val ready = ensureA2uiReady(a2uiUrl)
if (!ready) {
return BridgeSession.InvokeResult.error(
return GatewaySession.InvokeResult.error(
code = "A2UI_HOST_UNAVAILABLE",
message = "A2UI host not reachable",
)
}
val res = canvas.eval(a2uiResetJS)
BridgeSession.InvokeResult.ok(res)
GatewaySession.InvokeResult.ok(res)
}
ClawdbotCanvasA2UICommand.Push.rawValue, ClawdbotCanvasA2UICommand.PushJSONL.rawValue -> {
val messages =
try {
decodeA2uiMessages(command, paramsJson)
} catch (err: Throwable) {
return BridgeSession.InvokeResult.error(code = "INVALID_REQUEST", message = err.message ?: "invalid A2UI payload")
return GatewaySession.InvokeResult.error(code = "INVALID_REQUEST", message = err.message ?: "invalid A2UI payload")
}
val a2uiUrl = resolveA2uiHostUrl()
?: return BridgeSession.InvokeResult.error(
?: return GatewaySession.InvokeResult.error(
code = "A2UI_HOST_NOT_CONFIGURED",
message = "A2UI_HOST_NOT_CONFIGURED: gateway did not advertise canvas host",
)
val ready = ensureA2uiReady(a2uiUrl)
if (!ready) {
return BridgeSession.InvokeResult.error(
return GatewaySession.InvokeResult.error(
code = "A2UI_HOST_UNAVAILABLE",
message = "A2UI host not reachable",
)
}
val js = a2uiApplyMessagesJS(messages)
val res = canvas.eval(js)
BridgeSession.InvokeResult.ok(res)
GatewaySession.InvokeResult.ok(res)
}
ClawdbotCameraCommand.Snap.rawValue -> {
showCameraHud(message = "Taking photo…", kind = CameraHudKind.Photo)
@@ -901,10 +935,10 @@ class NodeRuntime(context: Context) {
} catch (err: Throwable) {
val (code, message) = invokeErrorFromThrowable(err)
showCameraHud(message = message, kind = CameraHudKind.Error, autoHideMs = 2200)
return BridgeSession.InvokeResult.error(code = code, message = message)
return GatewaySession.InvokeResult.error(code = code, message = message)
}
showCameraHud(message = "Photo captured", kind = CameraHudKind.Success, autoHideMs = 1600)
BridgeSession.InvokeResult.ok(res.payloadJson)
GatewaySession.InvokeResult.ok(res.payloadJson)
}
ClawdbotCameraCommand.Clip.rawValue -> {
val includeAudio = paramsJson?.contains("\"includeAudio\":true") != false
@@ -917,10 +951,10 @@ class NodeRuntime(context: Context) {
} catch (err: Throwable) {
val (code, message) = invokeErrorFromThrowable(err)
showCameraHud(message = message, kind = CameraHudKind.Error, autoHideMs = 2400)
return BridgeSession.InvokeResult.error(code = code, message = message)
return GatewaySession.InvokeResult.error(code = code, message = message)
}
showCameraHud(message = "Clip captured", kind = CameraHudKind.Success, autoHideMs = 1800)
BridgeSession.InvokeResult.ok(res.payloadJson)
GatewaySession.InvokeResult.ok(res.payloadJson)
} finally {
if (includeAudio) externalAudioCaptureActive.value = false
}
@@ -928,19 +962,19 @@ class NodeRuntime(context: Context) {
ClawdbotLocationCommand.Get.rawValue -> {
val mode = locationMode.value
if (!isForeground.value && mode != LocationMode.Always) {
return BridgeSession.InvokeResult.error(
return GatewaySession.InvokeResult.error(
code = "LOCATION_BACKGROUND_UNAVAILABLE",
message = "LOCATION_BACKGROUND_UNAVAILABLE: background location requires Always",
)
}
if (!hasFineLocationPermission() && !hasCoarseLocationPermission()) {
return BridgeSession.InvokeResult.error(
return GatewaySession.InvokeResult.error(
code = "LOCATION_PERMISSION_REQUIRED",
message = "LOCATION_PERMISSION_REQUIRED: grant Location permission",
)
}
if (!isForeground.value && mode == LocationMode.Always && !hasBackgroundLocationPermission()) {
return BridgeSession.InvokeResult.error(
return GatewaySession.InvokeResult.error(
code = "LOCATION_PERMISSION_REQUIRED",
message = "LOCATION_PERMISSION_REQUIRED: enable Always in system Settings",
)
@@ -967,15 +1001,15 @@ class NodeRuntime(context: Context) {
timeoutMs = timeoutMs,
isPrecise = accuracy == "precise",
)
BridgeSession.InvokeResult.ok(payload.payloadJson)
GatewaySession.InvokeResult.ok(payload.payloadJson)
} catch (err: TimeoutCancellationException) {
BridgeSession.InvokeResult.error(
GatewaySession.InvokeResult.error(
code = "LOCATION_TIMEOUT",
message = "LOCATION_TIMEOUT: no fix in time",
)
} catch (err: Throwable) {
val message = err.message ?: "LOCATION_UNAVAILABLE: no fix"
BridgeSession.InvokeResult.error(code = "LOCATION_UNAVAILABLE", message = message)
GatewaySession.InvokeResult.error(code = "LOCATION_UNAVAILABLE", message = message)
}
}
ClawdbotScreenCommand.Record.rawValue -> {
@@ -987,9 +1021,9 @@ class NodeRuntime(context: Context) {
screenRecorder.record(paramsJson)
} catch (err: Throwable) {
val (code, message) = invokeErrorFromThrowable(err)
return BridgeSession.InvokeResult.error(code = code, message = message)
return GatewaySession.InvokeResult.error(code = code, message = message)
}
BridgeSession.InvokeResult.ok(res.payloadJson)
GatewaySession.InvokeResult.ok(res.payloadJson)
} finally {
_screenRecordActive.value = false
}
@@ -997,16 +1031,16 @@ class NodeRuntime(context: Context) {
ClawdbotSmsCommand.Send.rawValue -> {
val res = sms.send(paramsJson)
if (res.ok) {
BridgeSession.InvokeResult.ok(res.payloadJson)
GatewaySession.InvokeResult.ok(res.payloadJson)
} else {
val error = res.error ?: "SMS_SEND_FAILED"
val idx = error.indexOf(':')
val code = if (idx > 0) error.substring(0, idx).trim() else "SMS_SEND_FAILED"
BridgeSession.InvokeResult.error(code = code, message = error)
GatewaySession.InvokeResult.error(code = code, message = error)
}
}
else ->
BridgeSession.InvokeResult.error(
GatewaySession.InvokeResult.error(
code = "INVALID_REQUEST",
message = "INVALID_REQUEST: unknown command",
)
@@ -1062,7 +1096,9 @@ class NodeRuntime(context: Context) {
}
private fun resolveA2uiHostUrl(): String? {
val raw = session.currentCanvasHostUrl()?.trim().orEmpty()
val nodeRaw = nodeSession.currentCanvasHostUrl()?.trim().orEmpty()
val operatorRaw = operatorSession.currentCanvasHostUrl()?.trim().orEmpty()
val raw = if (nodeRaw.isNotBlank()) nodeRaw else operatorRaw
if (raw.isBlank()) return null
val base = raw.trimEnd('/')
return "${base}/__clawdbot__/a2ui/?platform=android"

View File

@@ -58,17 +58,30 @@ class SecurePrefs(context: Context) {
private val _preventSleep = MutableStateFlow(prefs.getBoolean("screen.preventSleep", true))
val preventSleep: StateFlow<Boolean> = _preventSleep
private val _manualEnabled = MutableStateFlow(prefs.getBoolean("bridge.manual.enabled", false))
private val _manualEnabled =
MutableStateFlow(readBoolWithMigration("gateway.manual.enabled", "bridge.manual.enabled", false))
val manualEnabled: StateFlow<Boolean> = _manualEnabled
private val _manualHost = MutableStateFlow(prefs.getString("bridge.manual.host", "")!!)
private val _manualHost =
MutableStateFlow(readStringWithMigration("gateway.manual.host", "bridge.manual.host", ""))
val manualHost: StateFlow<String> = _manualHost
private val _manualPort = MutableStateFlow(prefs.getInt("bridge.manual.port", 18790))
private val _manualPort =
MutableStateFlow(readIntWithMigration("gateway.manual.port", "bridge.manual.port", 18789))
val manualPort: StateFlow<Int> = _manualPort
private val _manualTls =
MutableStateFlow(readBoolWithMigration("gateway.manual.tls", null, true))
val manualTls: StateFlow<Boolean> = _manualTls
private val _lastDiscoveredStableId =
MutableStateFlow(prefs.getString("bridge.lastDiscoveredStableId", "")!!)
MutableStateFlow(
readStringWithMigration(
"gateway.lastDiscoveredStableID",
"bridge.lastDiscoveredStableId",
"",
),
)
val lastDiscoveredStableId: StateFlow<String> = _lastDiscoveredStableId
private val _canvasDebugStatusEnabled =
@@ -86,7 +99,7 @@ class SecurePrefs(context: Context) {
fun setLastDiscoveredStableId(value: String) {
val trimmed = value.trim()
prefs.edit { putString("bridge.lastDiscoveredStableId", trimmed) }
prefs.edit { putString("gateway.lastDiscoveredStableID", trimmed) }
_lastDiscoveredStableId.value = trimmed
}
@@ -117,43 +130,62 @@ class SecurePrefs(context: Context) {
}
fun setManualEnabled(value: Boolean) {
prefs.edit { putBoolean("bridge.manual.enabled", value) }
prefs.edit { putBoolean("gateway.manual.enabled", value) }
_manualEnabled.value = value
}
fun setManualHost(value: String) {
val trimmed = value.trim()
prefs.edit { putString("bridge.manual.host", trimmed) }
prefs.edit { putString("gateway.manual.host", trimmed) }
_manualHost.value = trimmed
}
fun setManualPort(value: Int) {
prefs.edit { putInt("bridge.manual.port", value) }
prefs.edit { putInt("gateway.manual.port", value) }
_manualPort.value = value
}
fun setManualTls(value: Boolean) {
prefs.edit { putBoolean("gateway.manual.tls", value) }
_manualTls.value = value
}
fun setCanvasDebugStatusEnabled(value: Boolean) {
prefs.edit { putBoolean("canvas.debugStatusEnabled", value) }
_canvasDebugStatusEnabled.value = value
}
fun loadBridgeToken(): String? {
val key = "bridge.token.${_instanceId.value}"
return prefs.getString(key, null)
fun loadGatewayToken(): String? {
val key = "gateway.token.${_instanceId.value}"
val stored = prefs.getString(key, null)?.trim()
if (!stored.isNullOrEmpty()) return stored
val legacy = prefs.getString("bridge.token.${_instanceId.value}", null)?.trim()
return legacy?.takeIf { it.isNotEmpty() }
}
fun saveBridgeToken(token: String) {
val key = "bridge.token.${_instanceId.value}"
fun saveGatewayToken(token: String) {
val key = "gateway.token.${_instanceId.value}"
prefs.edit { putString(key, token.trim()) }
}
fun loadBridgeTlsFingerprint(stableId: String): String? {
val key = "bridge.tls.$stableId"
fun loadGatewayPassword(): String? {
val key = "gateway.password.${_instanceId.value}"
val stored = prefs.getString(key, null)?.trim()
return stored?.takeIf { it.isNotEmpty() }
}
fun saveGatewayPassword(password: String) {
val key = "gateway.password.${_instanceId.value}"
prefs.edit { putString(key, password.trim()) }
}
fun loadGatewayTlsFingerprint(stableId: String): String? {
val key = "gateway.tls.$stableId"
return prefs.getString(key, null)?.trim()?.takeIf { it.isNotEmpty() }
}
fun saveBridgeTlsFingerprint(stableId: String, fingerprint: String) {
val key = "bridge.tls.$stableId"
fun saveGatewayTlsFingerprint(stableId: String, fingerprint: String) {
val key = "gateway.tls.$stableId"
prefs.edit { putString(key, fingerprint.trim()) }
}
@@ -225,4 +257,40 @@ class SecurePrefs(context: Context) {
defaultWakeWords
}
}
private fun readBoolWithMigration(newKey: String, oldKey: String?, defaultValue: Boolean): Boolean {
if (prefs.contains(newKey)) {
return prefs.getBoolean(newKey, defaultValue)
}
if (oldKey != null && prefs.contains(oldKey)) {
val value = prefs.getBoolean(oldKey, defaultValue)
prefs.edit { putBoolean(newKey, value) }
return value
}
return defaultValue
}
private fun readStringWithMigration(newKey: String, oldKey: String?, defaultValue: String): String {
if (prefs.contains(newKey)) {
return prefs.getString(newKey, defaultValue) ?: defaultValue
}
if (oldKey != null && prefs.contains(oldKey)) {
val value = prefs.getString(oldKey, defaultValue) ?: defaultValue
prefs.edit { putString(newKey, value) }
return value
}
return defaultValue
}
private fun readIntWithMigration(newKey: String, oldKey: String?, defaultValue: Int): Int {
if (prefs.contains(newKey)) {
return prefs.getInt(newKey, defaultValue)
}
if (oldKey != null && prefs.contains(oldKey)) {
val value = prefs.getInt(oldKey, defaultValue)
prefs.edit { putInt(newKey, value) }
return value
}
return defaultValue
}
}

View File

@@ -1,6 +1,6 @@
package com.clawdbot.android.chat
import com.clawdbot.android.bridge.BridgeSession
import com.clawdbot.android.gateway.GatewaySession
import java.util.UUID
import java.util.concurrent.ConcurrentHashMap
import kotlinx.coroutines.CoroutineScope
@@ -20,8 +20,9 @@ import kotlinx.serialization.json.buildJsonObject
class ChatController(
private val scope: CoroutineScope,
private val session: BridgeSession,
private val session: GatewaySession,
private val json: Json,
private val supportsChatSubscribe: Boolean,
) {
private val _sessionKey = MutableStateFlow("main")
val sessionKey: StateFlow<String> = _sessionKey.asStateFlow()
@@ -224,7 +225,7 @@ class ChatController(
}
}
fun handleBridgeEvent(event: String, payloadJson: String?) {
fun handleGatewayEvent(event: String, payloadJson: String?) {
when (event) {
"tick" -> {
scope.launch { pollHealthIfNeeded(force = false) }
@@ -259,10 +260,12 @@ class ChatController(
val key = _sessionKey.value
try {
try {
session.sendEvent("chat.subscribe", """{"sessionKey":"$key"}""")
} catch (_: Throwable) {
// best-effort
if (supportsChatSubscribe) {
try {
session.sendNodeEvent("chat.subscribe", """{"sessionKey":"$key"}""")
} catch (_: Throwable) {
// best-effort
}
}
val historyJson = session.request("chat.history", """{"sessionKey":"$key"}""")

View File

@@ -0,0 +1,146 @@
package com.clawdbot.android.gateway
import android.content.Context
import android.util.Base64
import java.io.File
import java.security.KeyFactory
import java.security.KeyPairGenerator
import java.security.MessageDigest
import java.security.Signature
import java.security.spec.PKCS8EncodedKeySpec
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
@Serializable
data class DeviceIdentity(
val deviceId: String,
val publicKeyRawBase64: String,
val privateKeyPkcs8Base64: String,
val createdAtMs: Long,
)
class DeviceIdentityStore(context: Context) {
private val json = Json { ignoreUnknownKeys = true }
private val identityFile = File(context.filesDir, "clawdbot/identity/device.json")
@Synchronized
fun loadOrCreate(): DeviceIdentity {
val existing = load()
if (existing != null) {
val derived = deriveDeviceId(existing.publicKeyRawBase64)
if (derived != null && derived != existing.deviceId) {
val updated = existing.copy(deviceId = derived)
save(updated)
return updated
}
return existing
}
val fresh = generate()
save(fresh)
return fresh
}
fun signPayload(payload: String, identity: DeviceIdentity): String? {
return try {
val privateKeyBytes = Base64.decode(identity.privateKeyPkcs8Base64, Base64.DEFAULT)
val keySpec = PKCS8EncodedKeySpec(privateKeyBytes)
val keyFactory = KeyFactory.getInstance("Ed25519")
val privateKey = keyFactory.generatePrivate(keySpec)
val signature = Signature.getInstance("Ed25519")
signature.initSign(privateKey)
signature.update(payload.toByteArray(Charsets.UTF_8))
base64UrlEncode(signature.sign())
} catch (_: Throwable) {
null
}
}
fun publicKeyBase64Url(identity: DeviceIdentity): String? {
return try {
val raw = Base64.decode(identity.publicKeyRawBase64, Base64.DEFAULT)
base64UrlEncode(raw)
} catch (_: Throwable) {
null
}
}
private fun load(): DeviceIdentity? {
return try {
if (!identityFile.exists()) return null
val raw = identityFile.readText(Charsets.UTF_8)
val decoded = json.decodeFromString(DeviceIdentity.serializer(), raw)
if (decoded.deviceId.isBlank() ||
decoded.publicKeyRawBase64.isBlank() ||
decoded.privateKeyPkcs8Base64.isBlank()
) {
null
} else {
decoded
}
} catch (_: Throwable) {
null
}
}
private fun save(identity: DeviceIdentity) {
try {
identityFile.parentFile?.mkdirs()
val encoded = json.encodeToString(DeviceIdentity.serializer(), identity)
identityFile.writeText(encoded, Charsets.UTF_8)
} catch (_: Throwable) {
// best-effort only
}
}
private fun generate(): DeviceIdentity {
val keyPair = KeyPairGenerator.getInstance("Ed25519").generateKeyPair()
val spki = keyPair.public.encoded
val rawPublic = stripSpkiPrefix(spki)
val deviceId = sha256Hex(rawPublic)
val privateKey = keyPair.private.encoded
return DeviceIdentity(
deviceId = deviceId,
publicKeyRawBase64 = Base64.encodeToString(rawPublic, Base64.NO_WRAP),
privateKeyPkcs8Base64 = Base64.encodeToString(privateKey, Base64.NO_WRAP),
createdAtMs = System.currentTimeMillis(),
)
}
private fun deriveDeviceId(publicKeyRawBase64: String): String? {
return try {
val raw = Base64.decode(publicKeyRawBase64, Base64.DEFAULT)
sha256Hex(raw)
} catch (_: Throwable) {
null
}
}
private fun stripSpkiPrefix(spki: ByteArray): ByteArray {
if (spki.size == ED25519_SPKI_PREFIX.size + 32 &&
spki.copyOfRange(0, ED25519_SPKI_PREFIX.size).contentEquals(ED25519_SPKI_PREFIX)
) {
return spki.copyOfRange(ED25519_SPKI_PREFIX.size, spki.size)
}
return spki
}
private fun sha256Hex(data: ByteArray): String {
val digest = MessageDigest.getInstance("SHA-256").digest(data)
val out = StringBuilder(digest.size * 2)
for (byte in digest) {
out.append(String.format("%02x", byte))
}
return out.toString()
}
private fun base64UrlEncode(data: ByteArray): String {
return Base64.encodeToString(data, Base64.URL_SAFE or Base64.NO_WRAP or Base64.NO_PADDING)
}
companion object {
private val ED25519_SPKI_PREFIX =
byteArrayOf(
0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x03, 0x21, 0x00,
)
}
}

View File

@@ -0,0 +1,520 @@
package com.clawdbot.android.gateway
import android.content.Context
import android.net.ConnectivityManager
import android.net.DnsResolver
import android.net.NetworkCapabilities
import android.net.nsd.NsdManager
import android.net.nsd.NsdServiceInfo
import android.os.CancellationSignal
import android.util.Log
import com.clawdbot.android.bridge.BonjourEscapes
import java.io.IOException
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.charset.CodingErrorAction
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.Executor
import java.util.concurrent.Executors
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.launch
import kotlinx.coroutines.suspendCancellableCoroutine
import org.xbill.DNS.AAAARecord
import org.xbill.DNS.ARecord
import org.xbill.DNS.DClass
import org.xbill.DNS.ExtendedResolver
import org.xbill.DNS.Message
import org.xbill.DNS.Name
import org.xbill.DNS.PTRRecord
import org.xbill.DNS.Record
import org.xbill.DNS.Rcode
import org.xbill.DNS.Resolver
import org.xbill.DNS.SRVRecord
import org.xbill.DNS.Section
import org.xbill.DNS.SimpleResolver
import org.xbill.DNS.TextParseException
import org.xbill.DNS.TXTRecord
import org.xbill.DNS.Type
import kotlin.coroutines.resume
import kotlin.coroutines.resumeWithException
@Suppress("DEPRECATION")
class GatewayDiscovery(
context: Context,
private val scope: CoroutineScope,
) {
private val nsd = context.getSystemService(NsdManager::class.java)
private val connectivity = context.getSystemService(ConnectivityManager::class.java)
private val dns = DnsResolver.getInstance()
private val serviceType = "_clawdbot-gateway._tcp."
private val wideAreaDomain = "clawdbot.internal."
private val logTag = "Clawdbot/GatewayDiscovery"
private val localById = ConcurrentHashMap<String, GatewayEndpoint>()
private val unicastById = ConcurrentHashMap<String, GatewayEndpoint>()
private val _gateways = MutableStateFlow<List<GatewayEndpoint>>(emptyList())
val gateways: StateFlow<List<GatewayEndpoint>> = _gateways.asStateFlow()
private val _statusText = MutableStateFlow("Searching…")
val statusText: StateFlow<String> = _statusText.asStateFlow()
private var unicastJob: Job? = null
private val dnsExecutor: Executor = Executors.newCachedThreadPool()
@Volatile private var lastWideAreaRcode: Int? = null
@Volatile private var lastWideAreaCount: Int = 0
private val discoveryListener =
object : NsdManager.DiscoveryListener {
override fun onStartDiscoveryFailed(serviceType: String, errorCode: Int) {}
override fun onStopDiscoveryFailed(serviceType: String, errorCode: Int) {}
override fun onDiscoveryStarted(serviceType: String) {}
override fun onDiscoveryStopped(serviceType: String) {}
override fun onServiceFound(serviceInfo: NsdServiceInfo) {
if (serviceInfo.serviceType != this@GatewayDiscovery.serviceType) return
resolve(serviceInfo)
}
override fun onServiceLost(serviceInfo: NsdServiceInfo) {
val serviceName = BonjourEscapes.decode(serviceInfo.serviceName)
val id = stableId(serviceName, "local.")
localById.remove(id)
publish()
}
}
init {
startLocalDiscovery()
startUnicastDiscovery(wideAreaDomain)
}
private fun startLocalDiscovery() {
try {
nsd.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, discoveryListener)
} catch (_: Throwable) {
// ignore (best-effort)
}
}
private fun stopLocalDiscovery() {
try {
nsd.stopServiceDiscovery(discoveryListener)
} catch (_: Throwable) {
// ignore (best-effort)
}
}
private fun startUnicastDiscovery(domain: String) {
unicastJob =
scope.launch(Dispatchers.IO) {
while (true) {
try {
refreshUnicast(domain)
} catch (_: Throwable) {
// ignore (best-effort)
}
delay(5000)
}
}
}
private fun resolve(serviceInfo: NsdServiceInfo) {
nsd.resolveService(
serviceInfo,
object : NsdManager.ResolveListener {
override fun onResolveFailed(serviceInfo: NsdServiceInfo, errorCode: Int) {}
override fun onServiceResolved(resolved: NsdServiceInfo) {
val host = resolved.host?.hostAddress ?: return
val port = resolved.port
if (port <= 0) return
val rawServiceName = resolved.serviceName
val serviceName = BonjourEscapes.decode(rawServiceName)
val displayName = BonjourEscapes.decode(txt(resolved, "displayName") ?: serviceName)
val lanHost = txt(resolved, "lanHost")
val tailnetDns = txt(resolved, "tailnetDns")
val gatewayPort = txtInt(resolved, "gatewayPort")
val canvasPort = txtInt(resolved, "canvasPort")
val tlsEnabled = txtBool(resolved, "gatewayTls")
val tlsFingerprint = txt(resolved, "gatewayTlsSha256")
val id = stableId(serviceName, "local.")
localById[id] =
GatewayEndpoint(
stableId = id,
name = displayName,
host = host,
port = port,
lanHost = lanHost,
tailnetDns = tailnetDns,
gatewayPort = gatewayPort,
canvasPort = canvasPort,
tlsEnabled = tlsEnabled,
tlsFingerprintSha256 = tlsFingerprint,
)
publish()
}
},
)
}
private fun publish() {
_gateways.value =
(localById.values + unicastById.values).sortedBy { it.name.lowercase() }
_statusText.value = buildStatusText()
}
private fun buildStatusText(): String {
val localCount = localById.size
val wideRcode = lastWideAreaRcode
val wideCount = lastWideAreaCount
val wide =
when (wideRcode) {
null -> "Wide: ?"
Rcode.NOERROR -> "Wide: $wideCount"
Rcode.NXDOMAIN -> "Wide: NXDOMAIN"
else -> "Wide: ${Rcode.string(wideRcode)}"
}
return when {
localCount == 0 && wideRcode == null -> "Searching for gateways…"
localCount == 0 -> "$wide"
else -> "Local: $localCount$wide"
}
}
private fun stableId(serviceName: String, domain: String): String {
return "${serviceType}|${domain}|${normalizeName(serviceName)}"
}
private fun normalizeName(raw: String): String {
return raw.trim().split(Regex("\\s+")).joinToString(" ")
}
private fun txt(info: NsdServiceInfo, key: String): String? {
val bytes = info.attributes[key] ?: return null
return try {
String(bytes, Charsets.UTF_8).trim().ifEmpty { null }
} catch (_: Throwable) {
null
}
}
private fun txtInt(info: NsdServiceInfo, key: String): Int? {
return txt(info, key)?.toIntOrNull()
}
private fun txtBool(info: NsdServiceInfo, key: String): Boolean {
val raw = txt(info, key)?.trim()?.lowercase() ?: return false
return raw == "1" || raw == "true" || raw == "yes"
}
private suspend fun refreshUnicast(domain: String) {
val ptrName = "${serviceType}${domain}"
val ptrMsg = lookupUnicastMessage(ptrName, Type.PTR) ?: return
val ptrRecords = records(ptrMsg, Section.ANSWER).mapNotNull { it as? PTRRecord }
val next = LinkedHashMap<String, GatewayEndpoint>()
for (ptr in ptrRecords) {
val instanceFqdn = ptr.target.toString()
val srv =
recordByName(ptrMsg, instanceFqdn, Type.SRV) as? SRVRecord
?: run {
val msg = lookupUnicastMessage(instanceFqdn, Type.SRV) ?: return@run null
recordByName(msg, instanceFqdn, Type.SRV) as? SRVRecord
}
?: continue
val port = srv.port
if (port <= 0) continue
val targetFqdn = srv.target.toString()
val host =
resolveHostFromMessage(ptrMsg, targetFqdn)
?: resolveHostFromMessage(lookupUnicastMessage(instanceFqdn, Type.SRV), targetFqdn)
?: resolveHostUnicast(targetFqdn)
?: continue
val txtFromPtr =
recordsByName(ptrMsg, Section.ADDITIONAL)[keyName(instanceFqdn)]
.orEmpty()
.mapNotNull { it as? TXTRecord }
val txt =
if (txtFromPtr.isNotEmpty()) {
txtFromPtr
} else {
val msg = lookupUnicastMessage(instanceFqdn, Type.TXT)
records(msg, Section.ANSWER).mapNotNull { it as? TXTRecord }
}
val instanceName = BonjourEscapes.decode(decodeInstanceName(instanceFqdn, domain))
val displayName = BonjourEscapes.decode(txtValue(txt, "displayName") ?: instanceName)
val lanHost = txtValue(txt, "lanHost")
val tailnetDns = txtValue(txt, "tailnetDns")
val gatewayPort = txtIntValue(txt, "gatewayPort")
val canvasPort = txtIntValue(txt, "canvasPort")
val tlsEnabled = txtBoolValue(txt, "gatewayTls")
val tlsFingerprint = txtValue(txt, "gatewayTlsSha256")
val id = stableId(instanceName, domain)
next[id] =
GatewayEndpoint(
stableId = id,
name = displayName,
host = host,
port = port,
lanHost = lanHost,
tailnetDns = tailnetDns,
gatewayPort = gatewayPort,
canvasPort = canvasPort,
tlsEnabled = tlsEnabled,
tlsFingerprintSha256 = tlsFingerprint,
)
}
unicastById.clear()
unicastById.putAll(next)
lastWideAreaRcode = ptrMsg.header.rcode
lastWideAreaCount = next.size
publish()
if (next.isEmpty()) {
Log.d(
logTag,
"wide-area discovery: 0 results for $ptrName (rcode=${Rcode.string(ptrMsg.header.rcode)})",
)
}
}
private fun decodeInstanceName(instanceFqdn: String, domain: String): String {
val suffix = "${serviceType}${domain}"
val withoutSuffix =
if (instanceFqdn.endsWith(suffix)) {
instanceFqdn.removeSuffix(suffix)
} else {
instanceFqdn.substringBefore(serviceType)
}
return normalizeName(stripTrailingDot(withoutSuffix))
}
private fun stripTrailingDot(raw: String): String {
return raw.removeSuffix(".")
}
private suspend fun lookupUnicastMessage(name: String, type: Int): Message? {
val query =
try {
Message.newQuery(
org.xbill.DNS.Record.newRecord(
Name.fromString(name),
type,
DClass.IN,
),
)
} catch (_: TextParseException) {
return null
}
val system = queryViaSystemDns(query)
if (records(system, Section.ANSWER).any { it.type == type }) return system
val direct = createDirectResolver() ?: return system
return try {
val msg = direct.send(query)
if (records(msg, Section.ANSWER).any { it.type == type }) msg else system
} catch (_: Throwable) {
system
}
}
private suspend fun queryViaSystemDns(query: Message): Message? {
val network = preferredDnsNetwork()
val bytes =
try {
rawQuery(network, query.toWire())
} catch (_: Throwable) {
return null
}
return try {
Message(bytes)
} catch (_: IOException) {
null
}
}
private fun records(msg: Message?, section: Int): List<Record> {
return msg?.getSectionArray(section)?.toList() ?: emptyList()
}
private fun keyName(raw: String): String {
return raw.trim().lowercase()
}
private fun recordsByName(msg: Message, section: Int): Map<String, List<Record>> {
val next = LinkedHashMap<String, MutableList<Record>>()
for (r in records(msg, section)) {
val name = r.name?.toString() ?: continue
next.getOrPut(keyName(name)) { mutableListOf() }.add(r)
}
return next
}
private fun recordByName(msg: Message, fqdn: String, type: Int): Record? {
val key = keyName(fqdn)
val byNameAnswer = recordsByName(msg, Section.ANSWER)
val fromAnswer = byNameAnswer[key].orEmpty().firstOrNull { it.type == type }
if (fromAnswer != null) return fromAnswer
val byNameAdditional = recordsByName(msg, Section.ADDITIONAL)
return byNameAdditional[key].orEmpty().firstOrNull { it.type == type }
}
private fun resolveHostFromMessage(msg: Message?, hostname: String): String? {
val m = msg ?: return null
val key = keyName(hostname)
val additional = recordsByName(m, Section.ADDITIONAL)[key].orEmpty()
val a = additional.mapNotNull { it as? ARecord }.mapNotNull { it.address?.hostAddress }
val aaaa = additional.mapNotNull { it as? AAAARecord }.mapNotNull { it.address?.hostAddress }
return a.firstOrNull() ?: aaaa.firstOrNull()
}
private fun preferredDnsNetwork(): android.net.Network? {
val cm = connectivity ?: return null
// Prefer VPN (Tailscale) when present; otherwise use the active network.
cm.allNetworks.firstOrNull { n ->
val caps = cm.getNetworkCapabilities(n) ?: return@firstOrNull false
caps.hasTransport(NetworkCapabilities.TRANSPORT_VPN)
}?.let { return it }
return cm.activeNetwork
}
private fun createDirectResolver(): Resolver? {
val cm = connectivity ?: return null
val candidateNetworks =
buildList {
cm.allNetworks
.firstOrNull { n ->
val caps = cm.getNetworkCapabilities(n) ?: return@firstOrNull false
caps.hasTransport(NetworkCapabilities.TRANSPORT_VPN)
}?.let(::add)
cm.activeNetwork?.let(::add)
}.distinct()
val servers =
candidateNetworks
.asSequence()
.flatMap { n ->
cm.getLinkProperties(n)?.dnsServers?.asSequence() ?: emptySequence()
}
.distinctBy { it.hostAddress ?: it.toString() }
.toList()
if (servers.isEmpty()) return null
return try {
val resolvers =
servers.mapNotNull { addr ->
try {
SimpleResolver().apply {
setAddress(InetSocketAddress(addr, 53))
setTimeout(3)
}
} catch (_: Throwable) {
null
}
}
if (resolvers.isEmpty()) return null
ExtendedResolver(resolvers.toTypedArray()).apply { setTimeout(3) }
} catch (_: Throwable) {
null
}
}
private suspend fun rawQuery(network: android.net.Network?, wireQuery: ByteArray): ByteArray =
suspendCancellableCoroutine { cont ->
val signal = CancellationSignal()
cont.invokeOnCancellation { signal.cancel() }
dns.rawQuery(
network,
wireQuery,
DnsResolver.FLAG_EMPTY,
dnsExecutor,
signal,
object : DnsResolver.Callback<ByteArray> {
override fun onAnswer(answer: ByteArray, rcode: Int) {
cont.resume(answer)
}
override fun onError(error: DnsResolver.DnsException) {
cont.resumeWithException(error)
}
},
)
}
private fun txtValue(records: List<TXTRecord>, key: String): String? {
val prefix = "$key="
for (r in records) {
val strings: List<String> =
try {
r.strings.mapNotNull { it as? String }
} catch (_: Throwable) {
emptyList()
}
for (s in strings) {
val trimmed = decodeDnsTxtString(s).trim()
if (trimmed.startsWith(prefix)) {
return trimmed.removePrefix(prefix).trim().ifEmpty { null }
}
}
}
return null
}
private fun txtIntValue(records: List<TXTRecord>, key: String): Int? {
return txtValue(records, key)?.toIntOrNull()
}
private fun txtBoolValue(records: List<TXTRecord>, key: String): Boolean {
val raw = txtValue(records, key)?.trim()?.lowercase() ?: return false
return raw == "1" || raw == "true" || raw == "yes"
}
private fun decodeDnsTxtString(raw: String): String {
// dnsjava treats TXT as opaque bytes and decodes as ISO-8859-1 to preserve bytes.
// Our TXT payload is UTF-8 (written by the gateway), so re-decode when possible.
val bytes = raw.toByteArray(Charsets.ISO_8859_1)
val decoder =
Charsets.UTF_8
.newDecoder()
.onMalformedInput(CodingErrorAction.REPORT)
.onUnmappableCharacter(CodingErrorAction.REPORT)
return try {
decoder.decode(ByteBuffer.wrap(bytes)).toString()
} catch (_: Throwable) {
raw
}
}
private suspend fun resolveHostUnicast(hostname: String): String? {
val a =
records(lookupUnicastMessage(hostname, Type.A), Section.ANSWER)
.mapNotNull { it as? ARecord }
.mapNotNull { it.address?.hostAddress }
val aaaa =
records(lookupUnicastMessage(hostname, Type.AAAA), Section.ANSWER)
.mapNotNull { it as? AAAARecord }
.mapNotNull { it.address?.hostAddress }
return a.firstOrNull() ?: aaaa.firstOrNull()
}
}

View File

@@ -0,0 +1,26 @@
package com.clawdbot.android.gateway
data class GatewayEndpoint(
val stableId: String,
val name: String,
val host: String,
val port: Int,
val lanHost: String? = null,
val tailnetDns: String? = null,
val gatewayPort: Int? = null,
val canvasPort: Int? = null,
val tlsEnabled: Boolean = false,
val tlsFingerprintSha256: String? = null,
) {
companion object {
fun manual(host: String, port: Int): GatewayEndpoint =
GatewayEndpoint(
stableId = "manual|${host.lowercase()}|$port",
name = "$host:$port",
host = host,
port = port,
tlsEnabled = false,
tlsFingerprintSha256 = null,
)
}
}

View File

@@ -0,0 +1,3 @@
package com.clawdbot.android.gateway
const val GATEWAY_PROTOCOL_VERSION = 3

View File

@@ -0,0 +1,599 @@
package com.clawdbot.android.gateway
import android.util.Log
import java.util.Locale
import java.util.UUID
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.delay
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonArray
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonNull
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.buildJsonObject
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.Response
import okhttp3.WebSocket
import okhttp3.WebSocketListener
data class GatewayClientInfo(
val id: String,
val displayName: String?,
val version: String,
val platform: String,
val mode: String,
val instanceId: String?,
val deviceFamily: String?,
val modelIdentifier: String?,
)
data class GatewayConnectOptions(
val role: String,
val scopes: List<String>,
val caps: List<String>,
val commands: List<String>,
val permissions: Map<String, Boolean>,
val client: GatewayClientInfo,
)
class GatewaySession(
private val scope: CoroutineScope,
private val identityStore: DeviceIdentityStore,
private val onConnected: (serverName: String?, remoteAddress: String?, mainSessionKey: String?) -> Unit,
private val onDisconnected: (message: String) -> Unit,
private val onEvent: (event: String, payloadJson: String?) -> Unit,
private val onInvoke: (suspend (InvokeRequest) -> InvokeResult)? = null,
private val onTlsFingerprint: ((stableId: String, fingerprint: String) -> Unit)? = null,
) {
data class InvokeRequest(
val id: String,
val nodeId: String,
val command: String,
val paramsJson: String?,
val timeoutMs: Long?,
)
data class InvokeResult(val ok: Boolean, val payloadJson: String?, val error: ErrorShape?) {
companion object {
fun ok(payloadJson: String?) = InvokeResult(ok = true, payloadJson = payloadJson, error = null)
fun error(code: String, message: String) =
InvokeResult(ok = false, payloadJson = null, error = ErrorShape(code = code, message = message))
}
}
data class ErrorShape(val code: String, val message: String)
private val json = Json { ignoreUnknownKeys = true }
private val writeLock = Mutex()
private val pending = ConcurrentHashMap<String, CompletableDeferred<RpcResponse>>()
@Volatile private var canvasHostUrl: String? = null
@Volatile private var mainSessionKey: String? = null
private data class DesiredConnection(
val endpoint: GatewayEndpoint,
val token: String?,
val password: String?,
val options: GatewayConnectOptions,
val tls: GatewayTlsParams?,
)
private var desired: DesiredConnection? = null
private var job: Job? = null
@Volatile private var currentConnection: Connection? = null
fun connect(
endpoint: GatewayEndpoint,
token: String?,
password: String?,
options: GatewayConnectOptions,
tls: GatewayTlsParams? = null,
) {
desired = DesiredConnection(endpoint, token, password, options, tls)
if (job == null) {
job = scope.launch(Dispatchers.IO) { runLoop() }
}
}
fun disconnect() {
desired = null
currentConnection?.closeQuietly()
scope.launch(Dispatchers.IO) {
job?.cancelAndJoin()
job = null
canvasHostUrl = null
mainSessionKey = null
onDisconnected("Offline")
}
}
fun reconnect() {
currentConnection?.closeQuietly()
}
fun currentCanvasHostUrl(): String? = canvasHostUrl
fun currentMainSessionKey(): String? = mainSessionKey
suspend fun sendNodeEvent(event: String, payloadJson: String?) {
val conn = currentConnection ?: return
val params =
buildJsonObject {
put("event", JsonPrimitive(event))
if (payloadJson != null) put("payloadJSON", JsonPrimitive(payloadJson)) else put("payloadJSON", JsonNull)
}
try {
conn.request("node.event", params, timeoutMs = 8_000)
} catch (err: Throwable) {
Log.w("ClawdbotGateway", "node.event failed: ${err.message ?: err::class.java.simpleName}")
}
}
suspend fun request(method: String, paramsJson: String?, timeoutMs: Long = 15_000): String {
val conn = currentConnection ?: throw IllegalStateException("not connected")
val params =
if (paramsJson.isNullOrBlank()) {
null
} else {
json.parseToJsonElement(paramsJson)
}
val res = conn.request(method, params, timeoutMs)
if (res.ok) return res.payloadJson ?: ""
val err = res.error
throw IllegalStateException("${err?.code ?: "UNAVAILABLE"}: ${err?.message ?: "request failed"}")
}
private data class RpcResponse(val id: String, val ok: Boolean, val payloadJson: String?, val error: ErrorShape?)
private inner class Connection(
private val endpoint: GatewayEndpoint,
private val token: String?,
private val password: String?,
private val options: GatewayConnectOptions,
private val tls: GatewayTlsParams?,
) {
private val connectDeferred = CompletableDeferred<Unit>()
private val closedDeferred = CompletableDeferred<Unit>()
private val isClosed = AtomicBoolean(false)
private val client: OkHttpClient = buildClient()
private var socket: WebSocket? = null
private val loggerTag = "ClawdbotGateway"
val remoteAddress: String =
if (endpoint.host.contains(":")) {
"[${endpoint.host}]:${endpoint.port}"
} else {
"${endpoint.host}:${endpoint.port}"
}
suspend fun connect() {
val scheme = if (tls != null) "wss" else "ws"
val url = "$scheme://${endpoint.host}:${endpoint.port}"
val request = Request.Builder().url(url).build()
socket = client.newWebSocket(request, Listener())
try {
connectDeferred.await()
} catch (err: Throwable) {
throw err
}
}
suspend fun request(method: String, params: JsonElement?, timeoutMs: Long): RpcResponse {
val id = UUID.randomUUID().toString()
val deferred = CompletableDeferred<RpcResponse>()
pending[id] = deferred
val frame =
buildJsonObject {
put("type", JsonPrimitive("req"))
put("id", JsonPrimitive(id))
put("method", JsonPrimitive(method))
if (params != null) put("params", params)
}
sendJson(frame)
return try {
withTimeout(timeoutMs) { deferred.await() }
} catch (err: TimeoutCancellationException) {
pending.remove(id)
throw IllegalStateException("request timeout")
}
}
suspend fun sendJson(obj: JsonObject) {
val jsonString = obj.toString()
writeLock.withLock {
socket?.send(jsonString)
}
}
fun awaitClose() = closedDeferred.await()
fun closeQuietly() {
if (isClosed.compareAndSet(false, true)) {
socket?.close(1000, "bye")
socket = null
closedDeferred.complete(Unit)
}
}
private fun buildClient(): OkHttpClient {
val builder = OkHttpClient.Builder()
val tlsConfig = buildGatewayTlsConfig(tls) { fingerprint ->
onTlsFingerprint?.invoke(tls?.stableId ?: endpoint.stableId, fingerprint)
}
if (tlsConfig != null) {
builder.sslSocketFactory(tlsConfig.sslSocketFactory, tlsConfig.trustManager)
builder.hostnameVerifier(tlsConfig.hostnameVerifier)
}
return builder.build()
}
private inner class Listener : WebSocketListener() {
override fun onOpen(webSocket: WebSocket, response: Response) {
scope.launch {
try {
sendConnect()
} catch (err: Throwable) {
connectDeferred.completeExceptionally(err)
closeQuietly()
}
}
}
override fun onMessage(webSocket: WebSocket, text: String) {
scope.launch { handleMessage(text) }
}
override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
if (!connectDeferred.isCompleted) {
connectDeferred.completeExceptionally(t)
}
if (isClosed.compareAndSet(false, true)) {
failPending()
closedDeferred.complete(Unit)
onDisconnected("Gateway error: ${t.message ?: t::class.java.simpleName}")
}
}
override fun onClosed(webSocket: WebSocket, code: Int, reason: String) {
if (!connectDeferred.isCompleted) {
connectDeferred.completeExceptionally(IllegalStateException("Gateway closed: $reason"))
}
if (isClosed.compareAndSet(false, true)) {
failPending()
closedDeferred.complete(Unit)
onDisconnected("Gateway closed: $reason")
}
}
}
private suspend fun sendConnect() {
val payload = buildConnectParams()
val res = request("connect", payload, timeoutMs = 8_000)
if (!res.ok) {
val msg = res.error?.message ?: "connect failed"
throw IllegalStateException(msg)
}
val payloadJson = res.payloadJson ?: throw IllegalStateException("connect failed: missing payload")
val obj = json.parseToJsonElement(payloadJson).asObjectOrNull() ?: throw IllegalStateException("connect failed")
val serverName = obj["server"].asObjectOrNull()?.get("host").asStringOrNull()
val rawCanvas = obj["canvasHostUrl"].asStringOrNull()
canvasHostUrl = normalizeCanvasHostUrl(rawCanvas, endpoint)
val sessionDefaults =
obj["snapshot"].asObjectOrNull()
?.get("sessionDefaults").asObjectOrNull()
mainSessionKey = sessionDefaults?.get("mainSessionKey").asStringOrNull()
onConnected(serverName, remoteAddress, mainSessionKey)
connectDeferred.complete(Unit)
}
private fun buildConnectParams(): JsonObject {
val client = options.client
val locale = Locale.getDefault().toLanguageTag()
val clientObj =
buildJsonObject {
put("id", JsonPrimitive(client.id))
client.displayName?.let { put("displayName", JsonPrimitive(it)) }
put("version", JsonPrimitive(client.version))
put("platform", JsonPrimitive(client.platform))
put("mode", JsonPrimitive(client.mode))
client.instanceId?.let { put("instanceId", JsonPrimitive(it)) }
client.deviceFamily?.let { put("deviceFamily", JsonPrimitive(it)) }
client.modelIdentifier?.let { put("modelIdentifier", JsonPrimitive(it)) }
}
val params =
buildJsonObject {
put("minProtocol", JsonPrimitive(GATEWAY_PROTOCOL_VERSION))
put("maxProtocol", JsonPrimitive(GATEWAY_PROTOCOL_VERSION))
put("client", clientObj)
if (options.caps.isNotEmpty()) put("caps", JsonArray(options.caps.map(::JsonPrimitive)))
if (options.commands.isNotEmpty()) put("commands", JsonArray(options.commands.map(::JsonPrimitive)))
if (options.permissions.isNotEmpty()) {
put(
"permissions",
buildJsonObject {
options.permissions.forEach { (key, value) ->
put(key, JsonPrimitive(value))
}
},
)
}
put("role", JsonPrimitive(options.role))
if (options.scopes.isNotEmpty()) put("scopes", JsonArray(options.scopes.map(::JsonPrimitive)))
put("locale", JsonPrimitive(locale))
}
val authToken = token?.trim().orEmpty()
val authPassword = password?.trim().orEmpty()
if (authToken.isNotEmpty()) {
params["auth"] =
buildJsonObject {
put("token", JsonPrimitive(authToken))
}
} else if (authPassword.isNotEmpty()) {
params["auth"] =
buildJsonObject {
put("password", JsonPrimitive(authPassword))
}
}
val identity = identityStore.loadOrCreate()
val signedAtMs = System.currentTimeMillis()
val payload =
buildDeviceAuthPayload(
deviceId = identity.deviceId,
clientId = client.id,
clientMode = client.mode,
role = options.role,
scopes = options.scopes,
signedAtMs = signedAtMs,
token = if (authToken.isNotEmpty()) authToken else null,
)
val signature = identityStore.signPayload(payload, identity)
val publicKey = identityStore.publicKeyBase64Url(identity)
if (!signature.isNullOrBlank() && !publicKey.isNullOrBlank()) {
params["device"] =
buildJsonObject {
put("id", JsonPrimitive(identity.deviceId))
put("publicKey", JsonPrimitive(publicKey))
put("signature", JsonPrimitive(signature))
put("signedAt", JsonPrimitive(signedAtMs))
}
}
return params
}
private suspend fun handleMessage(text: String) {
val frame = json.parseToJsonElement(text).asObjectOrNull() ?: return
when (frame["type"].asStringOrNull()) {
"res" -> handleResponse(frame)
"event" -> handleEvent(frame)
}
}
private fun handleResponse(frame: JsonObject) {
val id = frame["id"].asStringOrNull() ?: return
val ok = frame["ok"].asBooleanOrNull() ?: false
val payloadJson = frame["payload"]?.let { payload -> payload.toString() }
val error =
frame["error"]?.asObjectOrNull()?.let { obj ->
val code = obj["code"].asStringOrNull() ?: "UNAVAILABLE"
val msg = obj["message"].asStringOrNull() ?: "request failed"
ErrorShape(code, msg)
}
pending.remove(id)?.complete(RpcResponse(id, ok, payloadJson, error))
}
private fun handleEvent(frame: JsonObject) {
val event = frame["event"].asStringOrNull() ?: return
val payloadJson = frame["payload"]?.let { it.toString() }
if (event == "node.invoke.request" && payloadJson != null && onInvoke != null) {
handleInvokeEvent(payloadJson)
return
}
onEvent(event, payloadJson)
}
private fun handleInvokeEvent(payloadJson: String) {
val payload =
try {
json.parseToJsonElement(payloadJson).asObjectOrNull()
} catch (_: Throwable) {
null
} ?: return
val id = payload["id"].asStringOrNull() ?: return
val nodeId = payload["nodeId"].asStringOrNull() ?: return
val command = payload["command"].asStringOrNull() ?: return
val params = payload["paramsJSON"].asStringOrNull()
val timeoutMs = payload["timeoutMs"].asLongOrNull()
scope.launch {
val result =
try {
onInvoke?.invoke(InvokeRequest(id, nodeId, command, params, timeoutMs))
?: InvokeResult.error("UNAVAILABLE", "invoke handler missing")
} catch (err: Throwable) {
invokeErrorFromThrowable(err)
}
sendInvokeResult(id, nodeId, result)
}
}
private suspend fun sendInvokeResult(id: String, nodeId: String, result: InvokeResult) {
val params =
buildJsonObject {
put("id", JsonPrimitive(id))
put("nodeId", JsonPrimitive(nodeId))
put("ok", JsonPrimitive(result.ok))
if (result.payloadJson != null) put("payloadJSON", JsonPrimitive(result.payloadJson))
result.error?.let { err ->
put(
"error",
buildJsonObject {
put("code", JsonPrimitive(err.code))
put("message", JsonPrimitive(err.message))
},
)
}
}
try {
request("node.invoke.result", params, timeoutMs = 15_000)
} catch (err: Throwable) {
Log.w(loggerTag, "node.invoke.result failed: ${err.message ?: err::class.java.simpleName}")
}
}
private fun invokeErrorFromThrowable(err: Throwable): InvokeResult {
val msg = err.message?.trim().takeIf { !it.isNullOrEmpty() } ?: err::class.java.simpleName
val parts = msg.split(":", limit = 2)
if (parts.size == 2) {
val code = parts[0].trim()
val rest = parts[1].trim()
if (code.isNotEmpty() && code.all { it.isUpperCase() || it == '_' }) {
return InvokeResult.error(code = code, message = rest.ifEmpty { msg })
}
}
return InvokeResult.error(code = "UNAVAILABLE", message = msg)
}
private fun failPending() {
for ((_, waiter) in pending) {
waiter.cancel()
}
pending.clear()
}
}
private suspend fun runLoop() {
var attempt = 0
while (scope.isActive) {
val target = desired
if (target == null) {
currentConnection?.closeQuietly()
currentConnection = null
delay(250)
continue
}
try {
onDisconnected(if (attempt == 0) "Connecting…" else "Reconnecting…")
connectOnce(target)
attempt = 0
} catch (err: Throwable) {
attempt += 1
onDisconnected("Gateway error: ${err.message ?: err::class.java.simpleName}")
val sleepMs = minOf(8_000L, (350.0 * Math.pow(1.7, attempt.toDouble())).toLong())
delay(sleepMs)
}
}
}
private suspend fun connectOnce(target: DesiredConnection) = withContext(Dispatchers.IO) {
val conn = Connection(target.endpoint, target.token, target.password, target.options, target.tls)
currentConnection = conn
try {
conn.connect()
conn.awaitClose()
} finally {
currentConnection = null
canvasHostUrl = null
mainSessionKey = null
}
}
private fun buildDeviceAuthPayload(
deviceId: String,
clientId: String,
clientMode: String,
role: String,
scopes: List<String>,
signedAtMs: Long,
token: String?,
): String {
val scopeString = scopes.joinToString(",")
val authToken = token.orEmpty()
return listOf(
"v1",
deviceId,
clientId,
clientMode,
role,
scopeString,
signedAtMs.toString(),
authToken,
).joinToString("|")
}
private fun normalizeCanvasHostUrl(raw: String?, endpoint: GatewayEndpoint): String? {
val trimmed = raw?.trim().orEmpty()
val parsed = trimmed.takeIf { it.isNotBlank() }?.let { runCatching { java.net.URI(it) }.getOrNull() }
val host = parsed?.host?.trim().orEmpty()
val port = parsed?.port ?: -1
val scheme = parsed?.scheme?.trim().orEmpty().ifBlank { "http" }
if (trimmed.isNotBlank() && !isLoopbackHost(host)) {
return trimmed
}
val fallbackHost =
endpoint.tailnetDns?.trim().takeIf { !it.isNullOrEmpty() }
?: endpoint.lanHost?.trim().takeIf { !it.isNullOrEmpty() }
?: endpoint.host.trim()
if (fallbackHost.isEmpty()) return trimmed.ifBlank { null }
val fallbackPort = endpoint.canvasPort ?: if (port > 0) port else 18793
val formattedHost = if (fallbackHost.contains(":")) "[${fallbackHost}]" else fallbackHost
return "$scheme://$formattedHost:$fallbackPort"
}
private fun isLoopbackHost(raw: String?): Boolean {
val host = raw?.trim()?.lowercase().orEmpty()
if (host.isEmpty()) return false
if (host == "localhost") return true
if (host == "::1") return true
if (host == "0.0.0.0" || host == "::") return true
return host.startsWith("127.")
}
}
private fun JsonElement?.asObjectOrNull(): JsonObject? = this as? JsonObject
private fun JsonElement?.asStringOrNull(): String? =
when (this) {
is JsonNull -> null
is JsonPrimitive -> content
else -> null
}
private fun JsonElement?.asBooleanOrNull(): Boolean? =
when (this) {
is JsonPrimitive -> {
val c = content.trim()
when {
c.equals("true", ignoreCase = true) -> true
c.equals("false", ignoreCase = true) -> false
else -> null
}
}
else -> null
}
private fun JsonElement?.asLongOrNull(): Long? =
when (this) {
is JsonPrimitive -> content.toLongOrNull()
else -> null
}

View File

@@ -0,0 +1,88 @@
package com.clawdbot.android.gateway
import android.annotation.SuppressLint
import java.security.MessageDigest
import java.security.SecureRandom
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import javax.net.ssl.HostnameVerifier
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLSocketFactory
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509TrustManager
data class GatewayTlsParams(
val required: Boolean,
val expectedFingerprint: String?,
val allowTOFU: Boolean,
val stableId: String,
)
data class GatewayTlsConfig(
val sslSocketFactory: SSLSocketFactory,
val trustManager: X509TrustManager,
val hostnameVerifier: HostnameVerifier,
)
fun buildGatewayTlsConfig(
params: GatewayTlsParams?,
onStore: ((String) -> Unit)? = null,
): GatewayTlsConfig? {
if (params == null) return null
val expected = params.expectedFingerprint?.let(::normalizeFingerprint)
val defaultTrust = defaultTrustManager()
@SuppressLint("CustomX509TrustManager")
val trustManager =
object : X509TrustManager {
override fun checkClientTrusted(chain: Array<X509Certificate>, authType: String) {
defaultTrust.checkClientTrusted(chain, authType)
}
override fun checkServerTrusted(chain: Array<X509Certificate>, authType: String) {
if (chain.isEmpty()) throw CertificateException("empty certificate chain")
val fingerprint = sha256Hex(chain[0].encoded)
if (expected != null) {
if (fingerprint != expected) {
throw CertificateException("gateway TLS fingerprint mismatch")
}
return
}
if (params.allowTOFU) {
onStore?.invoke(fingerprint)
return
}
defaultTrust.checkServerTrusted(chain, authType)
}
override fun getAcceptedIssuers(): Array<X509Certificate> = defaultTrust.acceptedIssuers
}
val context = SSLContext.getInstance("TLS")
context.init(null, arrayOf(trustManager), SecureRandom())
return GatewayTlsConfig(
sslSocketFactory = context.socketFactory,
trustManager = trustManager,
hostnameVerifier = HostnameVerifier { _, _ -> true },
)
}
private fun defaultTrustManager(): X509TrustManager {
val factory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
factory.init(null as java.security.KeyStore?)
val trust =
factory.trustManagers.firstOrNull { it is X509TrustManager } as? X509TrustManager
return trust ?: throw IllegalStateException("No default X509TrustManager found")
}
private fun sha256Hex(data: ByteArray): String {
val digest = MessageDigest.getInstance("SHA-256").digest(data)
val out = StringBuilder(digest.size * 2)
for (byte in digest) {
out.append(String.format("%02x", byte))
}
return out.toString()
}
private fun normalizeFingerprint(raw: String): String {
return raw.lowercase().filter { it in '0'..'9' || it in 'a'..'f' }
}

View File

@@ -118,7 +118,7 @@ fun RootScreen(viewModel: MainViewModel) {
contentDescription = "Approval pending",
)
}
// Avoid duplicating the primary bridge status ("Connecting…") in the activity slot.
// Avoid duplicating the primary gateway status ("Connecting…") in the activity slot.
if (screenRecordActive) {
return@remember StatusActivity(
@@ -179,14 +179,14 @@ fun RootScreen(viewModel: MainViewModel) {
null
}
val bridgeState =
val gatewayState =
remember(serverName, statusText) {
when {
serverName != null -> BridgeState.Connected
serverName != null -> GatewayState.Connected
statusText.contains("connecting", ignoreCase = true) ||
statusText.contains("reconnecting", ignoreCase = true) -> BridgeState.Connecting
statusText.contains("error", ignoreCase = true) -> BridgeState.Error
else -> BridgeState.Disconnected
statusText.contains("reconnecting", ignoreCase = true) -> GatewayState.Connecting
statusText.contains("error", ignoreCase = true) -> GatewayState.Error
else -> GatewayState.Disconnected
}
}
@@ -206,7 +206,7 @@ fun RootScreen(viewModel: MainViewModel) {
// Keep the overlay buttons above the WebView canvas (AndroidView), otherwise they may not receive touches.
Popup(alignment = Alignment.TopStart, properties = PopupProperties(focusable = false)) {
StatusPill(
bridge = bridgeState,
gateway = gatewayState,
voiceEnabled = voiceEnabled,
activity = activity,
onClick = { sheet = Sheet.Settings },

View File

@@ -48,6 +48,7 @@ import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.alpha
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.dp
@@ -74,11 +75,12 @@ fun SettingsSheet(viewModel: MainViewModel) {
val manualEnabled by viewModel.manualEnabled.collectAsState()
val manualHost by viewModel.manualHost.collectAsState()
val manualPort by viewModel.manualPort.collectAsState()
val manualTls by viewModel.manualTls.collectAsState()
val canvasDebugStatusEnabled by viewModel.canvasDebugStatusEnabled.collectAsState()
val statusText by viewModel.statusText.collectAsState()
val serverName by viewModel.serverName.collectAsState()
val remoteAddress by viewModel.remoteAddress.collectAsState()
val bridges by viewModel.bridges.collectAsState()
val gateways by viewModel.gateways.collectAsState()
val discoveryStatusText by viewModel.discoveryStatusText.collectAsState()
val listState = rememberLazyListState()
@@ -163,7 +165,7 @@ fun SettingsSheet(viewModel: MainViewModel) {
val smsPermissionLauncher =
rememberLauncherForActivityResult(ActivityResultContracts.RequestPermission()) { granted ->
smsPermissionGranted = granted
viewModel.refreshBridgeHello()
viewModel.refreshGatewayConnection()
}
fun setCameraEnabledChecked(checked: Boolean) {
@@ -223,20 +225,20 @@ fun SettingsSheet(viewModel: MainViewModel) {
}
}
val visibleBridges =
val visibleGateways =
if (isConnected && remoteAddress != null) {
bridges.filterNot { "${it.host}:${it.port}" == remoteAddress }
gateways.filterNot { "${it.host}:${it.port}" == remoteAddress }
} else {
bridges
gateways
}
val bridgeDiscoveryFooterText =
if (visibleBridges.isEmpty()) {
val gatewayDiscoveryFooterText =
if (visibleGateways.isEmpty()) {
discoveryStatusText
} else if (isConnected) {
"Discovery active • ${visibleBridges.size} other bridge${if (visibleBridges.size == 1) "" else "s"} found"
"Discovery active • ${visibleGateways.size} other gateway${if (visibleGateways.size == 1) "" else "s"} found"
} else {
"Discovery active • ${visibleBridges.size} bridge${if (visibleBridges.size == 1) "" else "s"} found"
"Discovery active • ${visibleGateways.size} gateway${if (visibleGateways.size == 1) "" else "s"} found"
}
LazyColumn(
@@ -250,7 +252,7 @@ fun SettingsSheet(viewModel: MainViewModel) {
contentPadding = PaddingValues(16.dp),
verticalArrangement = Arrangement.spacedBy(6.dp),
) {
// Order parity: Node → Bridge → Voice → Camera → Messaging → Location → Screen.
// Order parity: Node → Gateway → Voice → Camera → Messaging → Location → Screen.
item { Text("Node", style = MaterialTheme.typography.titleSmall) }
item {
OutlinedTextField(
@@ -266,8 +268,8 @@ fun SettingsSheet(viewModel: MainViewModel) {
item { HorizontalDivider() }
// Bridge
item { Text("Bridge", style = MaterialTheme.typography.titleSmall) }
// Gateway
item { Text("Gateway", style = MaterialTheme.typography.titleSmall) }
item { ListItem(headlineContent = { Text("Status") }, supportingContent = { Text(statusText) }) }
if (serverName != null) {
item { ListItem(headlineContent = { Text("Server") }, supportingContent = { Text(serverName!!) }) }
@@ -291,31 +293,31 @@ fun SettingsSheet(viewModel: MainViewModel) {
item { HorizontalDivider() }
if (!isConnected || visibleBridges.isNotEmpty()) {
if (!isConnected || visibleGateways.isNotEmpty()) {
item {
Text(
if (isConnected) "Other Bridges" else "Discovered Bridges",
if (isConnected) "Other Gateways" else "Discovered Gateways",
style = MaterialTheme.typography.titleSmall,
)
}
if (!isConnected && visibleBridges.isEmpty()) {
item { Text("No bridges found yet.", color = MaterialTheme.colorScheme.onSurfaceVariant) }
if (!isConnected && visibleGateways.isEmpty()) {
item { Text("No gateways found yet.", color = MaterialTheme.colorScheme.onSurfaceVariant) }
} else {
items(items = visibleBridges, key = { it.stableId }) { bridge ->
items(items = visibleGateways, key = { it.stableId }) { gateway ->
val detailLines =
buildList {
add("IP: ${bridge.host}:${bridge.port}")
bridge.lanHost?.let { add("LAN: $it") }
bridge.tailnetDns?.let { add("Tailnet: $it") }
if (bridge.gatewayPort != null || bridge.bridgePort != null || bridge.canvasPort != null) {
val gw = bridge.gatewayPort?.toString() ?: ""
val br = (bridge.bridgePort ?: bridge.port).toString()
val canvas = bridge.canvasPort?.toString() ?: ""
add("IP: ${gateway.host}:${gateway.port}")
gateway.lanHost?.let { add("LAN: $it") }
gateway.tailnetDns?.let { add("Tailnet: $it") }
if (gateway.gatewayPort != null || gateway.bridgePort != null || gateway.canvasPort != null) {
val gw = gateway.gatewayPort?.toString() ?: ""
val br = (gateway.bridgePort ?: gateway.port).toString()
val canvas = gateway.canvasPort?.toString() ?: ""
add("Ports: gw $gw · bridge $br · canvas $canvas")
}
}
ListItem(
headlineContent = { Text(bridge.name) },
headlineContent = { Text(gateway.name) },
supportingContent = {
Column(verticalArrangement = Arrangement.spacedBy(2.dp)) {
detailLines.forEach { line ->
@@ -327,7 +329,7 @@ fun SettingsSheet(viewModel: MainViewModel) {
Button(
onClick = {
NodeForegroundService.start(context)
viewModel.connect(bridge)
viewModel.connect(gateway)
},
) {
Text("Connect")
@@ -338,7 +340,7 @@ fun SettingsSheet(viewModel: MainViewModel) {
}
item {
Text(
bridgeDiscoveryFooterText,
gatewayDiscoveryFooterText,
modifier = Modifier.fillMaxWidth(),
textAlign = TextAlign.Center,
style = MaterialTheme.typography.labelMedium,
@@ -352,7 +354,7 @@ fun SettingsSheet(viewModel: MainViewModel) {
item {
ListItem(
headlineContent = { Text("Advanced") },
supportingContent = { Text("Manual bridge connection") },
supportingContent = { Text("Manual gateway connection") },
trailingContent = {
Icon(
imageVector = if (advancedExpanded) Icons.Filled.ExpandLess else Icons.Filled.ExpandMore,
@@ -369,7 +371,7 @@ fun SettingsSheet(viewModel: MainViewModel) {
AnimatedVisibility(visible = advancedExpanded) {
Column(verticalArrangement = Arrangement.spacedBy(10.dp), modifier = Modifier.fillMaxWidth()) {
ListItem(
headlineContent = { Text("Use Manual Bridge") },
headlineContent = { Text("Use Manual Gateway") },
supportingContent = { Text("Use this when discovery is blocked.") },
trailingContent = { Switch(checked = manualEnabled, onCheckedChange = viewModel::setManualEnabled) },
)
@@ -388,6 +390,12 @@ fun SettingsSheet(viewModel: MainViewModel) {
modifier = Modifier.fillMaxWidth(),
enabled = manualEnabled,
)
ListItem(
headlineContent = { Text("Require TLS") },
supportingContent = { Text("Pin the gateway certificate on first connect.") },
trailingContent = { Switch(checked = manualTls, onCheckedChange = viewModel::setManualTls, enabled = manualEnabled) },
modifier = Modifier.alpha(if (manualEnabled) 1f else 0.5f),
)
val hostOk = manualHost.trim().isNotEmpty()
val portOk = manualPort in 1..65535
@@ -496,7 +504,7 @@ fun SettingsSheet(viewModel: MainViewModel) {
item {
Text(
if (isConnected) {
"Any node can edit wake words. Changes sync via the gateway bridge."
"Any node can edit wake words. Changes sync via the gateway."
} else {
"Connect to a gateway to sync wake words globally."
},
@@ -511,7 +519,7 @@ fun SettingsSheet(viewModel: MainViewModel) {
item {
ListItem(
headlineContent = { Text("Allow Camera") },
supportingContent = { Text("Allows the bridge to request photos or short video clips (foreground only).") },
supportingContent = { Text("Allows the gateway to request photos or short video clips (foreground only).") },
trailingContent = { Switch(checked = cameraEnabled, onCheckedChange = ::setCameraEnabledChecked) },
)
}
@@ -538,7 +546,7 @@ fun SettingsSheet(viewModel: MainViewModel) {
supportingContent = {
Text(
if (smsPermissionAvailable) {
"Allow the bridge to send SMS from this device."
"Allow the gateway to send SMS from this device."
} else {
"SMS requires a device with telephony hardware."
},

View File

@@ -26,7 +26,7 @@ import androidx.compose.ui.unit.dp
@Composable
fun StatusPill(
bridge: BridgeState,
gateway: GatewayState,
voiceEnabled: Boolean,
onClick: () -> Unit,
modifier: Modifier = Modifier,
@@ -49,11 +49,11 @@ fun StatusPill(
Surface(
modifier = Modifier.size(9.dp),
shape = CircleShape,
color = bridge.color,
color = gateway.color,
) {}
Text(
text = bridge.title,
text = gateway.title,
style = MaterialTheme.typography.labelLarge,
)
}
@@ -106,7 +106,7 @@ data class StatusActivity(
val tint: Color? = null,
)
enum class BridgeState(val title: String, val color: Color) {
enum class GatewayState(val title: String, val color: Color) {
Connected("Connected", Color(0xFF2ECC71)),
Connecting("Connecting…", Color(0xFFF1C40F)),
Error("Error", Color(0xFFE74C3C)),

View File

@@ -20,7 +20,7 @@ import android.speech.tts.TextToSpeech
import android.speech.tts.UtteranceProgressListener
import android.util.Log
import androidx.core.content.ContextCompat
import com.clawdbot.android.bridge.BridgeSession
import com.clawdbot.android.gateway.GatewaySession
import com.clawdbot.android.isCanonicalMainSessionKey
import com.clawdbot.android.normalizeMainKey
import java.net.HttpURLConnection
@@ -46,6 +46,9 @@ import kotlin.math.max
class TalkModeManager(
private val context: Context,
private val scope: CoroutineScope,
private val session: GatewaySession,
private val supportsChatSubscribe: Boolean,
private val isConnected: () -> Boolean,
) {
companion object {
private const val tag = "TalkMode"
@@ -99,7 +102,6 @@ class TalkModeManager(
private var modelOverrideActive = false
private var mainSessionKey: String = "main"
private var session: BridgeSession? = null
private var pendingRunId: String? = null
private var pendingFinal: CompletableDeferred<Boolean>? = null
private var chatSubscribedSessionKey: String? = null
@@ -112,11 +114,6 @@ class TalkModeManager(
private var systemTtsPending: CompletableDeferred<Unit>? = null
private var systemTtsPendingId: String? = null
fun attachSession(session: BridgeSession) {
this.session = session
chatSubscribedSessionKey = null
}
fun setMainSessionKey(sessionKey: String?) {
val trimmed = sessionKey?.trim().orEmpty()
if (trimmed.isEmpty()) return
@@ -136,7 +133,7 @@ class TalkModeManager(
}
}
fun handleBridgeEvent(event: String, payloadJson: String?) {
fun handleGatewayEvent(event: String, payloadJson: String?) {
if (event != "chat") return
if (payloadJson.isNullOrBlank()) return
val pending = pendingRunId ?: return
@@ -306,25 +303,24 @@ class TalkModeManager(
reloadConfig()
val prompt = buildPrompt(transcript)
val bridge = session
if (bridge == null) {
_statusText.value = "Bridge not connected"
Log.w(tag, "finalize: bridge not connected")
if (!isConnected()) {
_statusText.value = "Gateway not connected"
Log.w(tag, "finalize: gateway not connected")
start()
return
}
try {
val startedAt = System.currentTimeMillis().toDouble() / 1000.0
subscribeChatIfNeeded(bridge = bridge, sessionKey = mainSessionKey)
subscribeChatIfNeeded(session = session, sessionKey = mainSessionKey)
Log.d(tag, "chat.send start sessionKey=${mainSessionKey.ifBlank { "main" }} chars=${prompt.length}")
val runId = sendChat(prompt, bridge)
val runId = sendChat(prompt, session)
Log.d(tag, "chat.send ok runId=$runId")
val ok = waitForChatFinal(runId)
if (!ok) {
Log.w(tag, "chat final timeout runId=$runId; attempting history fallback")
}
val assistant = waitForAssistantText(bridge, startedAt, if (ok) 12_000 else 25_000)
val assistant = waitForAssistantText(session, startedAt, if (ok) 12_000 else 25_000)
if (assistant.isNullOrBlank()) {
_statusText.value = "No reply"
Log.w(tag, "assistant text timeout runId=$runId")
@@ -343,12 +339,13 @@ class TalkModeManager(
}
}
private suspend fun subscribeChatIfNeeded(bridge: BridgeSession, sessionKey: String) {
private suspend fun subscribeChatIfNeeded(session: GatewaySession, sessionKey: String) {
if (!supportsChatSubscribe) return
val key = sessionKey.trim()
if (key.isEmpty()) return
if (chatSubscribedSessionKey == key) return
try {
bridge.sendEvent("chat.subscribe", """{"sessionKey":"$key"}""")
session.sendNodeEvent("chat.subscribe", """{"sessionKey":"$key"}""")
chatSubscribedSessionKey = key
Log.d(tag, "chat.subscribe ok sessionKey=$key")
} catch (err: Throwable) {
@@ -370,7 +367,7 @@ class TalkModeManager(
return lines.joinToString("\n")
}
private suspend fun sendChat(message: String, bridge: BridgeSession): String {
private suspend fun sendChat(message: String, session: GatewaySession): String {
val runId = UUID.randomUUID().toString()
val params =
buildJsonObject {
@@ -380,7 +377,7 @@ class TalkModeManager(
put("timeoutMs", JsonPrimitive(30_000))
put("idempotencyKey", JsonPrimitive(runId))
}
val res = bridge.request("chat.send", params.toString())
val res = session.request("chat.send", params.toString())
val parsed = parseRunId(res) ?: runId
if (parsed != runId) {
pendingRunId = parsed
@@ -411,13 +408,13 @@ class TalkModeManager(
}
private suspend fun waitForAssistantText(
bridge: BridgeSession,
session: GatewaySession,
sinceSeconds: Double,
timeoutMs: Long,
): String? {
val deadline = SystemClock.elapsedRealtime() + timeoutMs
while (SystemClock.elapsedRealtime() < deadline) {
val text = fetchLatestAssistantText(bridge, sinceSeconds)
val text = fetchLatestAssistantText(session, sinceSeconds)
if (!text.isNullOrBlank()) return text
delay(300)
}
@@ -425,11 +422,11 @@ class TalkModeManager(
}
private suspend fun fetchLatestAssistantText(
bridge: BridgeSession,
session: GatewaySession,
sinceSeconds: Double? = null,
): String? {
val key = mainSessionKey.ifBlank { "main" }
val res = bridge.request("chat.history", "{\"sessionKey\":\"$key\"}")
val res = session.request("chat.history", "{\"sessionKey\":\"$key\"}")
val root = json.parseToJsonElement(res).asObjectOrNull() ?: return null
val messages = root["messages"] as? JsonArray ?: return null
for (item in messages.reversed()) {
@@ -813,12 +810,11 @@ class TalkModeManager(
}
private suspend fun reloadConfig() {
val bridge = session ?: return
val envVoice = System.getenv("ELEVENLABS_VOICE_ID")?.trim()
val sagVoice = System.getenv("SAG_VOICE_ID")?.trim()
val envKey = System.getenv("ELEVENLABS_API_KEY")?.trim()
try {
val res = bridge.request("config.get", "{}")
val res = session.request("config.get", "{}")
val root = json.parseToJsonElement(res).asObjectOrNull()
val config = root?.get("config").asObjectOrNull()
val talk = config?.get("talk").asObjectOrNull()