refactor(android): drop legacy bridge transport

This commit is contained in:
Peter Steinberger
2026-01-19 15:45:50 +00:00
parent c7808a543d
commit 079c29ceb8
12 changed files with 7 additions and 1619 deletions

View File

@@ -1,523 +0,0 @@
package com.clawdbot.android.bridge
import android.content.Context
import android.net.ConnectivityManager
import android.net.DnsResolver
import android.net.NetworkCapabilities
import android.net.nsd.NsdManager
import android.net.nsd.NsdServiceInfo
import android.os.CancellationSignal
import android.util.Log
import java.io.IOException
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.charset.CodingErrorAction
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.Executor
import java.util.concurrent.Executors
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.launch
import kotlinx.coroutines.suspendCancellableCoroutine
import org.xbill.DNS.AAAARecord
import org.xbill.DNS.ARecord
import org.xbill.DNS.DClass
import org.xbill.DNS.ExtendedResolver
import org.xbill.DNS.Message
import org.xbill.DNS.Name
import org.xbill.DNS.PTRRecord
import org.xbill.DNS.Record
import org.xbill.DNS.Rcode
import org.xbill.DNS.Resolver
import org.xbill.DNS.SRVRecord
import org.xbill.DNS.Section
import org.xbill.DNS.SimpleResolver
import org.xbill.DNS.TextParseException
import org.xbill.DNS.TXTRecord
import org.xbill.DNS.Type
import kotlin.coroutines.resume
import kotlin.coroutines.resumeWithException
@Suppress("DEPRECATION")
class BridgeDiscovery(
context: Context,
private val scope: CoroutineScope,
) {
private val nsd = context.getSystemService(NsdManager::class.java)
private val connectivity = context.getSystemService(ConnectivityManager::class.java)
private val dns = DnsResolver.getInstance()
private val serviceType = "_clawdbot-bridge._tcp."
private val wideAreaDomain = "clawdbot.internal."
private val logTag = "Clawdbot/BridgeDiscovery"
private val localById = ConcurrentHashMap<String, BridgeEndpoint>()
private val unicastById = ConcurrentHashMap<String, BridgeEndpoint>()
private val _bridges = MutableStateFlow<List<BridgeEndpoint>>(emptyList())
val bridges: StateFlow<List<BridgeEndpoint>> = _bridges.asStateFlow()
private val _statusText = MutableStateFlow("Searching…")
val statusText: StateFlow<String> = _statusText.asStateFlow()
private var unicastJob: Job? = null
private val dnsExecutor: Executor = Executors.newCachedThreadPool()
@Volatile private var lastWideAreaRcode: Int? = null
@Volatile private var lastWideAreaCount: Int = 0
private val discoveryListener =
object : NsdManager.DiscoveryListener {
override fun onStartDiscoveryFailed(serviceType: String, errorCode: Int) {}
override fun onStopDiscoveryFailed(serviceType: String, errorCode: Int) {}
override fun onDiscoveryStarted(serviceType: String) {}
override fun onDiscoveryStopped(serviceType: String) {}
override fun onServiceFound(serviceInfo: NsdServiceInfo) {
if (serviceInfo.serviceType != this@BridgeDiscovery.serviceType) return
resolve(serviceInfo)
}
override fun onServiceLost(serviceInfo: NsdServiceInfo) {
val serviceName = BonjourEscapes.decode(serviceInfo.serviceName)
val id = stableId(serviceName, "local.")
localById.remove(id)
publish()
}
}
init {
startLocalDiscovery()
startUnicastDiscovery(wideAreaDomain)
}
private fun startLocalDiscovery() {
try {
nsd.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, discoveryListener)
} catch (_: Throwable) {
// ignore (best-effort)
}
}
private fun stopLocalDiscovery() {
try {
nsd.stopServiceDiscovery(discoveryListener)
} catch (_: Throwable) {
// ignore (best-effort)
}
}
private fun startUnicastDiscovery(domain: String) {
unicastJob =
scope.launch(Dispatchers.IO) {
while (true) {
try {
refreshUnicast(domain)
} catch (_: Throwable) {
// ignore (best-effort)
}
delay(5000)
}
}
}
private fun resolve(serviceInfo: NsdServiceInfo) {
nsd.resolveService(
serviceInfo,
object : NsdManager.ResolveListener {
override fun onResolveFailed(serviceInfo: NsdServiceInfo, errorCode: Int) {}
override fun onServiceResolved(resolved: NsdServiceInfo) {
val host = resolved.host?.hostAddress ?: return
val port = resolved.port
if (port <= 0) return
val rawServiceName = resolved.serviceName
val serviceName = BonjourEscapes.decode(rawServiceName)
val displayName = BonjourEscapes.decode(txt(resolved, "displayName") ?: serviceName)
val lanHost = txt(resolved, "lanHost")
val tailnetDns = txt(resolved, "tailnetDns")
val gatewayPort = txtInt(resolved, "gatewayPort")
val bridgePort = txtInt(resolved, "bridgePort")
val canvasPort = txtInt(resolved, "canvasPort")
val tlsEnabled = txtBool(resolved, "bridgeTls")
val tlsFingerprint = txt(resolved, "bridgeTlsSha256")
val id = stableId(serviceName, "local.")
localById[id] =
BridgeEndpoint(
stableId = id,
name = displayName,
host = host,
port = port,
lanHost = lanHost,
tailnetDns = tailnetDns,
gatewayPort = gatewayPort,
bridgePort = bridgePort,
canvasPort = canvasPort,
tlsEnabled = tlsEnabled,
tlsFingerprintSha256 = tlsFingerprint,
)
publish()
}
},
)
}
private fun publish() {
_bridges.value =
(localById.values + unicastById.values).sortedBy { it.name.lowercase() }
_statusText.value = buildStatusText()
}
private fun buildStatusText(): String {
val localCount = localById.size
val wideRcode = lastWideAreaRcode
val wideCount = lastWideAreaCount
val wide =
when (wideRcode) {
null -> "Wide: ?"
Rcode.NOERROR -> "Wide: $wideCount"
Rcode.NXDOMAIN -> "Wide: NXDOMAIN"
else -> "Wide: ${Rcode.string(wideRcode)}"
}
return when {
localCount == 0 && wideRcode == null -> "Searching for bridges…"
localCount == 0 -> "$wide"
else -> "Local: $localCount$wide"
}
}
private fun stableId(serviceName: String, domain: String): String {
return "${serviceType}|${domain}|${normalizeName(serviceName)}"
}
private fun normalizeName(raw: String): String {
return raw.trim().split(Regex("\\s+")).joinToString(" ")
}
private fun txt(info: NsdServiceInfo, key: String): String? {
val bytes = info.attributes[key] ?: return null
return try {
String(bytes, Charsets.UTF_8).trim().ifEmpty { null }
} catch (_: Throwable) {
null
}
}
private fun txtInt(info: NsdServiceInfo, key: String): Int? {
return txt(info, key)?.toIntOrNull()
}
private fun txtBool(info: NsdServiceInfo, key: String): Boolean {
val raw = txt(info, key)?.trim()?.lowercase() ?: return false
return raw == "1" || raw == "true" || raw == "yes"
}
private suspend fun refreshUnicast(domain: String) {
val ptrName = "${serviceType}${domain}"
val ptrMsg = lookupUnicastMessage(ptrName, Type.PTR) ?: return
val ptrRecords = records(ptrMsg, Section.ANSWER).mapNotNull { it as? PTRRecord }
val next = LinkedHashMap<String, BridgeEndpoint>()
for (ptr in ptrRecords) {
val instanceFqdn = ptr.target.toString()
val srv =
recordByName(ptrMsg, instanceFqdn, Type.SRV) as? SRVRecord
?: run {
val msg = lookupUnicastMessage(instanceFqdn, Type.SRV) ?: return@run null
recordByName(msg, instanceFqdn, Type.SRV) as? SRVRecord
}
?: continue
val port = srv.port
if (port <= 0) continue
val targetFqdn = srv.target.toString()
val host =
resolveHostFromMessage(ptrMsg, targetFqdn)
?: resolveHostFromMessage(lookupUnicastMessage(instanceFqdn, Type.SRV), targetFqdn)
?: resolveHostUnicast(targetFqdn)
?: continue
val txtFromPtr =
recordsByName(ptrMsg, Section.ADDITIONAL)[keyName(instanceFqdn)]
.orEmpty()
.mapNotNull { it as? TXTRecord }
val txt =
if (txtFromPtr.isNotEmpty()) {
txtFromPtr
} else {
val msg = lookupUnicastMessage(instanceFqdn, Type.TXT)
records(msg, Section.ANSWER).mapNotNull { it as? TXTRecord }
}
val instanceName = BonjourEscapes.decode(decodeInstanceName(instanceFqdn, domain))
val displayName = BonjourEscapes.decode(txtValue(txt, "displayName") ?: instanceName)
val lanHost = txtValue(txt, "lanHost")
val tailnetDns = txtValue(txt, "tailnetDns")
val gatewayPort = txtIntValue(txt, "gatewayPort")
val bridgePort = txtIntValue(txt, "bridgePort")
val canvasPort = txtIntValue(txt, "canvasPort")
val tlsEnabled = txtBoolValue(txt, "bridgeTls")
val tlsFingerprint = txtValue(txt, "bridgeTlsSha256")
val id = stableId(instanceName, domain)
next[id] =
BridgeEndpoint(
stableId = id,
name = displayName,
host = host,
port = port,
lanHost = lanHost,
tailnetDns = tailnetDns,
gatewayPort = gatewayPort,
bridgePort = bridgePort,
canvasPort = canvasPort,
tlsEnabled = tlsEnabled,
tlsFingerprintSha256 = tlsFingerprint,
)
}
unicastById.clear()
unicastById.putAll(next)
lastWideAreaRcode = ptrMsg.header.rcode
lastWideAreaCount = next.size
publish()
if (next.isEmpty()) {
Log.d(
logTag,
"wide-area discovery: 0 results for $ptrName (rcode=${Rcode.string(ptrMsg.header.rcode)})",
)
}
}
private fun decodeInstanceName(instanceFqdn: String, domain: String): String {
val suffix = "${serviceType}${domain}"
val withoutSuffix =
if (instanceFqdn.endsWith(suffix)) {
instanceFqdn.removeSuffix(suffix)
} else {
instanceFqdn.substringBefore(serviceType)
}
return normalizeName(stripTrailingDot(withoutSuffix))
}
private fun stripTrailingDot(raw: String): String {
return raw.removeSuffix(".")
}
private suspend fun lookupUnicastMessage(name: String, type: Int): Message? {
val query =
try {
Message.newQuery(
org.xbill.DNS.Record.newRecord(
Name.fromString(name),
type,
DClass.IN,
),
)
} catch (_: TextParseException) {
return null
}
val system = queryViaSystemDns(query)
if (records(system, Section.ANSWER).any { it.type == type }) return system
val direct = createDirectResolver() ?: return system
return try {
val msg = direct.send(query)
if (records(msg, Section.ANSWER).any { it.type == type }) msg else system
} catch (_: Throwable) {
system
}
}
private suspend fun queryViaSystemDns(query: Message): Message? {
val network = preferredDnsNetwork()
val bytes =
try {
rawQuery(network, query.toWire())
} catch (_: Throwable) {
return null
}
return try {
Message(bytes)
} catch (_: IOException) {
null
}
}
private fun records(msg: Message?, section: Int): List<Record> {
return msg?.getSectionArray(section)?.toList() ?: emptyList()
}
private fun keyName(raw: String): String {
return raw.trim().lowercase()
}
private fun recordsByName(msg: Message, section: Int): Map<String, List<Record>> {
val next = LinkedHashMap<String, MutableList<Record>>()
for (r in records(msg, section)) {
val name = r.name?.toString() ?: continue
next.getOrPut(keyName(name)) { mutableListOf() }.add(r)
}
return next
}
private fun recordByName(msg: Message, fqdn: String, type: Int): Record? {
val key = keyName(fqdn)
val byNameAnswer = recordsByName(msg, Section.ANSWER)
val fromAnswer = byNameAnswer[key].orEmpty().firstOrNull { it.type == type }
if (fromAnswer != null) return fromAnswer
val byNameAdditional = recordsByName(msg, Section.ADDITIONAL)
return byNameAdditional[key].orEmpty().firstOrNull { it.type == type }
}
private fun resolveHostFromMessage(msg: Message?, hostname: String): String? {
val m = msg ?: return null
val key = keyName(hostname)
val additional = recordsByName(m, Section.ADDITIONAL)[key].orEmpty()
val a = additional.mapNotNull { it as? ARecord }.mapNotNull { it.address?.hostAddress }
val aaaa = additional.mapNotNull { it as? AAAARecord }.mapNotNull { it.address?.hostAddress }
return a.firstOrNull() ?: aaaa.firstOrNull()
}
private fun preferredDnsNetwork(): android.net.Network? {
val cm = connectivity ?: return null
// Prefer VPN (Tailscale) when present; otherwise use the active network.
cm.allNetworks.firstOrNull { n ->
val caps = cm.getNetworkCapabilities(n) ?: return@firstOrNull false
caps.hasTransport(NetworkCapabilities.TRANSPORT_VPN)
}?.let { return it }
return cm.activeNetwork
}
private fun createDirectResolver(): Resolver? {
val cm = connectivity ?: return null
val candidateNetworks =
buildList {
cm.allNetworks
.firstOrNull { n ->
val caps = cm.getNetworkCapabilities(n) ?: return@firstOrNull false
caps.hasTransport(NetworkCapabilities.TRANSPORT_VPN)
}?.let(::add)
cm.activeNetwork?.let(::add)
}.distinct()
val servers =
candidateNetworks
.asSequence()
.flatMap { n ->
cm.getLinkProperties(n)?.dnsServers?.asSequence() ?: emptySequence()
}
.distinctBy { it.hostAddress ?: it.toString() }
.toList()
if (servers.isEmpty()) return null
return try {
val resolvers =
servers.mapNotNull { addr ->
try {
SimpleResolver().apply {
setAddress(InetSocketAddress(addr, 53))
setTimeout(3)
}
} catch (_: Throwable) {
null
}
}
if (resolvers.isEmpty()) return null
ExtendedResolver(resolvers.toTypedArray()).apply { setTimeout(3) }
} catch (_: Throwable) {
null
}
}
private suspend fun rawQuery(network: android.net.Network?, wireQuery: ByteArray): ByteArray =
suspendCancellableCoroutine { cont ->
val signal = CancellationSignal()
cont.invokeOnCancellation { signal.cancel() }
dns.rawQuery(
network,
wireQuery,
DnsResolver.FLAG_EMPTY,
dnsExecutor,
signal,
object : DnsResolver.Callback<ByteArray> {
override fun onAnswer(answer: ByteArray, rcode: Int) {
cont.resume(answer)
}
override fun onError(error: DnsResolver.DnsException) {
cont.resumeWithException(error)
}
},
)
}
private fun txtValue(records: List<TXTRecord>, key: String): String? {
val prefix = "$key="
for (r in records) {
val strings: List<String> =
try {
r.strings.mapNotNull { it as? String }
} catch (_: Throwable) {
emptyList()
}
for (s in strings) {
val trimmed = decodeDnsTxtString(s).trim()
if (trimmed.startsWith(prefix)) {
return trimmed.removePrefix(prefix).trim().ifEmpty { null }
}
}
}
return null
}
private fun txtIntValue(records: List<TXTRecord>, key: String): Int? {
return txtValue(records, key)?.toIntOrNull()
}
private fun txtBoolValue(records: List<TXTRecord>, key: String): Boolean {
val raw = txtValue(records, key)?.trim()?.lowercase() ?: return false
return raw == "1" || raw == "true" || raw == "yes"
}
private fun decodeDnsTxtString(raw: String): String {
// dnsjava treats TXT as opaque bytes and decodes as ISO-8859-1 to preserve bytes.
// Our TXT payload is UTF-8 (written by the gateway), so re-decode when possible.
val bytes = raw.toByteArray(Charsets.ISO_8859_1)
val decoder =
Charsets.UTF_8
.newDecoder()
.onMalformedInput(CodingErrorAction.REPORT)
.onUnmappableCharacter(CodingErrorAction.REPORT)
return try {
decoder.decode(ByteBuffer.wrap(bytes)).toString()
} catch (_: Throwable) {
raw
}
}
private suspend fun resolveHostUnicast(hostname: String): String? {
val a =
records(lookupUnicastMessage(hostname, Type.A), Section.ANSWER)
.mapNotNull { it as? ARecord }
.mapNotNull { it.address?.hostAddress }
val aaaa =
records(lookupUnicastMessage(hostname, Type.AAAA), Section.ANSWER)
.mapNotNull { it as? AAAARecord }
.mapNotNull { it.address?.hostAddress }
return a.firstOrNull() ?: aaaa.firstOrNull()
}
}

View File

@@ -1,27 +0,0 @@
package com.clawdbot.android.bridge
data class BridgeEndpoint(
val stableId: String,
val name: String,
val host: String,
val port: Int,
val lanHost: String? = null,
val tailnetDns: String? = null,
val gatewayPort: Int? = null,
val bridgePort: Int? = null,
val canvasPort: Int? = null,
val tlsEnabled: Boolean = false,
val tlsFingerprintSha256: String? = null,
) {
companion object {
fun manual(host: String, port: Int): BridgeEndpoint =
BridgeEndpoint(
stableId = "manual|$host|$port",
name = "$host:$port",
host = host,
port = port,
tlsEnabled = false,
tlsFingerprintSha256 = null,
)
}
}

View File

@@ -1,158 +0,0 @@
package com.clawdbot.android.bridge
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonArray
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.JsonNull
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.buildJsonObject
import java.io.BufferedReader
import java.io.BufferedWriter
import java.io.InputStreamReader
import java.io.OutputStreamWriter
import java.net.InetSocketAddress
class BridgePairingClient {
private val json = Json { ignoreUnknownKeys = true }
data class Hello(
val nodeId: String,
val displayName: String?,
val token: String?,
val platform: String?,
val version: String?,
val deviceFamily: String?,
val modelIdentifier: String?,
val caps: List<String>?,
val commands: List<String>?,
)
data class PairResult(val ok: Boolean, val token: String?, val error: String? = null)
suspend fun pairAndHello(
endpoint: BridgeEndpoint,
hello: Hello,
tls: BridgeTlsParams? = null,
onTlsFingerprint: ((String) -> Unit)? = null,
): PairResult =
withContext(Dispatchers.IO) {
if (tls != null) {
try {
return@withContext pairAndHelloWithTls(endpoint, hello, tls, onTlsFingerprint)
} catch (e: Exception) {
if (tls.required) throw e
}
}
pairAndHelloWithTls(endpoint, hello, null, null)
}
private fun pairAndHelloWithTls(
endpoint: BridgeEndpoint,
hello: Hello,
tls: BridgeTlsParams?,
onTlsFingerprint: ((String) -> Unit)?,
): PairResult {
val socket =
createBridgeSocket(tls) { fingerprint ->
onTlsFingerprint?.invoke(fingerprint)
}
socket.tcpNoDelay = true
try {
socket.connect(InetSocketAddress(endpoint.host, endpoint.port), 8_000)
socket.soTimeout = 60_000
startTlsHandshakeIfNeeded(socket)
val reader = BufferedReader(InputStreamReader(socket.getInputStream(), Charsets.UTF_8))
val writer = BufferedWriter(OutputStreamWriter(socket.getOutputStream(), Charsets.UTF_8))
fun send(line: String) {
writer.write(line)
writer.write("\n")
writer.flush()
}
fun sendJson(obj: JsonObject) = send(obj.toString())
sendJson(
buildJsonObject {
put("type", JsonPrimitive("hello"))
put("nodeId", JsonPrimitive(hello.nodeId))
hello.displayName?.let { put("displayName", JsonPrimitive(it)) }
hello.token?.let { put("token", JsonPrimitive(it)) }
hello.platform?.let { put("platform", JsonPrimitive(it)) }
hello.version?.let { put("version", JsonPrimitive(it)) }
hello.deviceFamily?.let { put("deviceFamily", JsonPrimitive(it)) }
hello.modelIdentifier?.let { put("modelIdentifier", JsonPrimitive(it)) }
hello.caps?.let { put("caps", JsonArray(it.map(::JsonPrimitive))) }
hello.commands?.let { put("commands", JsonArray(it.map(::JsonPrimitive))) }
},
)
val firstObj = json.parseToJsonElement(reader.readLine()).asObjectOrNull()
?: return PairResult(ok = false, token = null, error = "unexpected bridge response")
return when (firstObj["type"].asStringOrNull()) {
"hello-ok" -> PairResult(ok = true, token = hello.token)
"error" -> {
val code = firstObj["code"].asStringOrNull() ?: "UNAVAILABLE"
val message = firstObj["message"].asStringOrNull() ?: "pairing required"
if (code != "NOT_PAIRED" && code != "UNAUTHORIZED") {
return PairResult(ok = false, token = null, error = "$code: $message")
}
sendJson(
buildJsonObject {
put("type", JsonPrimitive("pair-request"))
put("nodeId", JsonPrimitive(hello.nodeId))
hello.displayName?.let { put("displayName", JsonPrimitive(it)) }
hello.platform?.let { put("platform", JsonPrimitive(it)) }
hello.version?.let { put("version", JsonPrimitive(it)) }
hello.deviceFamily?.let { put("deviceFamily", JsonPrimitive(it)) }
hello.modelIdentifier?.let { put("modelIdentifier", JsonPrimitive(it)) }
hello.caps?.let { put("caps", JsonArray(it.map(::JsonPrimitive))) }
hello.commands?.let { put("commands", JsonArray(it.map(::JsonPrimitive))) }
},
)
while (true) {
val nextLine = reader.readLine() ?: break
val next = json.parseToJsonElement(nextLine).asObjectOrNull() ?: continue
when (next["type"].asStringOrNull()) {
"pair-ok" -> {
val token = next["token"].asStringOrNull()
return PairResult(ok = !token.isNullOrBlank(), token = token)
}
"error" -> {
val c = next["code"].asStringOrNull() ?: "UNAVAILABLE"
val m = next["message"].asStringOrNull() ?: "pairing failed"
return PairResult(ok = false, token = null, error = "$c: $m")
}
}
}
PairResult(ok = false, token = null, error = "pairing failed")
}
else -> PairResult(ok = false, token = null, error = "unexpected bridge response")
}
} catch (e: Exception) {
val message = e.message?.trim().orEmpty().ifEmpty { "gateway unreachable" }
return PairResult(ok = false, token = null, error = message)
} finally {
try {
socket.close()
} catch (_: Throwable) {
// ignore
}
}
}
}
private fun JsonElement?.asObjectOrNull(): JsonObject? = this as? JsonObject
private fun JsonElement?.asStringOrNull(): String? =
when (this) {
is JsonNull -> null
is JsonPrimitive -> content
else -> null
}

View File

@@ -1,398 +0,0 @@
package com.clawdbot.android.bridge
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.delay
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext
import com.clawdbot.android.BuildConfig
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonArray
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonNull
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.buildJsonObject
import java.io.BufferedReader
import java.io.BufferedWriter
import java.io.InputStreamReader
import java.io.OutputStreamWriter
import java.net.InetSocketAddress
import java.net.URI
import java.net.Socket
import java.util.UUID
import java.util.concurrent.ConcurrentHashMap
class BridgeSession(
private val scope: CoroutineScope,
private val onConnected: (serverName: String, remoteAddress: String?, mainSessionKey: String?) -> Unit,
private val onDisconnected: (message: String) -> Unit,
private val onEvent: (event: String, payloadJson: String?) -> Unit,
private val onInvoke: suspend (InvokeRequest) -> InvokeResult,
private val onTlsFingerprint: ((stableId: String, fingerprint: String) -> Unit)? = null,
) {
data class Hello(
val nodeId: String,
val displayName: String?,
val token: String?,
val platform: String?,
val version: String?,
val deviceFamily: String?,
val modelIdentifier: String?,
val caps: List<String>?,
val commands: List<String>?,
)
data class InvokeRequest(val id: String, val command: String, val paramsJson: String?)
data class InvokeResult(val ok: Boolean, val payloadJson: String?, val error: ErrorShape?) {
companion object {
fun ok(payloadJson: String?) = InvokeResult(ok = true, payloadJson = payloadJson, error = null)
fun error(code: String, message: String) =
InvokeResult(ok = false, payloadJson = null, error = ErrorShape(code = code, message = message))
}
}
data class ErrorShape(val code: String, val message: String)
private val json = Json { ignoreUnknownKeys = true }
private val writeLock = Mutex()
private val pending = ConcurrentHashMap<String, CompletableDeferred<RpcResponse>>()
@Volatile private var canvasHostUrl: String? = null
@Volatile private var mainSessionKey: String? = null
private data class DesiredConnection(
val endpoint: BridgeEndpoint,
val hello: Hello,
val tls: BridgeTlsParams?,
)
private var desired: DesiredConnection? = null
private var job: Job? = null
fun connect(endpoint: BridgeEndpoint, hello: Hello, tls: BridgeTlsParams? = null) {
desired = DesiredConnection(endpoint, hello, tls)
if (job == null) {
job = scope.launch(Dispatchers.IO) { runLoop() }
}
}
suspend fun updateHello(hello: Hello) {
val target = desired ?: return
desired = target.copy(hello = hello)
val conn = currentConnection ?: return
conn.sendJson(buildHelloJson(hello))
}
fun disconnect() {
desired = null
// Unblock connectOnce() read loop. Coroutine cancellation alone won't interrupt BufferedReader.readLine().
currentConnection?.closeQuietly()
scope.launch(Dispatchers.IO) {
job?.cancelAndJoin()
job = null
canvasHostUrl = null
mainSessionKey = null
onDisconnected("Offline")
}
}
fun currentCanvasHostUrl(): String? = canvasHostUrl
fun currentMainSessionKey(): String? = mainSessionKey
suspend fun sendEvent(event: String, payloadJson: String?) {
val conn = currentConnection ?: return
conn.sendJson(
buildJsonObject {
put("type", JsonPrimitive("event"))
put("event", JsonPrimitive(event))
if (payloadJson != null) put("payloadJSON", JsonPrimitive(payloadJson)) else put("payloadJSON", JsonNull)
},
)
}
suspend fun request(method: String, paramsJson: String?): String {
val conn = currentConnection ?: throw IllegalStateException("not connected")
val id = UUID.randomUUID().toString()
val deferred = CompletableDeferred<RpcResponse>()
pending[id] = deferred
conn.sendJson(
buildJsonObject {
put("type", JsonPrimitive("req"))
put("id", JsonPrimitive(id))
put("method", JsonPrimitive(method))
if (paramsJson != null) put("paramsJSON", JsonPrimitive(paramsJson)) else put("paramsJSON", JsonNull)
},
)
val res = deferred.await()
if (res.ok) return res.payloadJson ?: ""
val err = res.error
throw IllegalStateException("${err?.code ?: "UNAVAILABLE"}: ${err?.message ?: "request failed"}")
}
private data class RpcResponse(val id: String, val ok: Boolean, val payloadJson: String?, val error: ErrorShape?)
private class Connection(private val socket: Socket, private val reader: BufferedReader, private val writer: BufferedWriter, private val writeLock: Mutex) {
val remoteAddress: String? =
socket.inetAddress?.hostAddress?.takeIf { it.isNotBlank() }?.let { "${it}:${socket.port}" }
suspend fun sendJson(obj: JsonObject) {
writeLock.withLock {
writer.write(obj.toString())
writer.write("\n")
writer.flush()
}
}
fun closeQuietly() {
try {
socket.close()
} catch (_: Throwable) {
// ignore
}
}
}
@Volatile private var currentConnection: Connection? = null
private suspend fun runLoop() {
var attempt = 0
while (scope.isActive) {
val target = desired
if (target == null) {
currentConnection?.closeQuietly()
currentConnection = null
delay(250)
continue
}
val (endpoint, hello, tls) = target
try {
onDisconnected(if (attempt == 0) "Connecting…" else "Reconnecting…")
connectOnce(endpoint, hello, tls)
attempt = 0
} catch (err: Throwable) {
attempt += 1
onDisconnected("Bridge error: ${err.message ?: err::class.java.simpleName}")
val sleepMs = minOf(8_000L, (350.0 * Math.pow(1.7, attempt.toDouble())).toLong())
delay(sleepMs)
}
}
}
private fun invokeErrorFromThrowable(err: Throwable): InvokeResult {
val msg = err.message?.trim().takeIf { !it.isNullOrEmpty() } ?: err::class.java.simpleName
val parts = msg.split(":", limit = 2)
if (parts.size == 2) {
val code = parts[0].trim()
val rest = parts[1].trim()
if (code.isNotEmpty() && code.all { it.isUpperCase() || it == '_' }) {
return InvokeResult.error(code = code, message = rest.ifEmpty { msg })
}
}
return InvokeResult.error(code = "UNAVAILABLE", message = msg)
}
private suspend fun connectOnce(endpoint: BridgeEndpoint, hello: Hello, tls: BridgeTlsParams?) =
withContext(Dispatchers.IO) {
if (tls != null) {
try {
connectWithSocket(endpoint, hello, tls)
return@withContext
} catch (err: Throwable) {
if (tls.required) throw err
}
}
connectWithSocket(endpoint, hello, null)
}
private suspend fun connectWithSocket(endpoint: BridgeEndpoint, hello: Hello, tls: BridgeTlsParams?) {
val socket =
createBridgeSocket(tls) { fingerprint ->
onTlsFingerprint?.invoke(tls?.stableId ?: endpoint.stableId, fingerprint)
}
socket.tcpNoDelay = true
socket.connect(InetSocketAddress(endpoint.host, endpoint.port), 8_000)
socket.soTimeout = 0
startTlsHandshakeIfNeeded(socket)
val reader = BufferedReader(InputStreamReader(socket.getInputStream(), Charsets.UTF_8))
val writer = BufferedWriter(OutputStreamWriter(socket.getOutputStream(), Charsets.UTF_8))
val conn = Connection(socket, reader, writer, writeLock)
currentConnection = conn
try {
conn.sendJson(buildHelloJson(hello))
val firstLine = reader.readLine() ?: throw IllegalStateException("bridge closed connection")
val first = json.parseToJsonElement(firstLine).asObjectOrNull()
?: throw IllegalStateException("unexpected bridge response")
when (first["type"].asStringOrNull()) {
"hello-ok" -> {
val name = first["serverName"].asStringOrNull() ?: "Bridge"
val rawCanvasUrl = first["canvasHostUrl"].asStringOrNull()?.trim()?.ifEmpty { null }
val rawMainSessionKey = first["mainSessionKey"].asStringOrNull()?.trim()?.ifEmpty { null }
canvasHostUrl = normalizeCanvasHostUrl(rawCanvasUrl, endpoint)
mainSessionKey = rawMainSessionKey
if (BuildConfig.DEBUG) {
// Local JVM unit tests use android.jar stubs; Log.d can throw "not mocked".
runCatching {
android.util.Log.d(
"ClawdbotBridge",
"canvasHostUrl resolved=${canvasHostUrl ?: "none"} (raw=${rawCanvasUrl ?: "none"})",
)
}
}
onConnected(name, conn.remoteAddress, rawMainSessionKey)
}
"error" -> {
val code = first["code"].asStringOrNull() ?: "UNAVAILABLE"
val msg = first["message"].asStringOrNull() ?: "connect failed"
throw IllegalStateException("$code: $msg")
}
else -> throw IllegalStateException("unexpected bridge response")
}
while (scope.isActive) {
val line = reader.readLine() ?: break
val frame = json.parseToJsonElement(line).asObjectOrNull() ?: continue
when (frame["type"].asStringOrNull()) {
"event" -> {
val event = frame["event"].asStringOrNull() ?: continue
val payload = frame["payloadJSON"].asStringOrNull()
onEvent(event, payload)
}
"ping" -> {
val id = frame["id"].asStringOrNull() ?: ""
conn.sendJson(buildJsonObject { put("type", JsonPrimitive("pong")); put("id", JsonPrimitive(id)) })
}
"res" -> {
val id = frame["id"].asStringOrNull() ?: continue
val ok = frame["ok"].asBooleanOrNull() ?: false
val payloadJson = frame["payloadJSON"].asStringOrNull()
val error =
frame["error"]?.let {
val obj = it.asObjectOrNull() ?: return@let null
val code = obj["code"].asStringOrNull() ?: "UNAVAILABLE"
val msg = obj["message"].asStringOrNull() ?: "request failed"
ErrorShape(code, msg)
}
pending.remove(id)?.complete(RpcResponse(id, ok, payloadJson, error))
}
"invoke" -> {
val id = frame["id"].asStringOrNull() ?: continue
val command = frame["command"].asStringOrNull() ?: ""
val params = frame["paramsJSON"].asStringOrNull()
val result =
try {
onInvoke(InvokeRequest(id, command, params))
} catch (err: Throwable) {
invokeErrorFromThrowable(err)
}
conn.sendJson(
buildJsonObject {
put("type", JsonPrimitive("invoke-res"))
put("id", JsonPrimitive(id))
put("ok", JsonPrimitive(result.ok))
if (result.payloadJson != null) put("payloadJSON", JsonPrimitive(result.payloadJson))
if (result.error != null) {
put(
"error",
buildJsonObject {
put("code", JsonPrimitive(result.error.code))
put("message", JsonPrimitive(result.error.message))
},
)
}
},
)
}
"invoke-res" -> {
// gateway->node only (ignore)
}
}
}
} finally {
currentConnection = null
for ((_, waiter) in pending) {
waiter.cancel()
}
pending.clear()
conn.closeQuietly()
}
}
private fun buildHelloJson(hello: Hello): JsonObject =
buildJsonObject {
put("type", JsonPrimitive("hello"))
put("nodeId", JsonPrimitive(hello.nodeId))
hello.displayName?.let { put("displayName", JsonPrimitive(it)) }
hello.token?.let { put("token", JsonPrimitive(it)) }
hello.platform?.let { put("platform", JsonPrimitive(it)) }
hello.version?.let { put("version", JsonPrimitive(it)) }
hello.deviceFamily?.let { put("deviceFamily", JsonPrimitive(it)) }
hello.modelIdentifier?.let { put("modelIdentifier", JsonPrimitive(it)) }
hello.caps?.let { put("caps", JsonArray(it.map(::JsonPrimitive))) }
hello.commands?.let { put("commands", JsonArray(it.map(::JsonPrimitive))) }
}
private fun normalizeCanvasHostUrl(raw: String?, endpoint: BridgeEndpoint): String? {
val trimmed = raw?.trim().orEmpty()
val parsed = trimmed.takeIf { it.isNotBlank() }?.let { runCatching { URI(it) }.getOrNull() }
val host = parsed?.host?.trim().orEmpty()
val port = parsed?.port ?: -1
val scheme = parsed?.scheme?.trim().orEmpty().ifBlank { "http" }
if (trimmed.isNotBlank() && !isLoopbackHost(host)) {
return trimmed
}
val fallbackHost =
endpoint.tailnetDns?.trim().takeIf { !it.isNullOrEmpty() }
?: endpoint.lanHost?.trim().takeIf { !it.isNullOrEmpty() }
?: endpoint.host.trim()
if (fallbackHost.isEmpty()) return trimmed.ifBlank { null }
val fallbackPort = endpoint.canvasPort ?: if (port > 0) port else 18793
val formattedHost = if (fallbackHost.contains(":")) "[${fallbackHost}]" else fallbackHost
return "$scheme://$formattedHost:$fallbackPort"
}
private fun isLoopbackHost(raw: String?): Boolean {
val host = raw?.trim()?.lowercase().orEmpty()
if (host.isEmpty()) return false
if (host == "localhost") return true
if (host == "::1") return true
if (host == "0.0.0.0" || host == "::") return true
return host.startsWith("127.")
}
}
private fun JsonElement?.asObjectOrNull(): JsonObject? = this as? JsonObject
private fun JsonElement?.asStringOrNull(): String? =
when (this) {
is JsonNull -> null
is JsonPrimitive -> content
else -> null
}
private fun JsonElement?.asBooleanOrNull(): Boolean? =
when (this) {
is JsonPrimitive -> {
val c = content.trim()
when {
c.equals("true", ignoreCase = true) -> true
c.equals("false", ignoreCase = true) -> false
else -> null
}
}
else -> null
}

View File

@@ -1,81 +0,0 @@
package com.clawdbot.android.bridge
import android.annotation.SuppressLint
import java.net.Socket
import java.security.MessageDigest
import java.security.SecureRandom
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLSocket
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509TrustManager
data class BridgeTlsParams(
val required: Boolean,
val expectedFingerprint: String?,
val allowTOFU: Boolean,
val stableId: String,
)
fun createBridgeSocket(params: BridgeTlsParams?, onStore: ((String) -> Unit)? = null): Socket {
if (params == null) return Socket()
val expected = params.expectedFingerprint?.let(::normalizeFingerprint)
val defaultTrust = defaultTrustManager()
@SuppressLint("CustomX509TrustManager")
val trustManager =
object : X509TrustManager {
override fun checkClientTrusted(chain: Array<X509Certificate>, authType: String) {
defaultTrust.checkClientTrusted(chain, authType)
}
override fun checkServerTrusted(chain: Array<X509Certificate>, authType: String) {
if (chain.isEmpty()) throw CertificateException("empty certificate chain")
val fingerprint = sha256Hex(chain[0].encoded)
if (expected != null) {
if (fingerprint != expected) {
throw CertificateException("bridge TLS fingerprint mismatch")
}
return
}
if (params.allowTOFU) {
onStore?.invoke(fingerprint)
return
}
defaultTrust.checkServerTrusted(chain, authType)
}
override fun getAcceptedIssuers(): Array<X509Certificate> = defaultTrust.acceptedIssuers
}
val context = SSLContext.getInstance("TLS")
context.init(null, arrayOf(trustManager), SecureRandom())
return context.socketFactory.createSocket()
}
fun startTlsHandshakeIfNeeded(socket: Socket) {
if (socket is SSLSocket) {
socket.startHandshake()
}
}
private fun defaultTrustManager(): X509TrustManager {
val factory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
factory.init(null as java.security.KeyStore?)
val trust =
factory.trustManagers.firstOrNull { it is X509TrustManager } as? X509TrustManager
return trust ?: throw IllegalStateException("No default X509TrustManager found")
}
private fun sha256Hex(data: ByteArray): String {
val digest = MessageDigest.getInstance("SHA-256").digest(data)
val out = StringBuilder(digest.size * 2)
for (byte in digest) {
out.append(String.format("%02x", byte))
}
return out.toString()
}
private fun normalizeFingerprint(raw: String): String {
return raw.lowercase().filter { it in '0'..'9' || it in 'a'..'f' }
}

View File

@@ -1,4 +1,4 @@
package com.clawdbot.android.bridge
package com.clawdbot.android.gateway
object BonjourEscapes {
fun decode(input: String): String {

View File

@@ -8,7 +8,6 @@ import android.net.nsd.NsdManager
import android.net.nsd.NsdServiceInfo
import android.os.CancellationSignal
import android.util.Log
import com.clawdbot.android.bridge.BonjourEscapes
import java.io.IOException
import java.net.InetSocketAddress
import java.nio.ByteBuffer

View File

@@ -1,14 +0,0 @@
package com.clawdbot.android.bridge
import io.kotest.core.spec.style.StringSpec
import io.kotest.matchers.shouldBe
class BridgeEndpointKotestTest : StringSpec({
"manual endpoint builds stable id + name" {
val endpoint = BridgeEndpoint.manual("10.0.0.5", 18790)
endpoint.stableId shouldBe "manual|10.0.0.5|18790"
endpoint.name shouldBe "10.0.0.5:18790"
endpoint.host shouldBe "10.0.0.5"
endpoint.port shouldBe 18790
}
})

View File

@@ -1,108 +0,0 @@
package com.clawdbot.android.bridge
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.runBlocking
import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue
import org.junit.Test
import java.io.BufferedReader
import java.io.BufferedWriter
import java.io.InputStreamReader
import java.io.OutputStreamWriter
import java.net.ServerSocket
class BridgePairingClientTest {
@Test
fun helloOkReturnsExistingToken() = runBlocking {
val serverSocket = ServerSocket(0)
val port = serverSocket.localPort
val server =
async(Dispatchers.IO) {
serverSocket.use { ss ->
val sock = ss.accept()
sock.use { s ->
val reader = BufferedReader(InputStreamReader(s.getInputStream(), Charsets.UTF_8))
val writer = BufferedWriter(OutputStreamWriter(s.getOutputStream(), Charsets.UTF_8))
val hello = reader.readLine()
assertTrue(hello.contains("\"type\":\"hello\""))
writer.write("""{"type":"hello-ok","serverName":"Test Bridge"}""")
writer.write("\n")
writer.flush()
}
}
}
val client = BridgePairingClient()
val res =
client.pairAndHello(
endpoint = BridgeEndpoint.manual(host = "127.0.0.1", port = port),
hello =
BridgePairingClient.Hello(
nodeId = "node-1",
displayName = "Android Node",
token = "token-123",
platform = "Android",
version = "test",
deviceFamily = "Android",
modelIdentifier = "SM-X000",
caps = null,
commands = null,
),
)
assertTrue(res.ok)
assertEquals("token-123", res.token)
server.await()
}
@Test
fun notPairedTriggersPairRequestAndReturnsToken() = runBlocking {
val serverSocket = ServerSocket(0)
val port = serverSocket.localPort
val server =
async(Dispatchers.IO) {
serverSocket.use { ss ->
val sock = ss.accept()
sock.use { s ->
val reader = BufferedReader(InputStreamReader(s.getInputStream(), Charsets.UTF_8))
val writer = BufferedWriter(OutputStreamWriter(s.getOutputStream(), Charsets.UTF_8))
reader.readLine() // hello
writer.write("""{"type":"error","code":"NOT_PAIRED","message":"not paired"}""")
writer.write("\n")
writer.flush()
val pairReq = reader.readLine()
assertTrue(pairReq.contains("\"type\":\"pair-request\""))
writer.write("""{"type":"pair-ok","token":"new-token"}""")
writer.write("\n")
writer.flush()
}
}
}
val client = BridgePairingClient()
val res =
client.pairAndHello(
endpoint = BridgeEndpoint.manual(host = "127.0.0.1", port = port),
hello =
BridgePairingClient.Hello(
nodeId = "node-1",
displayName = "Android Node",
token = null,
platform = "Android",
version = "test",
deviceFamily = "Android",
modelIdentifier = "SM-X000",
caps = null,
commands = null,
),
)
assertTrue(res.ok)
assertEquals("new-token", res.token)
server.await()
}
}

View File

@@ -1,307 +0,0 @@
package com.clawdbot.android.bridge
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.async
import kotlinx.coroutines.cancel
import kotlinx.coroutines.runBlocking
import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue
import org.junit.Test
import java.io.BufferedReader
import java.io.BufferedWriter
import java.io.InputStreamReader
import java.io.OutputStreamWriter
import java.net.ServerSocket
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
class BridgeSessionTest {
@Test
fun requestReturnsPayloadJson() = runBlocking {
val serverSocket = ServerSocket(0)
val port = serverSocket.localPort
val scope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
val connected = CompletableDeferred<Unit>()
val session =
BridgeSession(
scope = scope,
onConnected = { _, _, _ -> connected.complete(Unit) },
onDisconnected = { /* ignore */ },
onEvent = { _, _ -> /* ignore */ },
onInvoke = { BridgeSession.InvokeResult.ok(null) },
)
val server =
async(Dispatchers.IO) {
serverSocket.use { ss ->
val sock = ss.accept()
sock.use { s ->
val reader = BufferedReader(InputStreamReader(s.getInputStream(), Charsets.UTF_8))
val writer = BufferedWriter(OutputStreamWriter(s.getOutputStream(), Charsets.UTF_8))
val hello = reader.readLine()
assertTrue(hello.contains("\"type\":\"hello\""))
writer.write("""{"type":"hello-ok","serverName":"Test Bridge","canvasHostUrl":"http://127.0.0.1:18789"}""")
writer.write("\n")
writer.flush()
val req = reader.readLine()
assertTrue(req.contains("\"type\":\"req\""))
val id = extractJsonString(req, "id")
writer.write("""{"type":"res","id":"$id","ok":true,"payloadJSON":"{\"value\":123}"}""")
writer.write("\n")
writer.flush()
}
}
}
session.connect(
endpoint = BridgeEndpoint.manual(host = "127.0.0.1", port = port),
hello =
BridgeSession.Hello(
nodeId = "node-1",
displayName = "Android Node",
token = null,
platform = "Android",
version = "test",
deviceFamily = null,
modelIdentifier = null,
caps = null,
commands = null,
),
)
connected.await()
assertEquals("http://127.0.0.1:18789", session.currentCanvasHostUrl())
val payload = session.request(method = "health", paramsJson = null)
assertEquals("""{"value":123}""", payload)
server.await()
session.disconnect()
scope.cancel()
}
@Test
fun requestThrowsOnErrorResponse() = runBlocking {
val serverSocket = ServerSocket(0)
val port = serverSocket.localPort
val scope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
val connected = CompletableDeferred<Unit>()
val session =
BridgeSession(
scope = scope,
onConnected = { _, _, _ -> connected.complete(Unit) },
onDisconnected = { /* ignore */ },
onEvent = { _, _ -> /* ignore */ },
onInvoke = { BridgeSession.InvokeResult.ok(null) },
)
val server =
async(Dispatchers.IO) {
serverSocket.use { ss ->
val sock = ss.accept()
sock.use { s ->
val reader = BufferedReader(InputStreamReader(s.getInputStream(), Charsets.UTF_8))
val writer = BufferedWriter(OutputStreamWriter(s.getOutputStream(), Charsets.UTF_8))
reader.readLine() // hello
writer.write("""{"type":"hello-ok","serverName":"Test Bridge"}""")
writer.write("\n")
writer.flush()
val req = reader.readLine()
val id = extractJsonString(req, "id")
writer.write(
"""{"type":"res","id":"$id","ok":false,"error":{"code":"FORBIDDEN","message":"nope"}}""",
)
writer.write("\n")
writer.flush()
}
}
}
session.connect(
endpoint = BridgeEndpoint.manual(host = "127.0.0.1", port = port),
hello =
BridgeSession.Hello(
nodeId = "node-1",
displayName = "Android Node",
token = null,
platform = "Android",
version = "test",
deviceFamily = null,
modelIdentifier = null,
caps = null,
commands = null,
),
)
connected.await()
try {
session.request(method = "chat.history", paramsJson = """{"sessionKey":"main"}""")
throw AssertionError("expected request() to throw")
} catch (e: IllegalStateException) {
assertTrue(e.message?.contains("FORBIDDEN: nope") == true)
}
server.await()
session.disconnect()
scope.cancel()
}
@Test
fun invokeResReturnsErrorWhenHandlerThrows() = runBlocking {
val serverSocket = ServerSocket(0)
val port = serverSocket.localPort
val scope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
val connected = CompletableDeferred<Unit>()
val session =
BridgeSession(
scope = scope,
onConnected = { _, _, _ -> connected.complete(Unit) },
onDisconnected = { /* ignore */ },
onEvent = { _, _ -> /* ignore */ },
onInvoke = { throw IllegalStateException("FOO_BAR: boom") },
)
val invokeResLine = CompletableDeferred<String>()
val server =
async(Dispatchers.IO) {
serverSocket.use { ss ->
val sock = ss.accept()
sock.use { s ->
val reader = BufferedReader(InputStreamReader(s.getInputStream(), Charsets.UTF_8))
val writer = BufferedWriter(OutputStreamWriter(s.getOutputStream(), Charsets.UTF_8))
reader.readLine() // hello
writer.write("""{"type":"hello-ok","serverName":"Test Bridge"}""")
writer.write("\n")
writer.flush()
// Ask the node to invoke something; handler will throw.
writer.write("""{"type":"invoke","id":"i1","command":"canvas.snapshot","paramsJSON":null}""")
writer.write("\n")
writer.flush()
val res = reader.readLine()
invokeResLine.complete(res)
}
}
}
session.connect(
endpoint = BridgeEndpoint.manual(host = "127.0.0.1", port = port),
hello =
BridgeSession.Hello(
nodeId = "node-1",
displayName = "Android Node",
token = null,
platform = "Android",
version = "test",
deviceFamily = null,
modelIdentifier = null,
caps = null,
commands = null,
),
)
connected.await()
// Give the reader loop time to process.
val line = invokeResLine.await()
assertTrue(line.contains("\"type\":\"invoke-res\""))
assertTrue(line.contains("\"ok\":false"))
assertTrue(line.contains("\"code\":\"FOO_BAR\""))
assertTrue(line.contains("\"message\":\"boom\""))
server.await()
session.disconnect()
scope.cancel()
}
@Test(timeout = 12_000)
fun reconnectsAfterBridgeClosesDuringHello() = runBlocking {
val serverSocket = ServerSocket(0)
val port = serverSocket.localPort
val scope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
val connected = CountDownLatch(1)
val connectionsSeen = CountDownLatch(2)
val session =
BridgeSession(
scope = scope,
onConnected = { _, _, _ -> connected.countDown() },
onDisconnected = { /* ignore */ },
onEvent = { _, _ -> /* ignore */ },
onInvoke = { BridgeSession.InvokeResult.ok(null) },
)
val server =
async(Dispatchers.IO) {
serverSocket.use { ss ->
// First connection: read hello, then close (no response).
val sock1 = ss.accept()
sock1.use { s ->
val reader = BufferedReader(InputStreamReader(s.getInputStream(), Charsets.UTF_8))
reader.readLine() // hello
connectionsSeen.countDown()
}
// Second connection: complete hello.
val sock2 = ss.accept()
sock2.use { s ->
val reader = BufferedReader(InputStreamReader(s.getInputStream(), Charsets.UTF_8))
val writer = BufferedWriter(OutputStreamWriter(s.getOutputStream(), Charsets.UTF_8))
reader.readLine() // hello
writer.write("""{"type":"hello-ok","serverName":"Test Bridge"}""")
writer.write("\n")
writer.flush()
connectionsSeen.countDown()
Thread.sleep(200)
}
}
}
session.connect(
endpoint = BridgeEndpoint.manual(host = "127.0.0.1", port = port),
hello =
BridgeSession.Hello(
nodeId = "node-1",
displayName = "Android Node",
token = null,
platform = "Android",
version = "test",
deviceFamily = null,
modelIdentifier = null,
caps = null,
commands = null,
),
)
assertTrue("expected two connection attempts", connectionsSeen.await(8, TimeUnit.SECONDS))
assertTrue("expected session to connect", connected.await(8, TimeUnit.SECONDS))
session.disconnect()
scope.cancel()
server.await()
}
}
private fun extractJsonString(raw: String, key: String): String {
val needle = "\"$key\":\""
val start = raw.indexOf(needle)
if (start < 0) throw IllegalArgumentException("missing key $key in $raw")
val from = start + needle.length
val end = raw.indexOf('"', from)
if (end < 0) throw IllegalArgumentException("unterminated string for $key in $raw")
return raw.substring(from, end)
}

View File

@@ -1,4 +1,4 @@
package com.clawdbot.android.bridge
package com.clawdbot.android.gateway
import org.junit.Assert.assertEquals
import org.junit.Test