fix(android): use system DNS for wide-area discovery

This commit is contained in:
Peter Steinberger
2025-12-18 01:04:13 +01:00
parent 4c656ea22f
commit 1673bf2d44

View File

@@ -2,13 +2,16 @@ package com.steipete.clawdis.node.bridge
import android.content.Context
import android.net.ConnectivityManager
import android.net.DnsResolver
import android.net.NetworkCapabilities
import android.net.nsd.NsdManager
import android.net.nsd.NsdServiceInfo
import android.os.Build
import java.net.InetSocketAddress
import java.time.Duration
import android.os.CancellationSignal
import java.io.IOException
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.Executor
import java.util.concurrent.Executors
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
@@ -17,15 +20,20 @@ import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.launch
import org.xbill.DNS.ExtendedResolver
import org.xbill.DNS.Lookup
import org.xbill.DNS.SimpleResolver
import kotlinx.coroutines.suspendCancellableCoroutine
import org.xbill.DNS.AAAARecord
import org.xbill.DNS.ARecord
import org.xbill.DNS.DClass
import org.xbill.DNS.Message
import org.xbill.DNS.Name
import org.xbill.DNS.PTRRecord
import org.xbill.DNS.SRVRecord
import org.xbill.DNS.Section
import org.xbill.DNS.TextParseException
import org.xbill.DNS.TXTRecord
import org.xbill.DNS.Type
import kotlin.coroutines.resume
import kotlin.coroutines.resumeWithException
class BridgeDiscovery(
context: Context,
@@ -33,6 +41,7 @@ class BridgeDiscovery(
) {
private val nsd = context.getSystemService(NsdManager::class.java)
private val connectivity = context.getSystemService(ConnectivityManager::class.java)
private val dns = DnsResolver.getInstance()
private val serviceType = "_clawdis-bridge._tcp."
private val wideAreaDomain = "clawdis.internal."
@@ -42,6 +51,7 @@ class BridgeDiscovery(
val bridges: StateFlow<List<BridgeEndpoint>> = _bridges.asStateFlow()
private var unicastJob: Job? = null
private val dnsExecutor: Executor = Executors.newCachedThreadPool()
private val discoveryListener =
object : NsdManager.DiscoveryListener {
@@ -144,22 +154,21 @@ class BridgeDiscovery(
}
private suspend fun refreshUnicast(domain: String) {
val resolver = createUnicastResolver()
val ptrName = "${serviceType}${domain}"
val ptrRecords = lookup(ptrName, Type.PTR, resolver).mapNotNull { it as? PTRRecord }
val ptrRecords = lookupUnicast(ptrName, Type.PTR).mapNotNull { it as? PTRRecord }
val next = LinkedHashMap<String, BridgeEndpoint>()
for (ptr in ptrRecords) {
val instanceFqdn = ptr.target.toString()
val srv =
lookup(instanceFqdn, Type.SRV, resolver).firstOrNull { it is SRVRecord } as? SRVRecord ?: continue
lookupUnicast(instanceFqdn, Type.SRV).firstOrNull { it is SRVRecord } as? SRVRecord ?: continue
val port = srv.port
if (port <= 0) continue
val targetFqdn = srv.target.toString()
val host = resolveHost(targetFqdn, resolver) ?: continue
val host = resolveHostUnicast(targetFqdn) ?: continue
val txt = lookup(instanceFqdn, Type.TXT, resolver).mapNotNull { it as? TXTRecord }
val txt = lookupUnicast(instanceFqdn, Type.TXT).mapNotNull { it as? TXTRecord }
val instanceName = BonjourEscapes.decode(decodeInstanceName(instanceFqdn, domain))
val displayName = BonjourEscapes.decode(txtValue(txt, "displayName") ?: instanceName)
val id = stableId(instanceName, domain)
@@ -186,64 +195,71 @@ class BridgeDiscovery(
return raw.removeSuffix(".")
}
private fun lookup(name: String, type: Int, resolver: org.xbill.DNS.Resolver?): List<org.xbill.DNS.Record> {
return try {
val lookup = Lookup(name, type)
if (resolver != null) {
lookup.setResolver(resolver)
lookup.setCache(null)
private suspend fun lookupUnicast(name: String, type: Int): List<org.xbill.DNS.Record> {
val query =
try {
Message.newQuery(
org.xbill.DNS.Record.newRecord(
Name.fromString(name),
type,
DClass.IN,
),
)
} catch (_: TextParseException) {
return emptyList()
}
val out = lookup.run() ?: return emptyList()
out.toList()
} catch (_: Throwable) {
val network = preferredDnsNetwork()
val bytes =
try {
rawQuery(network, query.toWire())
} catch (_: Throwable) {
return emptyList()
}
return try {
val msg = Message(bytes)
msg.getSectionArray(Section.ANSWER)?.toList() ?: emptyList()
} catch (_: IOException) {
emptyList()
}
}
private fun createUnicastResolver(): org.xbill.DNS.Resolver? {
private fun preferredDnsNetwork(): android.net.Network? {
val cm = connectivity ?: return null
// Prefer VPN DNS (Tailscale) when present; fall back to active network DNS.
val candidateNetworks =
buildList {
cm.allNetworks
.firstOrNull { n ->
val caps = cm.getNetworkCapabilities(n) ?: return@firstOrNull false
caps.hasTransport(NetworkCapabilities.TRANSPORT_VPN)
}?.let(::add)
cm.activeNetwork?.let(::add)
}.distinct()
// Prefer VPN (Tailscale) when present; otherwise use the active network.
cm.allNetworks.firstOrNull { n ->
val caps = cm.getNetworkCapabilities(n) ?: return@firstOrNull false
caps.hasTransport(NetworkCapabilities.TRANSPORT_VPN)
}?.let { return it }
val servers =
candidateNetworks
.asSequence()
.flatMap { n ->
cm.getLinkProperties(n)?.dnsServers?.asSequence() ?: emptySequence()
}
.distinctBy { it.hostAddress ?: it.toString() }
.toList()
if (servers.isEmpty()) return null
return try {
val resolvers =
servers.mapNotNull { addr ->
try {
SimpleResolver().apply { setAddress(InetSocketAddress(addr, 53)) }
} catch (_: Throwable) {
null
}
}
if (resolvers.isEmpty()) return null
ExtendedResolver(resolvers.toTypedArray()).apply {
// Vienna -> London via tailnet: allow a bit more headroom than LAN mDNS.
setTimeout(Duration.ofMillis(3000))
}
} catch (_: Throwable) {
null
}
return cm.activeNetwork
}
private suspend fun rawQuery(network: android.net.Network?, wireQuery: ByteArray): ByteArray =
suspendCancellableCoroutine { cont ->
val signal = CancellationSignal()
cont.invokeOnCancellation { signal.cancel() }
dns.rawQuery(
network,
wireQuery,
0,
dnsExecutor,
signal,
object : DnsResolver.Callback<ByteArray> {
override fun onAnswer(answer: ByteArray, rcode: Int) {
cont.resume(answer)
}
override fun onError(error: DnsResolver.DnsException) {
cont.resumeWithException(error)
}
},
)
}
private fun txtValue(records: List<TXTRecord>, key: String): String? {
val prefix = "$key="
for (r in records) {
@@ -263,13 +279,13 @@ class BridgeDiscovery(
return null
}
private fun resolveHost(hostname: String, resolver: org.xbill.DNS.Resolver?): String? {
private suspend fun resolveHostUnicast(hostname: String): String? {
val a =
lookup(hostname, Type.A, resolver)
lookupUnicast(hostname, Type.A)
.mapNotNull { it as? ARecord }
.mapNotNull { it.address?.hostAddress }
val aaaa =
lookup(hostname, Type.AAAA, resolver)
lookupUnicast(hostname, Type.AAAA)
.mapNotNull { it as? AAAARecord }
.mapNotNull { it.address?.hostAddress }