fix(android): use system DNS for wide-area discovery

This commit is contained in:
Peter Steinberger
2025-12-18 01:04:13 +01:00
parent 4c656ea22f
commit 1673bf2d44

View File

@@ -2,13 +2,16 @@ package com.steipete.clawdis.node.bridge
import android.content.Context import android.content.Context
import android.net.ConnectivityManager import android.net.ConnectivityManager
import android.net.DnsResolver
import android.net.NetworkCapabilities import android.net.NetworkCapabilities
import android.net.nsd.NsdManager import android.net.nsd.NsdManager
import android.net.nsd.NsdServiceInfo import android.net.nsd.NsdServiceInfo
import android.os.Build import android.os.Build
import java.net.InetSocketAddress import android.os.CancellationSignal
import java.time.Duration import java.io.IOException
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.Executor
import java.util.concurrent.Executors
import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job import kotlinx.coroutines.Job
@@ -17,15 +20,20 @@ import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import org.xbill.DNS.ExtendedResolver import kotlinx.coroutines.suspendCancellableCoroutine
import org.xbill.DNS.Lookup
import org.xbill.DNS.SimpleResolver
import org.xbill.DNS.AAAARecord import org.xbill.DNS.AAAARecord
import org.xbill.DNS.ARecord import org.xbill.DNS.ARecord
import org.xbill.DNS.DClass
import org.xbill.DNS.Message
import org.xbill.DNS.Name
import org.xbill.DNS.PTRRecord import org.xbill.DNS.PTRRecord
import org.xbill.DNS.SRVRecord import org.xbill.DNS.SRVRecord
import org.xbill.DNS.Section
import org.xbill.DNS.TextParseException
import org.xbill.DNS.TXTRecord import org.xbill.DNS.TXTRecord
import org.xbill.DNS.Type import org.xbill.DNS.Type
import kotlin.coroutines.resume
import kotlin.coroutines.resumeWithException
class BridgeDiscovery( class BridgeDiscovery(
context: Context, context: Context,
@@ -33,6 +41,7 @@ class BridgeDiscovery(
) { ) {
private val nsd = context.getSystemService(NsdManager::class.java) private val nsd = context.getSystemService(NsdManager::class.java)
private val connectivity = context.getSystemService(ConnectivityManager::class.java) private val connectivity = context.getSystemService(ConnectivityManager::class.java)
private val dns = DnsResolver.getInstance()
private val serviceType = "_clawdis-bridge._tcp." private val serviceType = "_clawdis-bridge._tcp."
private val wideAreaDomain = "clawdis.internal." private val wideAreaDomain = "clawdis.internal."
@@ -42,6 +51,7 @@ class BridgeDiscovery(
val bridges: StateFlow<List<BridgeEndpoint>> = _bridges.asStateFlow() val bridges: StateFlow<List<BridgeEndpoint>> = _bridges.asStateFlow()
private var unicastJob: Job? = null private var unicastJob: Job? = null
private val dnsExecutor: Executor = Executors.newCachedThreadPool()
private val discoveryListener = private val discoveryListener =
object : NsdManager.DiscoveryListener { object : NsdManager.DiscoveryListener {
@@ -144,22 +154,21 @@ class BridgeDiscovery(
} }
private suspend fun refreshUnicast(domain: String) { private suspend fun refreshUnicast(domain: String) {
val resolver = createUnicastResolver()
val ptrName = "${serviceType}${domain}" val ptrName = "${serviceType}${domain}"
val ptrRecords = lookup(ptrName, Type.PTR, resolver).mapNotNull { it as? PTRRecord } val ptrRecords = lookupUnicast(ptrName, Type.PTR).mapNotNull { it as? PTRRecord }
val next = LinkedHashMap<String, BridgeEndpoint>() val next = LinkedHashMap<String, BridgeEndpoint>()
for (ptr in ptrRecords) { for (ptr in ptrRecords) {
val instanceFqdn = ptr.target.toString() val instanceFqdn = ptr.target.toString()
val srv = val srv =
lookup(instanceFqdn, Type.SRV, resolver).firstOrNull { it is SRVRecord } as? SRVRecord ?: continue lookupUnicast(instanceFqdn, Type.SRV).firstOrNull { it is SRVRecord } as? SRVRecord ?: continue
val port = srv.port val port = srv.port
if (port <= 0) continue if (port <= 0) continue
val targetFqdn = srv.target.toString() val targetFqdn = srv.target.toString()
val host = resolveHost(targetFqdn, resolver) ?: continue val host = resolveHostUnicast(targetFqdn) ?: continue
val txt = lookup(instanceFqdn, Type.TXT, resolver).mapNotNull { it as? TXTRecord } val txt = lookupUnicast(instanceFqdn, Type.TXT).mapNotNull { it as? TXTRecord }
val instanceName = BonjourEscapes.decode(decodeInstanceName(instanceFqdn, domain)) val instanceName = BonjourEscapes.decode(decodeInstanceName(instanceFqdn, domain))
val displayName = BonjourEscapes.decode(txtValue(txt, "displayName") ?: instanceName) val displayName = BonjourEscapes.decode(txtValue(txt, "displayName") ?: instanceName)
val id = stableId(instanceName, domain) val id = stableId(instanceName, domain)
@@ -186,64 +195,71 @@ class BridgeDiscovery(
return raw.removeSuffix(".") return raw.removeSuffix(".")
} }
private fun lookup(name: String, type: Int, resolver: org.xbill.DNS.Resolver?): List<org.xbill.DNS.Record> { private suspend fun lookupUnicast(name: String, type: Int): List<org.xbill.DNS.Record> {
return try { val query =
val lookup = Lookup(name, type) try {
if (resolver != null) { Message.newQuery(
lookup.setResolver(resolver) org.xbill.DNS.Record.newRecord(
lookup.setCache(null) Name.fromString(name),
type,
DClass.IN,
),
)
} catch (_: TextParseException) {
return emptyList()
} }
val out = lookup.run() ?: return emptyList()
out.toList() val network = preferredDnsNetwork()
} catch (_: Throwable) { val bytes =
try {
rawQuery(network, query.toWire())
} catch (_: Throwable) {
return emptyList()
}
return try {
val msg = Message(bytes)
msg.getSectionArray(Section.ANSWER)?.toList() ?: emptyList()
} catch (_: IOException) {
emptyList() emptyList()
} }
} }
private fun createUnicastResolver(): org.xbill.DNS.Resolver? { private fun preferredDnsNetwork(): android.net.Network? {
val cm = connectivity ?: return null val cm = connectivity ?: return null
// Prefer VPN DNS (Tailscale) when present; fall back to active network DNS. // Prefer VPN (Tailscale) when present; otherwise use the active network.
val candidateNetworks = cm.allNetworks.firstOrNull { n ->
buildList { val caps = cm.getNetworkCapabilities(n) ?: return@firstOrNull false
cm.allNetworks caps.hasTransport(NetworkCapabilities.TRANSPORT_VPN)
.firstOrNull { n -> }?.let { return it }
val caps = cm.getNetworkCapabilities(n) ?: return@firstOrNull false
caps.hasTransport(NetworkCapabilities.TRANSPORT_VPN)
}?.let(::add)
cm.activeNetwork?.let(::add)
}.distinct()
val servers = return cm.activeNetwork
candidateNetworks
.asSequence()
.flatMap { n ->
cm.getLinkProperties(n)?.dnsServers?.asSequence() ?: emptySequence()
}
.distinctBy { it.hostAddress ?: it.toString() }
.toList()
if (servers.isEmpty()) return null
return try {
val resolvers =
servers.mapNotNull { addr ->
try {
SimpleResolver().apply { setAddress(InetSocketAddress(addr, 53)) }
} catch (_: Throwable) {
null
}
}
if (resolvers.isEmpty()) return null
ExtendedResolver(resolvers.toTypedArray()).apply {
// Vienna -> London via tailnet: allow a bit more headroom than LAN mDNS.
setTimeout(Duration.ofMillis(3000))
}
} catch (_: Throwable) {
null
}
} }
private suspend fun rawQuery(network: android.net.Network?, wireQuery: ByteArray): ByteArray =
suspendCancellableCoroutine { cont ->
val signal = CancellationSignal()
cont.invokeOnCancellation { signal.cancel() }
dns.rawQuery(
network,
wireQuery,
0,
dnsExecutor,
signal,
object : DnsResolver.Callback<ByteArray> {
override fun onAnswer(answer: ByteArray, rcode: Int) {
cont.resume(answer)
}
override fun onError(error: DnsResolver.DnsException) {
cont.resumeWithException(error)
}
},
)
}
private fun txtValue(records: List<TXTRecord>, key: String): String? { private fun txtValue(records: List<TXTRecord>, key: String): String? {
val prefix = "$key=" val prefix = "$key="
for (r in records) { for (r in records) {
@@ -263,13 +279,13 @@ class BridgeDiscovery(
return null return null
} }
private fun resolveHost(hostname: String, resolver: org.xbill.DNS.Resolver?): String? { private suspend fun resolveHostUnicast(hostname: String): String? {
val a = val a =
lookup(hostname, Type.A, resolver) lookupUnicast(hostname, Type.A)
.mapNotNull { it as? ARecord } .mapNotNull { it as? ARecord }
.mapNotNull { it.address?.hostAddress } .mapNotNull { it.address?.hostAddress }
val aaaa = val aaaa =
lookup(hostname, Type.AAAA, resolver) lookupUnicast(hostname, Type.AAAA)
.mapNotNull { it as? AAAARecord } .mapNotNull { it as? AAAARecord }
.mapNotNull { it.address?.hostAddress } .mapNotNull { it.address?.hostAddress }