From 1673bf2d44835f2365d7950d3eebae5319dbd8b2 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Thu, 18 Dec 2025 01:04:13 +0100 Subject: [PATCH] fix(android): use system DNS for wide-area discovery --- .../clawdis/node/bridge/BridgeDiscovery.kt | 138 ++++++++++-------- 1 file changed, 77 insertions(+), 61 deletions(-) diff --git a/apps/android/app/src/main/java/com/steipete/clawdis/node/bridge/BridgeDiscovery.kt b/apps/android/app/src/main/java/com/steipete/clawdis/node/bridge/BridgeDiscovery.kt index 7b555697c..1480d15bf 100644 --- a/apps/android/app/src/main/java/com/steipete/clawdis/node/bridge/BridgeDiscovery.kt +++ b/apps/android/app/src/main/java/com/steipete/clawdis/node/bridge/BridgeDiscovery.kt @@ -2,13 +2,16 @@ package com.steipete.clawdis.node.bridge import android.content.Context import android.net.ConnectivityManager +import android.net.DnsResolver import android.net.NetworkCapabilities import android.net.nsd.NsdManager import android.net.nsd.NsdServiceInfo import android.os.Build -import java.net.InetSocketAddress -import java.time.Duration +import android.os.CancellationSignal +import java.io.IOException import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.Executor +import java.util.concurrent.Executors import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job @@ -17,15 +20,20 @@ import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.launch -import org.xbill.DNS.ExtendedResolver -import org.xbill.DNS.Lookup -import org.xbill.DNS.SimpleResolver +import kotlinx.coroutines.suspendCancellableCoroutine import org.xbill.DNS.AAAARecord import org.xbill.DNS.ARecord +import org.xbill.DNS.DClass +import org.xbill.DNS.Message +import org.xbill.DNS.Name import org.xbill.DNS.PTRRecord import org.xbill.DNS.SRVRecord +import org.xbill.DNS.Section +import org.xbill.DNS.TextParseException import org.xbill.DNS.TXTRecord import org.xbill.DNS.Type +import kotlin.coroutines.resume +import kotlin.coroutines.resumeWithException class BridgeDiscovery( context: Context, @@ -33,6 +41,7 @@ class BridgeDiscovery( ) { private val nsd = context.getSystemService(NsdManager::class.java) private val connectivity = context.getSystemService(ConnectivityManager::class.java) + private val dns = DnsResolver.getInstance() private val serviceType = "_clawdis-bridge._tcp." private val wideAreaDomain = "clawdis.internal." @@ -42,6 +51,7 @@ class BridgeDiscovery( val bridges: StateFlow> = _bridges.asStateFlow() private var unicastJob: Job? = null + private val dnsExecutor: Executor = Executors.newCachedThreadPool() private val discoveryListener = object : NsdManager.DiscoveryListener { @@ -144,22 +154,21 @@ class BridgeDiscovery( } private suspend fun refreshUnicast(domain: String) { - val resolver = createUnicastResolver() val ptrName = "${serviceType}${domain}" - val ptrRecords = lookup(ptrName, Type.PTR, resolver).mapNotNull { it as? PTRRecord } + val ptrRecords = lookupUnicast(ptrName, Type.PTR).mapNotNull { it as? PTRRecord } val next = LinkedHashMap() for (ptr in ptrRecords) { val instanceFqdn = ptr.target.toString() val srv = - lookup(instanceFqdn, Type.SRV, resolver).firstOrNull { it is SRVRecord } as? SRVRecord ?: continue + lookupUnicast(instanceFqdn, Type.SRV).firstOrNull { it is SRVRecord } as? SRVRecord ?: continue val port = srv.port if (port <= 0) continue val targetFqdn = srv.target.toString() - val host = resolveHost(targetFqdn, resolver) ?: continue + val host = resolveHostUnicast(targetFqdn) ?: continue - val txt = lookup(instanceFqdn, Type.TXT, resolver).mapNotNull { it as? TXTRecord } + val txt = lookupUnicast(instanceFqdn, Type.TXT).mapNotNull { it as? TXTRecord } val instanceName = BonjourEscapes.decode(decodeInstanceName(instanceFqdn, domain)) val displayName = BonjourEscapes.decode(txtValue(txt, "displayName") ?: instanceName) val id = stableId(instanceName, domain) @@ -186,64 +195,71 @@ class BridgeDiscovery( return raw.removeSuffix(".") } - private fun lookup(name: String, type: Int, resolver: org.xbill.DNS.Resolver?): List { - return try { - val lookup = Lookup(name, type) - if (resolver != null) { - lookup.setResolver(resolver) - lookup.setCache(null) + private suspend fun lookupUnicast(name: String, type: Int): List { + val query = + try { + Message.newQuery( + org.xbill.DNS.Record.newRecord( + Name.fromString(name), + type, + DClass.IN, + ), + ) + } catch (_: TextParseException) { + return emptyList() } - val out = lookup.run() ?: return emptyList() - out.toList() - } catch (_: Throwable) { + + val network = preferredDnsNetwork() + val bytes = + try { + rawQuery(network, query.toWire()) + } catch (_: Throwable) { + return emptyList() + } + + return try { + val msg = Message(bytes) + msg.getSectionArray(Section.ANSWER)?.toList() ?: emptyList() + } catch (_: IOException) { emptyList() } } - private fun createUnicastResolver(): org.xbill.DNS.Resolver? { + private fun preferredDnsNetwork(): android.net.Network? { val cm = connectivity ?: return null - // Prefer VPN DNS (Tailscale) when present; fall back to active network DNS. - val candidateNetworks = - buildList { - cm.allNetworks - .firstOrNull { n -> - val caps = cm.getNetworkCapabilities(n) ?: return@firstOrNull false - caps.hasTransport(NetworkCapabilities.TRANSPORT_VPN) - }?.let(::add) - cm.activeNetwork?.let(::add) - }.distinct() + // Prefer VPN (Tailscale) when present; otherwise use the active network. + cm.allNetworks.firstOrNull { n -> + val caps = cm.getNetworkCapabilities(n) ?: return@firstOrNull false + caps.hasTransport(NetworkCapabilities.TRANSPORT_VPN) + }?.let { return it } - val servers = - candidateNetworks - .asSequence() - .flatMap { n -> - cm.getLinkProperties(n)?.dnsServers?.asSequence() ?: emptySequence() - } - .distinctBy { it.hostAddress ?: it.toString() } - .toList() - if (servers.isEmpty()) return null - - return try { - val resolvers = - servers.mapNotNull { addr -> - try { - SimpleResolver().apply { setAddress(InetSocketAddress(addr, 53)) } - } catch (_: Throwable) { - null - } - } - if (resolvers.isEmpty()) return null - - ExtendedResolver(resolvers.toTypedArray()).apply { - // Vienna -> London via tailnet: allow a bit more headroom than LAN mDNS. - setTimeout(Duration.ofMillis(3000)) - } - } catch (_: Throwable) { - null - } + return cm.activeNetwork } + private suspend fun rawQuery(network: android.net.Network?, wireQuery: ByteArray): ByteArray = + suspendCancellableCoroutine { cont -> + val signal = CancellationSignal() + cont.invokeOnCancellation { signal.cancel() } + + dns.rawQuery( + network, + wireQuery, + 0, + dnsExecutor, + signal, + object : DnsResolver.Callback { + override fun onAnswer(answer: ByteArray, rcode: Int) { + cont.resume(answer) + } + + override fun onError(error: DnsResolver.DnsException) { + cont.resumeWithException(error) + } + }, + ) + } + private fun txtValue(records: List, key: String): String? { val prefix = "$key=" for (r in records) { @@ -263,13 +279,13 @@ class BridgeDiscovery( return null } - private fun resolveHost(hostname: String, resolver: org.xbill.DNS.Resolver?): String? { + private suspend fun resolveHostUnicast(hostname: String): String? { val a = - lookup(hostname, Type.A, resolver) + lookupUnicast(hostname, Type.A) .mapNotNull { it as? ARecord } .mapNotNull { it.address?.hostAddress } val aaaa = - lookup(hostname, Type.AAAA, resolver) + lookupUnicast(hostname, Type.AAAA) .mapNotNull { it as? AAAARecord } .mapNotNull { it.address?.hostAddress }