refactor(android): drop legacy bridge transport
This commit is contained in:
@@ -1,523 +0,0 @@
|
||||
package com.clawdbot.android.bridge
|
||||
|
||||
import android.content.Context
|
||||
import android.net.ConnectivityManager
|
||||
import android.net.DnsResolver
|
||||
import android.net.NetworkCapabilities
|
||||
import android.net.nsd.NsdManager
|
||||
import android.net.nsd.NsdServiceInfo
|
||||
import android.os.CancellationSignal
|
||||
import android.util.Log
|
||||
import java.io.IOException
|
||||
import java.net.InetSocketAddress
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.charset.CodingErrorAction
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.Executor
|
||||
import java.util.concurrent.Executors
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.suspendCancellableCoroutine
|
||||
import org.xbill.DNS.AAAARecord
|
||||
import org.xbill.DNS.ARecord
|
||||
import org.xbill.DNS.DClass
|
||||
import org.xbill.DNS.ExtendedResolver
|
||||
import org.xbill.DNS.Message
|
||||
import org.xbill.DNS.Name
|
||||
import org.xbill.DNS.PTRRecord
|
||||
import org.xbill.DNS.Record
|
||||
import org.xbill.DNS.Rcode
|
||||
import org.xbill.DNS.Resolver
|
||||
import org.xbill.DNS.SRVRecord
|
||||
import org.xbill.DNS.Section
|
||||
import org.xbill.DNS.SimpleResolver
|
||||
import org.xbill.DNS.TextParseException
|
||||
import org.xbill.DNS.TXTRecord
|
||||
import org.xbill.DNS.Type
|
||||
import kotlin.coroutines.resume
|
||||
import kotlin.coroutines.resumeWithException
|
||||
|
||||
@Suppress("DEPRECATION")
|
||||
class BridgeDiscovery(
|
||||
context: Context,
|
||||
private val scope: CoroutineScope,
|
||||
) {
|
||||
private val nsd = context.getSystemService(NsdManager::class.java)
|
||||
private val connectivity = context.getSystemService(ConnectivityManager::class.java)
|
||||
private val dns = DnsResolver.getInstance()
|
||||
private val serviceType = "_clawdbot-bridge._tcp."
|
||||
private val wideAreaDomain = "clawdbot.internal."
|
||||
private val logTag = "Clawdbot/BridgeDiscovery"
|
||||
|
||||
private val localById = ConcurrentHashMap<String, BridgeEndpoint>()
|
||||
private val unicastById = ConcurrentHashMap<String, BridgeEndpoint>()
|
||||
private val _bridges = MutableStateFlow<List<BridgeEndpoint>>(emptyList())
|
||||
val bridges: StateFlow<List<BridgeEndpoint>> = _bridges.asStateFlow()
|
||||
|
||||
private val _statusText = MutableStateFlow("Searching…")
|
||||
val statusText: StateFlow<String> = _statusText.asStateFlow()
|
||||
|
||||
private var unicastJob: Job? = null
|
||||
private val dnsExecutor: Executor = Executors.newCachedThreadPool()
|
||||
|
||||
@Volatile private var lastWideAreaRcode: Int? = null
|
||||
@Volatile private var lastWideAreaCount: Int = 0
|
||||
|
||||
private val discoveryListener =
|
||||
object : NsdManager.DiscoveryListener {
|
||||
override fun onStartDiscoveryFailed(serviceType: String, errorCode: Int) {}
|
||||
override fun onStopDiscoveryFailed(serviceType: String, errorCode: Int) {}
|
||||
override fun onDiscoveryStarted(serviceType: String) {}
|
||||
override fun onDiscoveryStopped(serviceType: String) {}
|
||||
|
||||
override fun onServiceFound(serviceInfo: NsdServiceInfo) {
|
||||
if (serviceInfo.serviceType != this@BridgeDiscovery.serviceType) return
|
||||
resolve(serviceInfo)
|
||||
}
|
||||
|
||||
override fun onServiceLost(serviceInfo: NsdServiceInfo) {
|
||||
val serviceName = BonjourEscapes.decode(serviceInfo.serviceName)
|
||||
val id = stableId(serviceName, "local.")
|
||||
localById.remove(id)
|
||||
publish()
|
||||
}
|
||||
}
|
||||
|
||||
init {
|
||||
startLocalDiscovery()
|
||||
startUnicastDiscovery(wideAreaDomain)
|
||||
}
|
||||
|
||||
private fun startLocalDiscovery() {
|
||||
try {
|
||||
nsd.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, discoveryListener)
|
||||
} catch (_: Throwable) {
|
||||
// ignore (best-effort)
|
||||
}
|
||||
}
|
||||
|
||||
private fun stopLocalDiscovery() {
|
||||
try {
|
||||
nsd.stopServiceDiscovery(discoveryListener)
|
||||
} catch (_: Throwable) {
|
||||
// ignore (best-effort)
|
||||
}
|
||||
}
|
||||
|
||||
private fun startUnicastDiscovery(domain: String) {
|
||||
unicastJob =
|
||||
scope.launch(Dispatchers.IO) {
|
||||
while (true) {
|
||||
try {
|
||||
refreshUnicast(domain)
|
||||
} catch (_: Throwable) {
|
||||
// ignore (best-effort)
|
||||
}
|
||||
delay(5000)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun resolve(serviceInfo: NsdServiceInfo) {
|
||||
nsd.resolveService(
|
||||
serviceInfo,
|
||||
object : NsdManager.ResolveListener {
|
||||
override fun onResolveFailed(serviceInfo: NsdServiceInfo, errorCode: Int) {}
|
||||
|
||||
override fun onServiceResolved(resolved: NsdServiceInfo) {
|
||||
val host = resolved.host?.hostAddress ?: return
|
||||
val port = resolved.port
|
||||
if (port <= 0) return
|
||||
|
||||
val rawServiceName = resolved.serviceName
|
||||
val serviceName = BonjourEscapes.decode(rawServiceName)
|
||||
val displayName = BonjourEscapes.decode(txt(resolved, "displayName") ?: serviceName)
|
||||
val lanHost = txt(resolved, "lanHost")
|
||||
val tailnetDns = txt(resolved, "tailnetDns")
|
||||
val gatewayPort = txtInt(resolved, "gatewayPort")
|
||||
val bridgePort = txtInt(resolved, "bridgePort")
|
||||
val canvasPort = txtInt(resolved, "canvasPort")
|
||||
val tlsEnabled = txtBool(resolved, "bridgeTls")
|
||||
val tlsFingerprint = txt(resolved, "bridgeTlsSha256")
|
||||
val id = stableId(serviceName, "local.")
|
||||
localById[id] =
|
||||
BridgeEndpoint(
|
||||
stableId = id,
|
||||
name = displayName,
|
||||
host = host,
|
||||
port = port,
|
||||
lanHost = lanHost,
|
||||
tailnetDns = tailnetDns,
|
||||
gatewayPort = gatewayPort,
|
||||
bridgePort = bridgePort,
|
||||
canvasPort = canvasPort,
|
||||
tlsEnabled = tlsEnabled,
|
||||
tlsFingerprintSha256 = tlsFingerprint,
|
||||
)
|
||||
publish()
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
private fun publish() {
|
||||
_bridges.value =
|
||||
(localById.values + unicastById.values).sortedBy { it.name.lowercase() }
|
||||
_statusText.value = buildStatusText()
|
||||
}
|
||||
|
||||
private fun buildStatusText(): String {
|
||||
val localCount = localById.size
|
||||
val wideRcode = lastWideAreaRcode
|
||||
val wideCount = lastWideAreaCount
|
||||
|
||||
val wide =
|
||||
when (wideRcode) {
|
||||
null -> "Wide: ?"
|
||||
Rcode.NOERROR -> "Wide: $wideCount"
|
||||
Rcode.NXDOMAIN -> "Wide: NXDOMAIN"
|
||||
else -> "Wide: ${Rcode.string(wideRcode)}"
|
||||
}
|
||||
|
||||
return when {
|
||||
localCount == 0 && wideRcode == null -> "Searching for bridges…"
|
||||
localCount == 0 -> "$wide"
|
||||
else -> "Local: $localCount • $wide"
|
||||
}
|
||||
}
|
||||
|
||||
private fun stableId(serviceName: String, domain: String): String {
|
||||
return "${serviceType}|${domain}|${normalizeName(serviceName)}"
|
||||
}
|
||||
|
||||
private fun normalizeName(raw: String): String {
|
||||
return raw.trim().split(Regex("\\s+")).joinToString(" ")
|
||||
}
|
||||
|
||||
private fun txt(info: NsdServiceInfo, key: String): String? {
|
||||
val bytes = info.attributes[key] ?: return null
|
||||
return try {
|
||||
String(bytes, Charsets.UTF_8).trim().ifEmpty { null }
|
||||
} catch (_: Throwable) {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
private fun txtInt(info: NsdServiceInfo, key: String): Int? {
|
||||
return txt(info, key)?.toIntOrNull()
|
||||
}
|
||||
|
||||
private fun txtBool(info: NsdServiceInfo, key: String): Boolean {
|
||||
val raw = txt(info, key)?.trim()?.lowercase() ?: return false
|
||||
return raw == "1" || raw == "true" || raw == "yes"
|
||||
}
|
||||
|
||||
private suspend fun refreshUnicast(domain: String) {
|
||||
val ptrName = "${serviceType}${domain}"
|
||||
val ptrMsg = lookupUnicastMessage(ptrName, Type.PTR) ?: return
|
||||
val ptrRecords = records(ptrMsg, Section.ANSWER).mapNotNull { it as? PTRRecord }
|
||||
|
||||
val next = LinkedHashMap<String, BridgeEndpoint>()
|
||||
for (ptr in ptrRecords) {
|
||||
val instanceFqdn = ptr.target.toString()
|
||||
val srv =
|
||||
recordByName(ptrMsg, instanceFqdn, Type.SRV) as? SRVRecord
|
||||
?: run {
|
||||
val msg = lookupUnicastMessage(instanceFqdn, Type.SRV) ?: return@run null
|
||||
recordByName(msg, instanceFqdn, Type.SRV) as? SRVRecord
|
||||
}
|
||||
?: continue
|
||||
val port = srv.port
|
||||
if (port <= 0) continue
|
||||
|
||||
val targetFqdn = srv.target.toString()
|
||||
val host =
|
||||
resolveHostFromMessage(ptrMsg, targetFqdn)
|
||||
?: resolveHostFromMessage(lookupUnicastMessage(instanceFqdn, Type.SRV), targetFqdn)
|
||||
?: resolveHostUnicast(targetFqdn)
|
||||
?: continue
|
||||
|
||||
val txtFromPtr =
|
||||
recordsByName(ptrMsg, Section.ADDITIONAL)[keyName(instanceFqdn)]
|
||||
.orEmpty()
|
||||
.mapNotNull { it as? TXTRecord }
|
||||
val txt =
|
||||
if (txtFromPtr.isNotEmpty()) {
|
||||
txtFromPtr
|
||||
} else {
|
||||
val msg = lookupUnicastMessage(instanceFqdn, Type.TXT)
|
||||
records(msg, Section.ANSWER).mapNotNull { it as? TXTRecord }
|
||||
}
|
||||
val instanceName = BonjourEscapes.decode(decodeInstanceName(instanceFqdn, domain))
|
||||
val displayName = BonjourEscapes.decode(txtValue(txt, "displayName") ?: instanceName)
|
||||
val lanHost = txtValue(txt, "lanHost")
|
||||
val tailnetDns = txtValue(txt, "tailnetDns")
|
||||
val gatewayPort = txtIntValue(txt, "gatewayPort")
|
||||
val bridgePort = txtIntValue(txt, "bridgePort")
|
||||
val canvasPort = txtIntValue(txt, "canvasPort")
|
||||
val tlsEnabled = txtBoolValue(txt, "bridgeTls")
|
||||
val tlsFingerprint = txtValue(txt, "bridgeTlsSha256")
|
||||
val id = stableId(instanceName, domain)
|
||||
next[id] =
|
||||
BridgeEndpoint(
|
||||
stableId = id,
|
||||
name = displayName,
|
||||
host = host,
|
||||
port = port,
|
||||
lanHost = lanHost,
|
||||
tailnetDns = tailnetDns,
|
||||
gatewayPort = gatewayPort,
|
||||
bridgePort = bridgePort,
|
||||
canvasPort = canvasPort,
|
||||
tlsEnabled = tlsEnabled,
|
||||
tlsFingerprintSha256 = tlsFingerprint,
|
||||
)
|
||||
}
|
||||
|
||||
unicastById.clear()
|
||||
unicastById.putAll(next)
|
||||
lastWideAreaRcode = ptrMsg.header.rcode
|
||||
lastWideAreaCount = next.size
|
||||
publish()
|
||||
|
||||
if (next.isEmpty()) {
|
||||
Log.d(
|
||||
logTag,
|
||||
"wide-area discovery: 0 results for $ptrName (rcode=${Rcode.string(ptrMsg.header.rcode)})",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private fun decodeInstanceName(instanceFqdn: String, domain: String): String {
|
||||
val suffix = "${serviceType}${domain}"
|
||||
val withoutSuffix =
|
||||
if (instanceFqdn.endsWith(suffix)) {
|
||||
instanceFqdn.removeSuffix(suffix)
|
||||
} else {
|
||||
instanceFqdn.substringBefore(serviceType)
|
||||
}
|
||||
return normalizeName(stripTrailingDot(withoutSuffix))
|
||||
}
|
||||
|
||||
private fun stripTrailingDot(raw: String): String {
|
||||
return raw.removeSuffix(".")
|
||||
}
|
||||
|
||||
private suspend fun lookupUnicastMessage(name: String, type: Int): Message? {
|
||||
val query =
|
||||
try {
|
||||
Message.newQuery(
|
||||
org.xbill.DNS.Record.newRecord(
|
||||
Name.fromString(name),
|
||||
type,
|
||||
DClass.IN,
|
||||
),
|
||||
)
|
||||
} catch (_: TextParseException) {
|
||||
return null
|
||||
}
|
||||
|
||||
val system = queryViaSystemDns(query)
|
||||
if (records(system, Section.ANSWER).any { it.type == type }) return system
|
||||
|
||||
val direct = createDirectResolver() ?: return system
|
||||
return try {
|
||||
val msg = direct.send(query)
|
||||
if (records(msg, Section.ANSWER).any { it.type == type }) msg else system
|
||||
} catch (_: Throwable) {
|
||||
system
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun queryViaSystemDns(query: Message): Message? {
|
||||
val network = preferredDnsNetwork()
|
||||
val bytes =
|
||||
try {
|
||||
rawQuery(network, query.toWire())
|
||||
} catch (_: Throwable) {
|
||||
return null
|
||||
}
|
||||
|
||||
return try {
|
||||
Message(bytes)
|
||||
} catch (_: IOException) {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
private fun records(msg: Message?, section: Int): List<Record> {
|
||||
return msg?.getSectionArray(section)?.toList() ?: emptyList()
|
||||
}
|
||||
|
||||
private fun keyName(raw: String): String {
|
||||
return raw.trim().lowercase()
|
||||
}
|
||||
|
||||
private fun recordsByName(msg: Message, section: Int): Map<String, List<Record>> {
|
||||
val next = LinkedHashMap<String, MutableList<Record>>()
|
||||
for (r in records(msg, section)) {
|
||||
val name = r.name?.toString() ?: continue
|
||||
next.getOrPut(keyName(name)) { mutableListOf() }.add(r)
|
||||
}
|
||||
return next
|
||||
}
|
||||
|
||||
private fun recordByName(msg: Message, fqdn: String, type: Int): Record? {
|
||||
val key = keyName(fqdn)
|
||||
val byNameAnswer = recordsByName(msg, Section.ANSWER)
|
||||
val fromAnswer = byNameAnswer[key].orEmpty().firstOrNull { it.type == type }
|
||||
if (fromAnswer != null) return fromAnswer
|
||||
|
||||
val byNameAdditional = recordsByName(msg, Section.ADDITIONAL)
|
||||
return byNameAdditional[key].orEmpty().firstOrNull { it.type == type }
|
||||
}
|
||||
|
||||
private fun resolveHostFromMessage(msg: Message?, hostname: String): String? {
|
||||
val m = msg ?: return null
|
||||
val key = keyName(hostname)
|
||||
val additional = recordsByName(m, Section.ADDITIONAL)[key].orEmpty()
|
||||
val a = additional.mapNotNull { it as? ARecord }.mapNotNull { it.address?.hostAddress }
|
||||
val aaaa = additional.mapNotNull { it as? AAAARecord }.mapNotNull { it.address?.hostAddress }
|
||||
return a.firstOrNull() ?: aaaa.firstOrNull()
|
||||
}
|
||||
|
||||
private fun preferredDnsNetwork(): android.net.Network? {
|
||||
val cm = connectivity ?: return null
|
||||
|
||||
// Prefer VPN (Tailscale) when present; otherwise use the active network.
|
||||
cm.allNetworks.firstOrNull { n ->
|
||||
val caps = cm.getNetworkCapabilities(n) ?: return@firstOrNull false
|
||||
caps.hasTransport(NetworkCapabilities.TRANSPORT_VPN)
|
||||
}?.let { return it }
|
||||
|
||||
return cm.activeNetwork
|
||||
}
|
||||
|
||||
private fun createDirectResolver(): Resolver? {
|
||||
val cm = connectivity ?: return null
|
||||
|
||||
val candidateNetworks =
|
||||
buildList {
|
||||
cm.allNetworks
|
||||
.firstOrNull { n ->
|
||||
val caps = cm.getNetworkCapabilities(n) ?: return@firstOrNull false
|
||||
caps.hasTransport(NetworkCapabilities.TRANSPORT_VPN)
|
||||
}?.let(::add)
|
||||
cm.activeNetwork?.let(::add)
|
||||
}.distinct()
|
||||
|
||||
val servers =
|
||||
candidateNetworks
|
||||
.asSequence()
|
||||
.flatMap { n ->
|
||||
cm.getLinkProperties(n)?.dnsServers?.asSequence() ?: emptySequence()
|
||||
}
|
||||
.distinctBy { it.hostAddress ?: it.toString() }
|
||||
.toList()
|
||||
if (servers.isEmpty()) return null
|
||||
|
||||
return try {
|
||||
val resolvers =
|
||||
servers.mapNotNull { addr ->
|
||||
try {
|
||||
SimpleResolver().apply {
|
||||
setAddress(InetSocketAddress(addr, 53))
|
||||
setTimeout(3)
|
||||
}
|
||||
} catch (_: Throwable) {
|
||||
null
|
||||
}
|
||||
}
|
||||
if (resolvers.isEmpty()) return null
|
||||
ExtendedResolver(resolvers.toTypedArray()).apply { setTimeout(3) }
|
||||
} catch (_: Throwable) {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun rawQuery(network: android.net.Network?, wireQuery: ByteArray): ByteArray =
|
||||
suspendCancellableCoroutine { cont ->
|
||||
val signal = CancellationSignal()
|
||||
cont.invokeOnCancellation { signal.cancel() }
|
||||
|
||||
dns.rawQuery(
|
||||
network,
|
||||
wireQuery,
|
||||
DnsResolver.FLAG_EMPTY,
|
||||
dnsExecutor,
|
||||
signal,
|
||||
object : DnsResolver.Callback<ByteArray> {
|
||||
override fun onAnswer(answer: ByteArray, rcode: Int) {
|
||||
cont.resume(answer)
|
||||
}
|
||||
|
||||
override fun onError(error: DnsResolver.DnsException) {
|
||||
cont.resumeWithException(error)
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
private fun txtValue(records: List<TXTRecord>, key: String): String? {
|
||||
val prefix = "$key="
|
||||
for (r in records) {
|
||||
val strings: List<String> =
|
||||
try {
|
||||
r.strings.mapNotNull { it as? String }
|
||||
} catch (_: Throwable) {
|
||||
emptyList()
|
||||
}
|
||||
for (s in strings) {
|
||||
val trimmed = decodeDnsTxtString(s).trim()
|
||||
if (trimmed.startsWith(prefix)) {
|
||||
return trimmed.removePrefix(prefix).trim().ifEmpty { null }
|
||||
}
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
private fun txtIntValue(records: List<TXTRecord>, key: String): Int? {
|
||||
return txtValue(records, key)?.toIntOrNull()
|
||||
}
|
||||
|
||||
private fun txtBoolValue(records: List<TXTRecord>, key: String): Boolean {
|
||||
val raw = txtValue(records, key)?.trim()?.lowercase() ?: return false
|
||||
return raw == "1" || raw == "true" || raw == "yes"
|
||||
}
|
||||
|
||||
private fun decodeDnsTxtString(raw: String): String {
|
||||
// dnsjava treats TXT as opaque bytes and decodes as ISO-8859-1 to preserve bytes.
|
||||
// Our TXT payload is UTF-8 (written by the gateway), so re-decode when possible.
|
||||
val bytes = raw.toByteArray(Charsets.ISO_8859_1)
|
||||
val decoder =
|
||||
Charsets.UTF_8
|
||||
.newDecoder()
|
||||
.onMalformedInput(CodingErrorAction.REPORT)
|
||||
.onUnmappableCharacter(CodingErrorAction.REPORT)
|
||||
return try {
|
||||
decoder.decode(ByteBuffer.wrap(bytes)).toString()
|
||||
} catch (_: Throwable) {
|
||||
raw
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun resolveHostUnicast(hostname: String): String? {
|
||||
val a =
|
||||
records(lookupUnicastMessage(hostname, Type.A), Section.ANSWER)
|
||||
.mapNotNull { it as? ARecord }
|
||||
.mapNotNull { it.address?.hostAddress }
|
||||
val aaaa =
|
||||
records(lookupUnicastMessage(hostname, Type.AAAA), Section.ANSWER)
|
||||
.mapNotNull { it as? AAAARecord }
|
||||
.mapNotNull { it.address?.hostAddress }
|
||||
|
||||
return a.firstOrNull() ?: aaaa.firstOrNull()
|
||||
}
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
package com.clawdbot.android.bridge
|
||||
|
||||
data class BridgeEndpoint(
|
||||
val stableId: String,
|
||||
val name: String,
|
||||
val host: String,
|
||||
val port: Int,
|
||||
val lanHost: String? = null,
|
||||
val tailnetDns: String? = null,
|
||||
val gatewayPort: Int? = null,
|
||||
val bridgePort: Int? = null,
|
||||
val canvasPort: Int? = null,
|
||||
val tlsEnabled: Boolean = false,
|
||||
val tlsFingerprintSha256: String? = null,
|
||||
) {
|
||||
companion object {
|
||||
fun manual(host: String, port: Int): BridgeEndpoint =
|
||||
BridgeEndpoint(
|
||||
stableId = "manual|$host|$port",
|
||||
name = "$host:$port",
|
||||
host = host,
|
||||
port = port,
|
||||
tlsEnabled = false,
|
||||
tlsFingerprintSha256 = null,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1,158 +0,0 @@
|
||||
package com.clawdbot.android.bridge
|
||||
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.json.JsonArray
|
||||
import kotlinx.serialization.json.JsonObject
|
||||
import kotlinx.serialization.json.JsonPrimitive
|
||||
import kotlinx.serialization.json.JsonNull
|
||||
import kotlinx.serialization.json.JsonElement
|
||||
import kotlinx.serialization.json.buildJsonObject
|
||||
import java.io.BufferedReader
|
||||
import java.io.BufferedWriter
|
||||
import java.io.InputStreamReader
|
||||
import java.io.OutputStreamWriter
|
||||
import java.net.InetSocketAddress
|
||||
|
||||
class BridgePairingClient {
|
||||
private val json = Json { ignoreUnknownKeys = true }
|
||||
|
||||
data class Hello(
|
||||
val nodeId: String,
|
||||
val displayName: String?,
|
||||
val token: String?,
|
||||
val platform: String?,
|
||||
val version: String?,
|
||||
val deviceFamily: String?,
|
||||
val modelIdentifier: String?,
|
||||
val caps: List<String>?,
|
||||
val commands: List<String>?,
|
||||
)
|
||||
|
||||
data class PairResult(val ok: Boolean, val token: String?, val error: String? = null)
|
||||
|
||||
suspend fun pairAndHello(
|
||||
endpoint: BridgeEndpoint,
|
||||
hello: Hello,
|
||||
tls: BridgeTlsParams? = null,
|
||||
onTlsFingerprint: ((String) -> Unit)? = null,
|
||||
): PairResult =
|
||||
withContext(Dispatchers.IO) {
|
||||
if (tls != null) {
|
||||
try {
|
||||
return@withContext pairAndHelloWithTls(endpoint, hello, tls, onTlsFingerprint)
|
||||
} catch (e: Exception) {
|
||||
if (tls.required) throw e
|
||||
}
|
||||
}
|
||||
pairAndHelloWithTls(endpoint, hello, null, null)
|
||||
}
|
||||
|
||||
private fun pairAndHelloWithTls(
|
||||
endpoint: BridgeEndpoint,
|
||||
hello: Hello,
|
||||
tls: BridgeTlsParams?,
|
||||
onTlsFingerprint: ((String) -> Unit)?,
|
||||
): PairResult {
|
||||
val socket =
|
||||
createBridgeSocket(tls) { fingerprint ->
|
||||
onTlsFingerprint?.invoke(fingerprint)
|
||||
}
|
||||
socket.tcpNoDelay = true
|
||||
try {
|
||||
socket.connect(InetSocketAddress(endpoint.host, endpoint.port), 8_000)
|
||||
socket.soTimeout = 60_000
|
||||
startTlsHandshakeIfNeeded(socket)
|
||||
|
||||
val reader = BufferedReader(InputStreamReader(socket.getInputStream(), Charsets.UTF_8))
|
||||
val writer = BufferedWriter(OutputStreamWriter(socket.getOutputStream(), Charsets.UTF_8))
|
||||
|
||||
fun send(line: String) {
|
||||
writer.write(line)
|
||||
writer.write("\n")
|
||||
writer.flush()
|
||||
}
|
||||
|
||||
fun sendJson(obj: JsonObject) = send(obj.toString())
|
||||
|
||||
sendJson(
|
||||
buildJsonObject {
|
||||
put("type", JsonPrimitive("hello"))
|
||||
put("nodeId", JsonPrimitive(hello.nodeId))
|
||||
hello.displayName?.let { put("displayName", JsonPrimitive(it)) }
|
||||
hello.token?.let { put("token", JsonPrimitive(it)) }
|
||||
hello.platform?.let { put("platform", JsonPrimitive(it)) }
|
||||
hello.version?.let { put("version", JsonPrimitive(it)) }
|
||||
hello.deviceFamily?.let { put("deviceFamily", JsonPrimitive(it)) }
|
||||
hello.modelIdentifier?.let { put("modelIdentifier", JsonPrimitive(it)) }
|
||||
hello.caps?.let { put("caps", JsonArray(it.map(::JsonPrimitive))) }
|
||||
hello.commands?.let { put("commands", JsonArray(it.map(::JsonPrimitive))) }
|
||||
},
|
||||
)
|
||||
|
||||
val firstObj = json.parseToJsonElement(reader.readLine()).asObjectOrNull()
|
||||
?: return PairResult(ok = false, token = null, error = "unexpected bridge response")
|
||||
return when (firstObj["type"].asStringOrNull()) {
|
||||
"hello-ok" -> PairResult(ok = true, token = hello.token)
|
||||
"error" -> {
|
||||
val code = firstObj["code"].asStringOrNull() ?: "UNAVAILABLE"
|
||||
val message = firstObj["message"].asStringOrNull() ?: "pairing required"
|
||||
if (code != "NOT_PAIRED" && code != "UNAUTHORIZED") {
|
||||
return PairResult(ok = false, token = null, error = "$code: $message")
|
||||
}
|
||||
|
||||
sendJson(
|
||||
buildJsonObject {
|
||||
put("type", JsonPrimitive("pair-request"))
|
||||
put("nodeId", JsonPrimitive(hello.nodeId))
|
||||
hello.displayName?.let { put("displayName", JsonPrimitive(it)) }
|
||||
hello.platform?.let { put("platform", JsonPrimitive(it)) }
|
||||
hello.version?.let { put("version", JsonPrimitive(it)) }
|
||||
hello.deviceFamily?.let { put("deviceFamily", JsonPrimitive(it)) }
|
||||
hello.modelIdentifier?.let { put("modelIdentifier", JsonPrimitive(it)) }
|
||||
hello.caps?.let { put("caps", JsonArray(it.map(::JsonPrimitive))) }
|
||||
hello.commands?.let { put("commands", JsonArray(it.map(::JsonPrimitive))) }
|
||||
},
|
||||
)
|
||||
|
||||
while (true) {
|
||||
val nextLine = reader.readLine() ?: break
|
||||
val next = json.parseToJsonElement(nextLine).asObjectOrNull() ?: continue
|
||||
when (next["type"].asStringOrNull()) {
|
||||
"pair-ok" -> {
|
||||
val token = next["token"].asStringOrNull()
|
||||
return PairResult(ok = !token.isNullOrBlank(), token = token)
|
||||
}
|
||||
"error" -> {
|
||||
val c = next["code"].asStringOrNull() ?: "UNAVAILABLE"
|
||||
val m = next["message"].asStringOrNull() ?: "pairing failed"
|
||||
return PairResult(ok = false, token = null, error = "$c: $m")
|
||||
}
|
||||
}
|
||||
}
|
||||
PairResult(ok = false, token = null, error = "pairing failed")
|
||||
}
|
||||
else -> PairResult(ok = false, token = null, error = "unexpected bridge response")
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
val message = e.message?.trim().orEmpty().ifEmpty { "gateway unreachable" }
|
||||
return PairResult(ok = false, token = null, error = message)
|
||||
} finally {
|
||||
try {
|
||||
socket.close()
|
||||
} catch (_: Throwable) {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun JsonElement?.asObjectOrNull(): JsonObject? = this as? JsonObject
|
||||
|
||||
private fun JsonElement?.asStringOrNull(): String? =
|
||||
when (this) {
|
||||
is JsonNull -> null
|
||||
is JsonPrimitive -> content
|
||||
else -> null
|
||||
}
|
||||
@@ -1,398 +0,0 @@
|
||||
package com.clawdbot.android.bridge
|
||||
|
||||
import kotlinx.coroutines.CompletableDeferred
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.cancelAndJoin
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.isActive
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.sync.Mutex
|
||||
import kotlinx.coroutines.sync.withLock
|
||||
import kotlinx.coroutines.withContext
|
||||
import com.clawdbot.android.BuildConfig
|
||||
import kotlinx.serialization.json.Json
|
||||
import kotlinx.serialization.json.JsonArray
|
||||
import kotlinx.serialization.json.JsonObject
|
||||
import kotlinx.serialization.json.JsonNull
|
||||
import kotlinx.serialization.json.JsonElement
|
||||
import kotlinx.serialization.json.JsonPrimitive
|
||||
import kotlinx.serialization.json.buildJsonObject
|
||||
import java.io.BufferedReader
|
||||
import java.io.BufferedWriter
|
||||
import java.io.InputStreamReader
|
||||
import java.io.OutputStreamWriter
|
||||
import java.net.InetSocketAddress
|
||||
import java.net.URI
|
||||
import java.net.Socket
|
||||
import java.util.UUID
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
|
||||
class BridgeSession(
|
||||
private val scope: CoroutineScope,
|
||||
private val onConnected: (serverName: String, remoteAddress: String?, mainSessionKey: String?) -> Unit,
|
||||
private val onDisconnected: (message: String) -> Unit,
|
||||
private val onEvent: (event: String, payloadJson: String?) -> Unit,
|
||||
private val onInvoke: suspend (InvokeRequest) -> InvokeResult,
|
||||
private val onTlsFingerprint: ((stableId: String, fingerprint: String) -> Unit)? = null,
|
||||
) {
|
||||
data class Hello(
|
||||
val nodeId: String,
|
||||
val displayName: String?,
|
||||
val token: String?,
|
||||
val platform: String?,
|
||||
val version: String?,
|
||||
val deviceFamily: String?,
|
||||
val modelIdentifier: String?,
|
||||
val caps: List<String>?,
|
||||
val commands: List<String>?,
|
||||
)
|
||||
|
||||
data class InvokeRequest(val id: String, val command: String, val paramsJson: String?)
|
||||
|
||||
data class InvokeResult(val ok: Boolean, val payloadJson: String?, val error: ErrorShape?) {
|
||||
companion object {
|
||||
fun ok(payloadJson: String?) = InvokeResult(ok = true, payloadJson = payloadJson, error = null)
|
||||
fun error(code: String, message: String) =
|
||||
InvokeResult(ok = false, payloadJson = null, error = ErrorShape(code = code, message = message))
|
||||
}
|
||||
}
|
||||
|
||||
data class ErrorShape(val code: String, val message: String)
|
||||
|
||||
private val json = Json { ignoreUnknownKeys = true }
|
||||
private val writeLock = Mutex()
|
||||
private val pending = ConcurrentHashMap<String, CompletableDeferred<RpcResponse>>()
|
||||
@Volatile private var canvasHostUrl: String? = null
|
||||
@Volatile private var mainSessionKey: String? = null
|
||||
|
||||
private data class DesiredConnection(
|
||||
val endpoint: BridgeEndpoint,
|
||||
val hello: Hello,
|
||||
val tls: BridgeTlsParams?,
|
||||
)
|
||||
|
||||
private var desired: DesiredConnection? = null
|
||||
private var job: Job? = null
|
||||
|
||||
fun connect(endpoint: BridgeEndpoint, hello: Hello, tls: BridgeTlsParams? = null) {
|
||||
desired = DesiredConnection(endpoint, hello, tls)
|
||||
if (job == null) {
|
||||
job = scope.launch(Dispatchers.IO) { runLoop() }
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun updateHello(hello: Hello) {
|
||||
val target = desired ?: return
|
||||
desired = target.copy(hello = hello)
|
||||
val conn = currentConnection ?: return
|
||||
conn.sendJson(buildHelloJson(hello))
|
||||
}
|
||||
|
||||
fun disconnect() {
|
||||
desired = null
|
||||
// Unblock connectOnce() read loop. Coroutine cancellation alone won't interrupt BufferedReader.readLine().
|
||||
currentConnection?.closeQuietly()
|
||||
scope.launch(Dispatchers.IO) {
|
||||
job?.cancelAndJoin()
|
||||
job = null
|
||||
canvasHostUrl = null
|
||||
mainSessionKey = null
|
||||
onDisconnected("Offline")
|
||||
}
|
||||
}
|
||||
|
||||
fun currentCanvasHostUrl(): String? = canvasHostUrl
|
||||
fun currentMainSessionKey(): String? = mainSessionKey
|
||||
|
||||
suspend fun sendEvent(event: String, payloadJson: String?) {
|
||||
val conn = currentConnection ?: return
|
||||
conn.sendJson(
|
||||
buildJsonObject {
|
||||
put("type", JsonPrimitive("event"))
|
||||
put("event", JsonPrimitive(event))
|
||||
if (payloadJson != null) put("payloadJSON", JsonPrimitive(payloadJson)) else put("payloadJSON", JsonNull)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
suspend fun request(method: String, paramsJson: String?): String {
|
||||
val conn = currentConnection ?: throw IllegalStateException("not connected")
|
||||
val id = UUID.randomUUID().toString()
|
||||
val deferred = CompletableDeferred<RpcResponse>()
|
||||
pending[id] = deferred
|
||||
conn.sendJson(
|
||||
buildJsonObject {
|
||||
put("type", JsonPrimitive("req"))
|
||||
put("id", JsonPrimitive(id))
|
||||
put("method", JsonPrimitive(method))
|
||||
if (paramsJson != null) put("paramsJSON", JsonPrimitive(paramsJson)) else put("paramsJSON", JsonNull)
|
||||
},
|
||||
)
|
||||
val res = deferred.await()
|
||||
if (res.ok) return res.payloadJson ?: ""
|
||||
val err = res.error
|
||||
throw IllegalStateException("${err?.code ?: "UNAVAILABLE"}: ${err?.message ?: "request failed"}")
|
||||
}
|
||||
|
||||
private data class RpcResponse(val id: String, val ok: Boolean, val payloadJson: String?, val error: ErrorShape?)
|
||||
|
||||
private class Connection(private val socket: Socket, private val reader: BufferedReader, private val writer: BufferedWriter, private val writeLock: Mutex) {
|
||||
val remoteAddress: String? =
|
||||
socket.inetAddress?.hostAddress?.takeIf { it.isNotBlank() }?.let { "${it}:${socket.port}" }
|
||||
|
||||
suspend fun sendJson(obj: JsonObject) {
|
||||
writeLock.withLock {
|
||||
writer.write(obj.toString())
|
||||
writer.write("\n")
|
||||
writer.flush()
|
||||
}
|
||||
}
|
||||
|
||||
fun closeQuietly() {
|
||||
try {
|
||||
socket.close()
|
||||
} catch (_: Throwable) {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Volatile private var currentConnection: Connection? = null
|
||||
|
||||
private suspend fun runLoop() {
|
||||
var attempt = 0
|
||||
while (scope.isActive) {
|
||||
val target = desired
|
||||
if (target == null) {
|
||||
currentConnection?.closeQuietly()
|
||||
currentConnection = null
|
||||
delay(250)
|
||||
continue
|
||||
}
|
||||
|
||||
val (endpoint, hello, tls) = target
|
||||
try {
|
||||
onDisconnected(if (attempt == 0) "Connecting…" else "Reconnecting…")
|
||||
connectOnce(endpoint, hello, tls)
|
||||
attempt = 0
|
||||
} catch (err: Throwable) {
|
||||
attempt += 1
|
||||
onDisconnected("Bridge error: ${err.message ?: err::class.java.simpleName}")
|
||||
val sleepMs = minOf(8_000L, (350.0 * Math.pow(1.7, attempt.toDouble())).toLong())
|
||||
delay(sleepMs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun invokeErrorFromThrowable(err: Throwable): InvokeResult {
|
||||
val msg = err.message?.trim().takeIf { !it.isNullOrEmpty() } ?: err::class.java.simpleName
|
||||
val parts = msg.split(":", limit = 2)
|
||||
if (parts.size == 2) {
|
||||
val code = parts[0].trim()
|
||||
val rest = parts[1].trim()
|
||||
if (code.isNotEmpty() && code.all { it.isUpperCase() || it == '_' }) {
|
||||
return InvokeResult.error(code = code, message = rest.ifEmpty { msg })
|
||||
}
|
||||
}
|
||||
return InvokeResult.error(code = "UNAVAILABLE", message = msg)
|
||||
}
|
||||
|
||||
private suspend fun connectOnce(endpoint: BridgeEndpoint, hello: Hello, tls: BridgeTlsParams?) =
|
||||
withContext(Dispatchers.IO) {
|
||||
if (tls != null) {
|
||||
try {
|
||||
connectWithSocket(endpoint, hello, tls)
|
||||
return@withContext
|
||||
} catch (err: Throwable) {
|
||||
if (tls.required) throw err
|
||||
}
|
||||
}
|
||||
connectWithSocket(endpoint, hello, null)
|
||||
}
|
||||
|
||||
private suspend fun connectWithSocket(endpoint: BridgeEndpoint, hello: Hello, tls: BridgeTlsParams?) {
|
||||
val socket =
|
||||
createBridgeSocket(tls) { fingerprint ->
|
||||
onTlsFingerprint?.invoke(tls?.stableId ?: endpoint.stableId, fingerprint)
|
||||
}
|
||||
socket.tcpNoDelay = true
|
||||
socket.connect(InetSocketAddress(endpoint.host, endpoint.port), 8_000)
|
||||
socket.soTimeout = 0
|
||||
startTlsHandshakeIfNeeded(socket)
|
||||
|
||||
val reader = BufferedReader(InputStreamReader(socket.getInputStream(), Charsets.UTF_8))
|
||||
val writer = BufferedWriter(OutputStreamWriter(socket.getOutputStream(), Charsets.UTF_8))
|
||||
|
||||
val conn = Connection(socket, reader, writer, writeLock)
|
||||
currentConnection = conn
|
||||
|
||||
try {
|
||||
conn.sendJson(buildHelloJson(hello))
|
||||
|
||||
val firstLine = reader.readLine() ?: throw IllegalStateException("bridge closed connection")
|
||||
val first = json.parseToJsonElement(firstLine).asObjectOrNull()
|
||||
?: throw IllegalStateException("unexpected bridge response")
|
||||
when (first["type"].asStringOrNull()) {
|
||||
"hello-ok" -> {
|
||||
val name = first["serverName"].asStringOrNull() ?: "Bridge"
|
||||
val rawCanvasUrl = first["canvasHostUrl"].asStringOrNull()?.trim()?.ifEmpty { null }
|
||||
val rawMainSessionKey = first["mainSessionKey"].asStringOrNull()?.trim()?.ifEmpty { null }
|
||||
canvasHostUrl = normalizeCanvasHostUrl(rawCanvasUrl, endpoint)
|
||||
mainSessionKey = rawMainSessionKey
|
||||
if (BuildConfig.DEBUG) {
|
||||
// Local JVM unit tests use android.jar stubs; Log.d can throw "not mocked".
|
||||
runCatching {
|
||||
android.util.Log.d(
|
||||
"ClawdbotBridge",
|
||||
"canvasHostUrl resolved=${canvasHostUrl ?: "none"} (raw=${rawCanvasUrl ?: "none"})",
|
||||
)
|
||||
}
|
||||
}
|
||||
onConnected(name, conn.remoteAddress, rawMainSessionKey)
|
||||
}
|
||||
"error" -> {
|
||||
val code = first["code"].asStringOrNull() ?: "UNAVAILABLE"
|
||||
val msg = first["message"].asStringOrNull() ?: "connect failed"
|
||||
throw IllegalStateException("$code: $msg")
|
||||
}
|
||||
else -> throw IllegalStateException("unexpected bridge response")
|
||||
}
|
||||
|
||||
while (scope.isActive) {
|
||||
val line = reader.readLine() ?: break
|
||||
val frame = json.parseToJsonElement(line).asObjectOrNull() ?: continue
|
||||
when (frame["type"].asStringOrNull()) {
|
||||
"event" -> {
|
||||
val event = frame["event"].asStringOrNull() ?: continue
|
||||
val payload = frame["payloadJSON"].asStringOrNull()
|
||||
onEvent(event, payload)
|
||||
}
|
||||
"ping" -> {
|
||||
val id = frame["id"].asStringOrNull() ?: ""
|
||||
conn.sendJson(buildJsonObject { put("type", JsonPrimitive("pong")); put("id", JsonPrimitive(id)) })
|
||||
}
|
||||
"res" -> {
|
||||
val id = frame["id"].asStringOrNull() ?: continue
|
||||
val ok = frame["ok"].asBooleanOrNull() ?: false
|
||||
val payloadJson = frame["payloadJSON"].asStringOrNull()
|
||||
val error =
|
||||
frame["error"]?.let {
|
||||
val obj = it.asObjectOrNull() ?: return@let null
|
||||
val code = obj["code"].asStringOrNull() ?: "UNAVAILABLE"
|
||||
val msg = obj["message"].asStringOrNull() ?: "request failed"
|
||||
ErrorShape(code, msg)
|
||||
}
|
||||
pending.remove(id)?.complete(RpcResponse(id, ok, payloadJson, error))
|
||||
}
|
||||
"invoke" -> {
|
||||
val id = frame["id"].asStringOrNull() ?: continue
|
||||
val command = frame["command"].asStringOrNull() ?: ""
|
||||
val params = frame["paramsJSON"].asStringOrNull()
|
||||
val result =
|
||||
try {
|
||||
onInvoke(InvokeRequest(id, command, params))
|
||||
} catch (err: Throwable) {
|
||||
invokeErrorFromThrowable(err)
|
||||
}
|
||||
conn.sendJson(
|
||||
buildJsonObject {
|
||||
put("type", JsonPrimitive("invoke-res"))
|
||||
put("id", JsonPrimitive(id))
|
||||
put("ok", JsonPrimitive(result.ok))
|
||||
if (result.payloadJson != null) put("payloadJSON", JsonPrimitive(result.payloadJson))
|
||||
if (result.error != null) {
|
||||
put(
|
||||
"error",
|
||||
buildJsonObject {
|
||||
put("code", JsonPrimitive(result.error.code))
|
||||
put("message", JsonPrimitive(result.error.message))
|
||||
},
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
"invoke-res" -> {
|
||||
// gateway->node only (ignore)
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
currentConnection = null
|
||||
for ((_, waiter) in pending) {
|
||||
waiter.cancel()
|
||||
}
|
||||
pending.clear()
|
||||
conn.closeQuietly()
|
||||
}
|
||||
}
|
||||
|
||||
private fun buildHelloJson(hello: Hello): JsonObject =
|
||||
buildJsonObject {
|
||||
put("type", JsonPrimitive("hello"))
|
||||
put("nodeId", JsonPrimitive(hello.nodeId))
|
||||
hello.displayName?.let { put("displayName", JsonPrimitive(it)) }
|
||||
hello.token?.let { put("token", JsonPrimitive(it)) }
|
||||
hello.platform?.let { put("platform", JsonPrimitive(it)) }
|
||||
hello.version?.let { put("version", JsonPrimitive(it)) }
|
||||
hello.deviceFamily?.let { put("deviceFamily", JsonPrimitive(it)) }
|
||||
hello.modelIdentifier?.let { put("modelIdentifier", JsonPrimitive(it)) }
|
||||
hello.caps?.let { put("caps", JsonArray(it.map(::JsonPrimitive))) }
|
||||
hello.commands?.let { put("commands", JsonArray(it.map(::JsonPrimitive))) }
|
||||
}
|
||||
|
||||
private fun normalizeCanvasHostUrl(raw: String?, endpoint: BridgeEndpoint): String? {
|
||||
val trimmed = raw?.trim().orEmpty()
|
||||
val parsed = trimmed.takeIf { it.isNotBlank() }?.let { runCatching { URI(it) }.getOrNull() }
|
||||
val host = parsed?.host?.trim().orEmpty()
|
||||
val port = parsed?.port ?: -1
|
||||
val scheme = parsed?.scheme?.trim().orEmpty().ifBlank { "http" }
|
||||
|
||||
if (trimmed.isNotBlank() && !isLoopbackHost(host)) {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
val fallbackHost =
|
||||
endpoint.tailnetDns?.trim().takeIf { !it.isNullOrEmpty() }
|
||||
?: endpoint.lanHost?.trim().takeIf { !it.isNullOrEmpty() }
|
||||
?: endpoint.host.trim()
|
||||
if (fallbackHost.isEmpty()) return trimmed.ifBlank { null }
|
||||
|
||||
val fallbackPort = endpoint.canvasPort ?: if (port > 0) port else 18793
|
||||
val formattedHost = if (fallbackHost.contains(":")) "[${fallbackHost}]" else fallbackHost
|
||||
return "$scheme://$formattedHost:$fallbackPort"
|
||||
}
|
||||
|
||||
private fun isLoopbackHost(raw: String?): Boolean {
|
||||
val host = raw?.trim()?.lowercase().orEmpty()
|
||||
if (host.isEmpty()) return false
|
||||
if (host == "localhost") return true
|
||||
if (host == "::1") return true
|
||||
if (host == "0.0.0.0" || host == "::") return true
|
||||
return host.startsWith("127.")
|
||||
}
|
||||
}
|
||||
|
||||
private fun JsonElement?.asObjectOrNull(): JsonObject? = this as? JsonObject
|
||||
|
||||
private fun JsonElement?.asStringOrNull(): String? =
|
||||
when (this) {
|
||||
is JsonNull -> null
|
||||
is JsonPrimitive -> content
|
||||
else -> null
|
||||
}
|
||||
|
||||
private fun JsonElement?.asBooleanOrNull(): Boolean? =
|
||||
when (this) {
|
||||
is JsonPrimitive -> {
|
||||
val c = content.trim()
|
||||
when {
|
||||
c.equals("true", ignoreCase = true) -> true
|
||||
c.equals("false", ignoreCase = true) -> false
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
else -> null
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
package com.clawdbot.android.bridge
|
||||
|
||||
import android.annotation.SuppressLint
|
||||
import java.net.Socket
|
||||
import java.security.MessageDigest
|
||||
import java.security.SecureRandom
|
||||
import java.security.cert.CertificateException
|
||||
import java.security.cert.X509Certificate
|
||||
import javax.net.ssl.SSLContext
|
||||
import javax.net.ssl.SSLSocket
|
||||
import javax.net.ssl.TrustManagerFactory
|
||||
import javax.net.ssl.X509TrustManager
|
||||
|
||||
data class BridgeTlsParams(
|
||||
val required: Boolean,
|
||||
val expectedFingerprint: String?,
|
||||
val allowTOFU: Boolean,
|
||||
val stableId: String,
|
||||
)
|
||||
|
||||
fun createBridgeSocket(params: BridgeTlsParams?, onStore: ((String) -> Unit)? = null): Socket {
|
||||
if (params == null) return Socket()
|
||||
val expected = params.expectedFingerprint?.let(::normalizeFingerprint)
|
||||
val defaultTrust = defaultTrustManager()
|
||||
@SuppressLint("CustomX509TrustManager")
|
||||
val trustManager =
|
||||
object : X509TrustManager {
|
||||
override fun checkClientTrusted(chain: Array<X509Certificate>, authType: String) {
|
||||
defaultTrust.checkClientTrusted(chain, authType)
|
||||
}
|
||||
|
||||
override fun checkServerTrusted(chain: Array<X509Certificate>, authType: String) {
|
||||
if (chain.isEmpty()) throw CertificateException("empty certificate chain")
|
||||
val fingerprint = sha256Hex(chain[0].encoded)
|
||||
if (expected != null) {
|
||||
if (fingerprint != expected) {
|
||||
throw CertificateException("bridge TLS fingerprint mismatch")
|
||||
}
|
||||
return
|
||||
}
|
||||
if (params.allowTOFU) {
|
||||
onStore?.invoke(fingerprint)
|
||||
return
|
||||
}
|
||||
defaultTrust.checkServerTrusted(chain, authType)
|
||||
}
|
||||
|
||||
override fun getAcceptedIssuers(): Array<X509Certificate> = defaultTrust.acceptedIssuers
|
||||
}
|
||||
|
||||
val context = SSLContext.getInstance("TLS")
|
||||
context.init(null, arrayOf(trustManager), SecureRandom())
|
||||
return context.socketFactory.createSocket()
|
||||
}
|
||||
|
||||
fun startTlsHandshakeIfNeeded(socket: Socket) {
|
||||
if (socket is SSLSocket) {
|
||||
socket.startHandshake()
|
||||
}
|
||||
}
|
||||
|
||||
private fun defaultTrustManager(): X509TrustManager {
|
||||
val factory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
|
||||
factory.init(null as java.security.KeyStore?)
|
||||
val trust =
|
||||
factory.trustManagers.firstOrNull { it is X509TrustManager } as? X509TrustManager
|
||||
return trust ?: throw IllegalStateException("No default X509TrustManager found")
|
||||
}
|
||||
|
||||
private fun sha256Hex(data: ByteArray): String {
|
||||
val digest = MessageDigest.getInstance("SHA-256").digest(data)
|
||||
val out = StringBuilder(digest.size * 2)
|
||||
for (byte in digest) {
|
||||
out.append(String.format("%02x", byte))
|
||||
}
|
||||
return out.toString()
|
||||
}
|
||||
|
||||
private fun normalizeFingerprint(raw: String): String {
|
||||
return raw.lowercase().filter { it in '0'..'9' || it in 'a'..'f' }
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.clawdbot.android.bridge
|
||||
package com.clawdbot.android.gateway
|
||||
|
||||
object BonjourEscapes {
|
||||
fun decode(input: String): String {
|
||||
@@ -8,7 +8,6 @@ import android.net.nsd.NsdManager
|
||||
import android.net.nsd.NsdServiceInfo
|
||||
import android.os.CancellationSignal
|
||||
import android.util.Log
|
||||
import com.clawdbot.android.bridge.BonjourEscapes
|
||||
import java.io.IOException
|
||||
import java.net.InetSocketAddress
|
||||
import java.nio.ByteBuffer
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
package com.clawdbot.android.bridge
|
||||
|
||||
import io.kotest.core.spec.style.StringSpec
|
||||
import io.kotest.matchers.shouldBe
|
||||
|
||||
class BridgeEndpointKotestTest : StringSpec({
|
||||
"manual endpoint builds stable id + name" {
|
||||
val endpoint = BridgeEndpoint.manual("10.0.0.5", 18790)
|
||||
endpoint.stableId shouldBe "manual|10.0.0.5|18790"
|
||||
endpoint.name shouldBe "10.0.0.5:18790"
|
||||
endpoint.host shouldBe "10.0.0.5"
|
||||
endpoint.port shouldBe 18790
|
||||
}
|
||||
})
|
||||
@@ -1,108 +0,0 @@
|
||||
package com.clawdbot.android.bridge
|
||||
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.async
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.junit.Assert.assertEquals
|
||||
import org.junit.Assert.assertTrue
|
||||
import org.junit.Test
|
||||
import java.io.BufferedReader
|
||||
import java.io.BufferedWriter
|
||||
import java.io.InputStreamReader
|
||||
import java.io.OutputStreamWriter
|
||||
import java.net.ServerSocket
|
||||
|
||||
class BridgePairingClientTest {
|
||||
@Test
|
||||
fun helloOkReturnsExistingToken() = runBlocking {
|
||||
val serverSocket = ServerSocket(0)
|
||||
val port = serverSocket.localPort
|
||||
|
||||
val server =
|
||||
async(Dispatchers.IO) {
|
||||
serverSocket.use { ss ->
|
||||
val sock = ss.accept()
|
||||
sock.use { s ->
|
||||
val reader = BufferedReader(InputStreamReader(s.getInputStream(), Charsets.UTF_8))
|
||||
val writer = BufferedWriter(OutputStreamWriter(s.getOutputStream(), Charsets.UTF_8))
|
||||
|
||||
val hello = reader.readLine()
|
||||
assertTrue(hello.contains("\"type\":\"hello\""))
|
||||
writer.write("""{"type":"hello-ok","serverName":"Test Bridge"}""")
|
||||
writer.write("\n")
|
||||
writer.flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val client = BridgePairingClient()
|
||||
val res =
|
||||
client.pairAndHello(
|
||||
endpoint = BridgeEndpoint.manual(host = "127.0.0.1", port = port),
|
||||
hello =
|
||||
BridgePairingClient.Hello(
|
||||
nodeId = "node-1",
|
||||
displayName = "Android Node",
|
||||
token = "token-123",
|
||||
platform = "Android",
|
||||
version = "test",
|
||||
deviceFamily = "Android",
|
||||
modelIdentifier = "SM-X000",
|
||||
caps = null,
|
||||
commands = null,
|
||||
),
|
||||
)
|
||||
assertTrue(res.ok)
|
||||
assertEquals("token-123", res.token)
|
||||
server.await()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun notPairedTriggersPairRequestAndReturnsToken() = runBlocking {
|
||||
val serverSocket = ServerSocket(0)
|
||||
val port = serverSocket.localPort
|
||||
|
||||
val server =
|
||||
async(Dispatchers.IO) {
|
||||
serverSocket.use { ss ->
|
||||
val sock = ss.accept()
|
||||
sock.use { s ->
|
||||
val reader = BufferedReader(InputStreamReader(s.getInputStream(), Charsets.UTF_8))
|
||||
val writer = BufferedWriter(OutputStreamWriter(s.getOutputStream(), Charsets.UTF_8))
|
||||
|
||||
reader.readLine() // hello
|
||||
writer.write("""{"type":"error","code":"NOT_PAIRED","message":"not paired"}""")
|
||||
writer.write("\n")
|
||||
writer.flush()
|
||||
|
||||
val pairReq = reader.readLine()
|
||||
assertTrue(pairReq.contains("\"type\":\"pair-request\""))
|
||||
writer.write("""{"type":"pair-ok","token":"new-token"}""")
|
||||
writer.write("\n")
|
||||
writer.flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val client = BridgePairingClient()
|
||||
val res =
|
||||
client.pairAndHello(
|
||||
endpoint = BridgeEndpoint.manual(host = "127.0.0.1", port = port),
|
||||
hello =
|
||||
BridgePairingClient.Hello(
|
||||
nodeId = "node-1",
|
||||
displayName = "Android Node",
|
||||
token = null,
|
||||
platform = "Android",
|
||||
version = "test",
|
||||
deviceFamily = "Android",
|
||||
modelIdentifier = "SM-X000",
|
||||
caps = null,
|
||||
commands = null,
|
||||
),
|
||||
)
|
||||
assertTrue(res.ok)
|
||||
assertEquals("new-token", res.token)
|
||||
server.await()
|
||||
}
|
||||
}
|
||||
@@ -1,307 +0,0 @@
|
||||
package com.clawdbot.android.bridge
|
||||
|
||||
import kotlinx.coroutines.CompletableDeferred
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.SupervisorJob
|
||||
import kotlinx.coroutines.async
|
||||
import kotlinx.coroutines.cancel
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import org.junit.Assert.assertEquals
|
||||
import org.junit.Assert.assertTrue
|
||||
import org.junit.Test
|
||||
import java.io.BufferedReader
|
||||
import java.io.BufferedWriter
|
||||
import java.io.InputStreamReader
|
||||
import java.io.OutputStreamWriter
|
||||
import java.net.ServerSocket
|
||||
import java.util.concurrent.CountDownLatch
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
class BridgeSessionTest {
|
||||
@Test
|
||||
fun requestReturnsPayloadJson() = runBlocking {
|
||||
val serverSocket = ServerSocket(0)
|
||||
val port = serverSocket.localPort
|
||||
|
||||
val scope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
|
||||
val connected = CompletableDeferred<Unit>()
|
||||
|
||||
val session =
|
||||
BridgeSession(
|
||||
scope = scope,
|
||||
onConnected = { _, _, _ -> connected.complete(Unit) },
|
||||
onDisconnected = { /* ignore */ },
|
||||
onEvent = { _, _ -> /* ignore */ },
|
||||
onInvoke = { BridgeSession.InvokeResult.ok(null) },
|
||||
)
|
||||
|
||||
val server =
|
||||
async(Dispatchers.IO) {
|
||||
serverSocket.use { ss ->
|
||||
val sock = ss.accept()
|
||||
sock.use { s ->
|
||||
val reader = BufferedReader(InputStreamReader(s.getInputStream(), Charsets.UTF_8))
|
||||
val writer = BufferedWriter(OutputStreamWriter(s.getOutputStream(), Charsets.UTF_8))
|
||||
|
||||
val hello = reader.readLine()
|
||||
assertTrue(hello.contains("\"type\":\"hello\""))
|
||||
writer.write("""{"type":"hello-ok","serverName":"Test Bridge","canvasHostUrl":"http://127.0.0.1:18789"}""")
|
||||
writer.write("\n")
|
||||
writer.flush()
|
||||
|
||||
val req = reader.readLine()
|
||||
assertTrue(req.contains("\"type\":\"req\""))
|
||||
val id = extractJsonString(req, "id")
|
||||
writer.write("""{"type":"res","id":"$id","ok":true,"payloadJSON":"{\"value\":123}"}""")
|
||||
writer.write("\n")
|
||||
writer.flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
session.connect(
|
||||
endpoint = BridgeEndpoint.manual(host = "127.0.0.1", port = port),
|
||||
hello =
|
||||
BridgeSession.Hello(
|
||||
nodeId = "node-1",
|
||||
displayName = "Android Node",
|
||||
token = null,
|
||||
platform = "Android",
|
||||
version = "test",
|
||||
deviceFamily = null,
|
||||
modelIdentifier = null,
|
||||
caps = null,
|
||||
commands = null,
|
||||
),
|
||||
)
|
||||
|
||||
connected.await()
|
||||
assertEquals("http://127.0.0.1:18789", session.currentCanvasHostUrl())
|
||||
val payload = session.request(method = "health", paramsJson = null)
|
||||
assertEquals("""{"value":123}""", payload)
|
||||
server.await()
|
||||
|
||||
session.disconnect()
|
||||
scope.cancel()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun requestThrowsOnErrorResponse() = runBlocking {
|
||||
val serverSocket = ServerSocket(0)
|
||||
val port = serverSocket.localPort
|
||||
|
||||
val scope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
|
||||
val connected = CompletableDeferred<Unit>()
|
||||
|
||||
val session =
|
||||
BridgeSession(
|
||||
scope = scope,
|
||||
onConnected = { _, _, _ -> connected.complete(Unit) },
|
||||
onDisconnected = { /* ignore */ },
|
||||
onEvent = { _, _ -> /* ignore */ },
|
||||
onInvoke = { BridgeSession.InvokeResult.ok(null) },
|
||||
)
|
||||
|
||||
val server =
|
||||
async(Dispatchers.IO) {
|
||||
serverSocket.use { ss ->
|
||||
val sock = ss.accept()
|
||||
sock.use { s ->
|
||||
val reader = BufferedReader(InputStreamReader(s.getInputStream(), Charsets.UTF_8))
|
||||
val writer = BufferedWriter(OutputStreamWriter(s.getOutputStream(), Charsets.UTF_8))
|
||||
|
||||
reader.readLine() // hello
|
||||
writer.write("""{"type":"hello-ok","serverName":"Test Bridge"}""")
|
||||
writer.write("\n")
|
||||
writer.flush()
|
||||
|
||||
val req = reader.readLine()
|
||||
val id = extractJsonString(req, "id")
|
||||
writer.write(
|
||||
"""{"type":"res","id":"$id","ok":false,"error":{"code":"FORBIDDEN","message":"nope"}}""",
|
||||
)
|
||||
writer.write("\n")
|
||||
writer.flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
session.connect(
|
||||
endpoint = BridgeEndpoint.manual(host = "127.0.0.1", port = port),
|
||||
hello =
|
||||
BridgeSession.Hello(
|
||||
nodeId = "node-1",
|
||||
displayName = "Android Node",
|
||||
token = null,
|
||||
platform = "Android",
|
||||
version = "test",
|
||||
deviceFamily = null,
|
||||
modelIdentifier = null,
|
||||
caps = null,
|
||||
commands = null,
|
||||
),
|
||||
)
|
||||
connected.await()
|
||||
|
||||
try {
|
||||
session.request(method = "chat.history", paramsJson = """{"sessionKey":"main"}""")
|
||||
throw AssertionError("expected request() to throw")
|
||||
} catch (e: IllegalStateException) {
|
||||
assertTrue(e.message?.contains("FORBIDDEN: nope") == true)
|
||||
}
|
||||
server.await()
|
||||
|
||||
session.disconnect()
|
||||
scope.cancel()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun invokeResReturnsErrorWhenHandlerThrows() = runBlocking {
|
||||
val serverSocket = ServerSocket(0)
|
||||
val port = serverSocket.localPort
|
||||
|
||||
val scope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
|
||||
val connected = CompletableDeferred<Unit>()
|
||||
|
||||
val session =
|
||||
BridgeSession(
|
||||
scope = scope,
|
||||
onConnected = { _, _, _ -> connected.complete(Unit) },
|
||||
onDisconnected = { /* ignore */ },
|
||||
onEvent = { _, _ -> /* ignore */ },
|
||||
onInvoke = { throw IllegalStateException("FOO_BAR: boom") },
|
||||
)
|
||||
|
||||
val invokeResLine = CompletableDeferred<String>()
|
||||
val server =
|
||||
async(Dispatchers.IO) {
|
||||
serverSocket.use { ss ->
|
||||
val sock = ss.accept()
|
||||
sock.use { s ->
|
||||
val reader = BufferedReader(InputStreamReader(s.getInputStream(), Charsets.UTF_8))
|
||||
val writer = BufferedWriter(OutputStreamWriter(s.getOutputStream(), Charsets.UTF_8))
|
||||
|
||||
reader.readLine() // hello
|
||||
writer.write("""{"type":"hello-ok","serverName":"Test Bridge"}""")
|
||||
writer.write("\n")
|
||||
writer.flush()
|
||||
|
||||
// Ask the node to invoke something; handler will throw.
|
||||
writer.write("""{"type":"invoke","id":"i1","command":"canvas.snapshot","paramsJSON":null}""")
|
||||
writer.write("\n")
|
||||
writer.flush()
|
||||
|
||||
val res = reader.readLine()
|
||||
invokeResLine.complete(res)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
session.connect(
|
||||
endpoint = BridgeEndpoint.manual(host = "127.0.0.1", port = port),
|
||||
hello =
|
||||
BridgeSession.Hello(
|
||||
nodeId = "node-1",
|
||||
displayName = "Android Node",
|
||||
token = null,
|
||||
platform = "Android",
|
||||
version = "test",
|
||||
deviceFamily = null,
|
||||
modelIdentifier = null,
|
||||
caps = null,
|
||||
commands = null,
|
||||
),
|
||||
)
|
||||
connected.await()
|
||||
|
||||
// Give the reader loop time to process.
|
||||
val line = invokeResLine.await()
|
||||
assertTrue(line.contains("\"type\":\"invoke-res\""))
|
||||
assertTrue(line.contains("\"ok\":false"))
|
||||
assertTrue(line.contains("\"code\":\"FOO_BAR\""))
|
||||
assertTrue(line.contains("\"message\":\"boom\""))
|
||||
server.await()
|
||||
|
||||
session.disconnect()
|
||||
scope.cancel()
|
||||
}
|
||||
|
||||
@Test(timeout = 12_000)
|
||||
fun reconnectsAfterBridgeClosesDuringHello() = runBlocking {
|
||||
val serverSocket = ServerSocket(0)
|
||||
val port = serverSocket.localPort
|
||||
|
||||
val scope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
|
||||
val connected = CountDownLatch(1)
|
||||
val connectionsSeen = CountDownLatch(2)
|
||||
|
||||
val session =
|
||||
BridgeSession(
|
||||
scope = scope,
|
||||
onConnected = { _, _, _ -> connected.countDown() },
|
||||
onDisconnected = { /* ignore */ },
|
||||
onEvent = { _, _ -> /* ignore */ },
|
||||
onInvoke = { BridgeSession.InvokeResult.ok(null) },
|
||||
)
|
||||
|
||||
val server =
|
||||
async(Dispatchers.IO) {
|
||||
serverSocket.use { ss ->
|
||||
// First connection: read hello, then close (no response).
|
||||
val sock1 = ss.accept()
|
||||
sock1.use { s ->
|
||||
val reader = BufferedReader(InputStreamReader(s.getInputStream(), Charsets.UTF_8))
|
||||
reader.readLine() // hello
|
||||
connectionsSeen.countDown()
|
||||
}
|
||||
|
||||
// Second connection: complete hello.
|
||||
val sock2 = ss.accept()
|
||||
sock2.use { s ->
|
||||
val reader = BufferedReader(InputStreamReader(s.getInputStream(), Charsets.UTF_8))
|
||||
val writer = BufferedWriter(OutputStreamWriter(s.getOutputStream(), Charsets.UTF_8))
|
||||
reader.readLine() // hello
|
||||
writer.write("""{"type":"hello-ok","serverName":"Test Bridge"}""")
|
||||
writer.write("\n")
|
||||
writer.flush()
|
||||
connectionsSeen.countDown()
|
||||
Thread.sleep(200)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
session.connect(
|
||||
endpoint = BridgeEndpoint.manual(host = "127.0.0.1", port = port),
|
||||
hello =
|
||||
BridgeSession.Hello(
|
||||
nodeId = "node-1",
|
||||
displayName = "Android Node",
|
||||
token = null,
|
||||
platform = "Android",
|
||||
version = "test",
|
||||
deviceFamily = null,
|
||||
modelIdentifier = null,
|
||||
caps = null,
|
||||
commands = null,
|
||||
),
|
||||
)
|
||||
|
||||
assertTrue("expected two connection attempts", connectionsSeen.await(8, TimeUnit.SECONDS))
|
||||
assertTrue("expected session to connect", connected.await(8, TimeUnit.SECONDS))
|
||||
|
||||
session.disconnect()
|
||||
scope.cancel()
|
||||
server.await()
|
||||
}
|
||||
}
|
||||
|
||||
private fun extractJsonString(raw: String, key: String): String {
|
||||
val needle = "\"$key\":\""
|
||||
val start = raw.indexOf(needle)
|
||||
if (start < 0) throw IllegalArgumentException("missing key $key in $raw")
|
||||
val from = start + needle.length
|
||||
val end = raw.indexOf('"', from)
|
||||
if (end < 0) throw IllegalArgumentException("unterminated string for $key in $raw")
|
||||
return raw.substring(from, end)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.clawdbot.android.bridge
|
||||
package com.clawdbot.android.gateway
|
||||
|
||||
import org.junit.Assert.assertEquals
|
||||
import org.junit.Test
|
||||
Reference in New Issue
Block a user