fix(android): improve wide-area bridge discovery

This commit is contained in:
Peter Steinberger
2025-12-18 01:40:08 +01:00
parent 3351c972e7
commit 86225d0eb6
4 changed files with 170 additions and 14 deletions

View File

@@ -15,6 +15,7 @@ class MainViewModel(app: Application) : AndroidViewModel(app) {
val camera: CameraCaptureManager = runtime.camera val camera: CameraCaptureManager = runtime.camera
val bridges: StateFlow<List<BridgeEndpoint>> = runtime.bridges val bridges: StateFlow<List<BridgeEndpoint>> = runtime.bridges
val discoveryStatusText: StateFlow<String> = runtime.discoveryStatusText
val isConnected: StateFlow<Boolean> = runtime.isConnected val isConnected: StateFlow<Boolean> = runtime.isConnected
val statusText: StateFlow<String> = runtime.statusText val statusText: StateFlow<String> = runtime.statusText

View File

@@ -41,6 +41,7 @@ class NodeRuntime(context: Context) {
private val discovery = BridgeDiscovery(appContext, scope = scope) private val discovery = BridgeDiscovery(appContext, scope = scope)
val bridges: StateFlow<List<BridgeEndpoint>> = discovery.bridges val bridges: StateFlow<List<BridgeEndpoint>> = discovery.bridges
val discoveryStatusText: StateFlow<String> = discovery.statusText
private val _isConnected = MutableStateFlow(false) private val _isConnected = MutableStateFlow(false)
val isConnected: StateFlow<Boolean> = _isConnected.asStateFlow() val isConnected: StateFlow<Boolean> = _isConnected.asStateFlow()

View File

@@ -8,7 +8,9 @@ import android.net.nsd.NsdManager
import android.net.nsd.NsdServiceInfo import android.net.nsd.NsdServiceInfo
import android.os.Build import android.os.Build
import android.os.CancellationSignal import android.os.CancellationSignal
import android.util.Log
import java.io.IOException import java.io.IOException
import java.net.InetSocketAddress
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.Executor import java.util.concurrent.Executor
import java.util.concurrent.Executors import java.util.concurrent.Executors
@@ -24,11 +26,16 @@ import kotlinx.coroutines.suspendCancellableCoroutine
import org.xbill.DNS.AAAARecord import org.xbill.DNS.AAAARecord
import org.xbill.DNS.ARecord import org.xbill.DNS.ARecord
import org.xbill.DNS.DClass import org.xbill.DNS.DClass
import org.xbill.DNS.ExtendedResolver
import org.xbill.DNS.Message import org.xbill.DNS.Message
import org.xbill.DNS.Name import org.xbill.DNS.Name
import org.xbill.DNS.PTRRecord import org.xbill.DNS.PTRRecord
import org.xbill.DNS.Record
import org.xbill.DNS.Rcode
import org.xbill.DNS.Resolver
import org.xbill.DNS.SRVRecord import org.xbill.DNS.SRVRecord
import org.xbill.DNS.Section import org.xbill.DNS.Section
import org.xbill.DNS.SimpleResolver
import org.xbill.DNS.TextParseException import org.xbill.DNS.TextParseException
import org.xbill.DNS.TXTRecord import org.xbill.DNS.TXTRecord
import org.xbill.DNS.Type import org.xbill.DNS.Type
@@ -44,15 +51,22 @@ class BridgeDiscovery(
private val dns = DnsResolver.getInstance() private val dns = DnsResolver.getInstance()
private val serviceType = "_clawdis-bridge._tcp." private val serviceType = "_clawdis-bridge._tcp."
private val wideAreaDomain = "clawdis.internal." private val wideAreaDomain = "clawdis.internal."
private val logTag = "Clawdis/BridgeDiscovery"
private val localById = ConcurrentHashMap<String, BridgeEndpoint>() private val localById = ConcurrentHashMap<String, BridgeEndpoint>()
private val unicastById = ConcurrentHashMap<String, BridgeEndpoint>() private val unicastById = ConcurrentHashMap<String, BridgeEndpoint>()
private val _bridges = MutableStateFlow<List<BridgeEndpoint>>(emptyList()) private val _bridges = MutableStateFlow<List<BridgeEndpoint>>(emptyList())
val bridges: StateFlow<List<BridgeEndpoint>> = _bridges.asStateFlow() val bridges: StateFlow<List<BridgeEndpoint>> = _bridges.asStateFlow()
private val _statusText = MutableStateFlow("Searching…")
val statusText: StateFlow<String> = _statusText.asStateFlow()
private var unicastJob: Job? = null private var unicastJob: Job? = null
private val dnsExecutor: Executor = Executors.newCachedThreadPool() private val dnsExecutor: Executor = Executors.newCachedThreadPool()
@Volatile private var lastWideAreaRcode: Int? = null
@Volatile private var lastWideAreaCount: Int = 0
private val discoveryListener = private val discoveryListener =
object : NsdManager.DiscoveryListener { object : NsdManager.DiscoveryListener {
override fun onStartDiscoveryFailed(serviceType: String, errorCode: Int) {} override fun onStartDiscoveryFailed(serviceType: String, errorCode: Int) {}
@@ -133,6 +147,27 @@ class BridgeDiscovery(
private fun publish() { private fun publish() {
_bridges.value = _bridges.value =
(localById.values + unicastById.values).sortedBy { it.name.lowercase() } (localById.values + unicastById.values).sortedBy { it.name.lowercase() }
_statusText.value = buildStatusText()
}
private fun buildStatusText(): String {
val localCount = localById.size
val wideRcode = lastWideAreaRcode
val wideCount = lastWideAreaCount
val wide =
when (wideRcode) {
null -> "Wide: ?"
Rcode.NOERROR -> "Wide: $wideCount"
Rcode.NXDOMAIN -> "Wide: NXDOMAIN"
else -> "Wide: ${Rcode.string(wideRcode)}"
}
return when {
localCount == 0 && wideRcode == null -> "Searching for bridges…"
localCount == 0 -> "$wide"
else -> "Local: $localCount$wide"
}
} }
private fun stableId(serviceName: String, domain: String): String { private fun stableId(serviceName: String, domain: String): String {
@@ -155,20 +190,40 @@ class BridgeDiscovery(
private suspend fun refreshUnicast(domain: String) { private suspend fun refreshUnicast(domain: String) {
val ptrName = "${serviceType}${domain}" val ptrName = "${serviceType}${domain}"
val ptrRecords = lookupUnicast(ptrName, Type.PTR).mapNotNull { it as? PTRRecord } val ptrMsg = lookupUnicastMessage(ptrName, Type.PTR) ?: return
val ptrRecords = records(ptrMsg, Section.ANSWER).mapNotNull { it as? PTRRecord }
val next = LinkedHashMap<String, BridgeEndpoint>() val next = LinkedHashMap<String, BridgeEndpoint>()
for (ptr in ptrRecords) { for (ptr in ptrRecords) {
val instanceFqdn = ptr.target.toString() val instanceFqdn = ptr.target.toString()
val srv = val srv =
lookupUnicast(instanceFqdn, Type.SRV).firstOrNull { it is SRVRecord } as? SRVRecord ?: continue recordByName(ptrMsg, instanceFqdn, Type.SRV) as? SRVRecord
?: run {
val msg = lookupUnicastMessage(instanceFqdn, Type.SRV) ?: return@run null
recordByName(msg, instanceFqdn, Type.SRV) as? SRVRecord
}
?: continue
val port = srv.port val port = srv.port
if (port <= 0) continue if (port <= 0) continue
val targetFqdn = srv.target.toString() val targetFqdn = srv.target.toString()
val host = resolveHostUnicast(targetFqdn) ?: continue val host =
resolveHostFromMessage(ptrMsg, targetFqdn)
?: resolveHostFromMessage(lookupUnicastMessage(instanceFqdn, Type.SRV), targetFqdn)
?: resolveHostUnicast(targetFqdn)
?: continue
val txt = lookupUnicast(instanceFqdn, Type.TXT).mapNotNull { it as? TXTRecord } val txtFromPtr =
recordsByName(ptrMsg, Section.ADDITIONAL)[keyName(instanceFqdn)]
.orEmpty()
.mapNotNull { it as? TXTRecord }
val txt =
if (txtFromPtr.isNotEmpty()) {
txtFromPtr
} else {
val msg = lookupUnicastMessage(instanceFqdn, Type.TXT)
records(msg, Section.ANSWER).mapNotNull { it as? TXTRecord }
}
val instanceName = BonjourEscapes.decode(decodeInstanceName(instanceFqdn, domain)) val instanceName = BonjourEscapes.decode(decodeInstanceName(instanceFqdn, domain))
val displayName = BonjourEscapes.decode(txtValue(txt, "displayName") ?: instanceName) val displayName = BonjourEscapes.decode(txtValue(txt, "displayName") ?: instanceName)
val id = stableId(instanceName, domain) val id = stableId(instanceName, domain)
@@ -177,7 +232,16 @@ class BridgeDiscovery(
unicastById.clear() unicastById.clear()
unicastById.putAll(next) unicastById.putAll(next)
lastWideAreaRcode = ptrMsg.header.rcode
lastWideAreaCount = next.size
publish() publish()
if (next.isEmpty()) {
Log.d(
logTag,
"wide-area discovery: 0 results for $ptrName (rcode=${Rcode.string(ptrMsg.header.rcode)})",
)
}
} }
private fun decodeInstanceName(instanceFqdn: String, domain: String): String { private fun decodeInstanceName(instanceFqdn: String, domain: String): String {
@@ -195,7 +259,7 @@ class BridgeDiscovery(
return raw.removeSuffix(".") return raw.removeSuffix(".")
} }
private suspend fun lookupUnicast(name: String, type: Int): List<org.xbill.DNS.Record> { private suspend fun lookupUnicastMessage(name: String, type: Int): Message? {
val query = val query =
try { try {
Message.newQuery( Message.newQuery(
@@ -206,25 +270,73 @@ class BridgeDiscovery(
), ),
) )
} catch (_: TextParseException) { } catch (_: TextParseException) {
return emptyList() return null
} }
val system = queryViaSystemDns(query)
if (records(system, Section.ANSWER).any { it.type == type }) return system
val direct = createDirectResolver() ?: return system
return try {
val msg = direct.send(query)
if (records(msg, Section.ANSWER).any { it.type == type }) msg else system
} catch (_: Throwable) {
system
}
}
private suspend fun queryViaSystemDns(query: Message): Message? {
val network = preferredDnsNetwork() val network = preferredDnsNetwork()
val bytes = val bytes =
try { try {
rawQuery(network, query.toWire()) rawQuery(network, query.toWire())
} catch (_: Throwable) { } catch (_: Throwable) {
return emptyList() return null
} }
return try { return try {
val msg = Message(bytes) Message(bytes)
msg.getSectionArray(Section.ANSWER)?.toList() ?: emptyList()
} catch (_: IOException) { } catch (_: IOException) {
emptyList() null
} }
} }
private fun records(msg: Message?, section: Int): List<Record> {
return msg?.getSectionArray(section)?.toList() ?: emptyList()
}
private fun keyName(raw: String): String {
return raw.trim().lowercase()
}
private fun recordsByName(msg: Message, section: Int): Map<String, List<Record>> {
val next = LinkedHashMap<String, MutableList<Record>>()
for (r in records(msg, section)) {
val name = r.name?.toString() ?: continue
next.getOrPut(keyName(name)) { mutableListOf() }.add(r)
}
return next
}
private fun recordByName(msg: Message, fqdn: String, type: Int): Record? {
val key = keyName(fqdn)
val byNameAnswer = recordsByName(msg, Section.ANSWER)
val fromAnswer = byNameAnswer[key].orEmpty().firstOrNull { it.type == type }
if (fromAnswer != null) return fromAnswer
val byNameAdditional = recordsByName(msg, Section.ADDITIONAL)
return byNameAdditional[key].orEmpty().firstOrNull { it.type == type }
}
private fun resolveHostFromMessage(msg: Message?, hostname: String): String? {
val m = msg ?: return null
val key = keyName(hostname)
val additional = recordsByName(m, Section.ADDITIONAL)[key].orEmpty()
val a = additional.mapNotNull { it as? ARecord }.mapNotNull { it.address?.hostAddress }
val aaaa = additional.mapNotNull { it as? AAAARecord }.mapNotNull { it.address?.hostAddress }
return a.firstOrNull() ?: aaaa.firstOrNull()
}
private fun preferredDnsNetwork(): android.net.Network? { private fun preferredDnsNetwork(): android.net.Network? {
val cm = connectivity ?: return null val cm = connectivity ?: return null
@@ -237,6 +349,48 @@ class BridgeDiscovery(
return cm.activeNetwork return cm.activeNetwork
} }
private fun createDirectResolver(): Resolver? {
val cm = connectivity ?: return null
val candidateNetworks =
buildList {
cm.allNetworks
.firstOrNull { n ->
val caps = cm.getNetworkCapabilities(n) ?: return@firstOrNull false
caps.hasTransport(NetworkCapabilities.TRANSPORT_VPN)
}?.let(::add)
cm.activeNetwork?.let(::add)
}.distinct()
val servers =
candidateNetworks
.asSequence()
.flatMap { n ->
cm.getLinkProperties(n)?.dnsServers?.asSequence() ?: emptySequence()
}
.distinctBy { it.hostAddress ?: it.toString() }
.toList()
if (servers.isEmpty()) return null
return try {
val resolvers =
servers.mapNotNull { addr ->
try {
SimpleResolver().apply {
setAddress(InetSocketAddress(addr, 53))
setTimeout(3)
}
} catch (_: Throwable) {
null
}
}
if (resolvers.isEmpty()) return null
ExtendedResolver(resolvers.toTypedArray()).apply { setTimeout(3) }
} catch (_: Throwable) {
null
}
}
private suspend fun rawQuery(network: android.net.Network?, wireQuery: ByteArray): ByteArray = private suspend fun rawQuery(network: android.net.Network?, wireQuery: ByteArray): ByteArray =
suspendCancellableCoroutine { cont -> suspendCancellableCoroutine { cont ->
val signal = CancellationSignal() val signal = CancellationSignal()
@@ -281,11 +435,11 @@ class BridgeDiscovery(
private suspend fun resolveHostUnicast(hostname: String): String? { private suspend fun resolveHostUnicast(hostname: String): String? {
val a = val a =
lookupUnicast(hostname, Type.A) records(lookupUnicastMessage(hostname, Type.A), Section.ANSWER)
.mapNotNull { it as? ARecord } .mapNotNull { it as? ARecord }
.mapNotNull { it.address?.hostAddress } .mapNotNull { it.address?.hostAddress }
val aaaa = val aaaa =
lookupUnicast(hostname, Type.AAAA) records(lookupUnicastMessage(hostname, Type.AAAA), Section.ANSWER)
.mapNotNull { it as? AAAARecord } .mapNotNull { it as? AAAARecord }
.mapNotNull { it.address?.hostAddress } .mapNotNull { it.address?.hostAddress }

View File

@@ -64,6 +64,7 @@ fun SettingsSheet(viewModel: MainViewModel) {
val serverName by viewModel.serverName.collectAsState() val serverName by viewModel.serverName.collectAsState()
val remoteAddress by viewModel.remoteAddress.collectAsState() val remoteAddress by viewModel.remoteAddress.collectAsState()
val bridges by viewModel.bridges.collectAsState() val bridges by viewModel.bridges.collectAsState()
val discoveryStatusText by viewModel.discoveryStatusText.collectAsState()
val listState = rememberLazyListState() val listState = rememberLazyListState()
val (wakeWordsText, setWakeWordsText) = remember { mutableStateOf("") } val (wakeWordsText, setWakeWordsText) = remember { mutableStateOf("") }
@@ -95,7 +96,7 @@ fun SettingsSheet(viewModel: MainViewModel) {
val bridgeDiscoveryFooterText = val bridgeDiscoveryFooterText =
if (bridges.isEmpty()) { if (bridges.isEmpty()) {
"Searching for bridges…" discoveryStatusText
} else { } else {
"Discovery active • ${bridges.size} bridge${if (bridges.size == 1) "" else "s"} found" "Discovery active • ${bridges.size} bridge${if (bridges.size == 1) "" else "s"} found"
} }
@@ -309,4 +310,3 @@ fun SettingsSheet(viewModel: MainViewModel) {
item { Spacer(modifier = Modifier.height(20.dp)) } item { Spacer(modifier = Modifier.height(20.dp)) }
} }
} }