diff --git a/apps/macos/Sources/Clawdis/GatewayDiscoveryModel.swift b/apps/macos/Sources/Clawdis/GatewayDiscoveryModel.swift index a0bdb7f69..33545fba4 100644 --- a/apps/macos/Sources/Clawdis/GatewayDiscoveryModel.swift +++ b/apps/macos/Sources/Clawdis/GatewayDiscoveryModel.swift @@ -2,6 +2,7 @@ import ClawdisKit import Foundation import Network import Observation +import OSLog @MainActor @Observable @@ -27,9 +28,13 @@ final class GatewayDiscoveryModel { var statusText: String = "Idle" private var browsers: [String: NWBrowser] = [:] + private var resultsByDomain: [String: Set] = [:] private var gatewaysByDomain: [String: [DiscoveredGateway]] = [:] private var statesByDomain: [String: NWBrowser.State] = [:] private var localIdentity: LocalIdentity + private var resolvedTXTByID: [String: [String: String]] = [:] + private var pendingTXTResolvers: [String: GatewayTXTResolver] = [:] + private let logger = Logger(subsystem: "com.steipete.clawdis", category: "gateway-discovery") init() { self.localIdentity = Self.buildLocalIdentityFast() @@ -57,60 +62,8 @@ final class GatewayDiscoveryModel { browser.browseResultsChangedHandler = { [weak self] results, _ in Task { @MainActor in guard let self else { return } - self.gatewaysByDomain[domain] = results.compactMap { result -> DiscoveredGateway? in - guard case let .service(name, _, _, _) = result.endpoint else { return nil } - - let decodedName = BonjourEscapes.decode(name) - let txt = Self.txtDictionary(from: result) - - let advertisedName = txt["displayName"] - .map(Self.prettifyInstanceName) - .flatMap { $0.isEmpty ? nil : $0 } - let prettyName = - advertisedName ?? Self.prettifyServiceName(decodedName) - - var lanHost: String? - var tailnetDns: String? - var sshPort = 22 - var cliPath: String? - - if let value = txt["lanHost"] { - let trimmed = value.trimmingCharacters(in: .whitespacesAndNewlines) - lanHost = trimmed.isEmpty ? nil : trimmed - } - if let value = txt["tailnetDns"] { - let trimmed = value.trimmingCharacters(in: .whitespacesAndNewlines) - tailnetDns = trimmed.isEmpty ? nil : trimmed - } - if let value = txt["sshPort"], - let parsed = Int(value.trimmingCharacters(in: .whitespacesAndNewlines)), - parsed > 0 - { - sshPort = parsed - } - if let value = txt["cliPath"] { - let trimmed = value.trimmingCharacters(in: .whitespacesAndNewlines) - cliPath = trimmed.isEmpty ? nil : trimmed - } - - let isLocal = Self.isLocalGateway( - lanHost: lanHost, - tailnetDns: tailnetDns, - displayName: prettyName, - serviceName: decodedName, - local: self.localIdentity) - return DiscoveredGateway( - displayName: prettyName, - lanHost: lanHost, - tailnetDns: tailnetDns, - sshPort: sshPort, - cliPath: cliPath, - stableID: BridgeEndpointID.stableID(result.endpoint), - debugID: BridgeEndpointID.prettyDescription(result.endpoint), - isLocal: isLocal) - } - .sorted { $0.displayName.localizedCaseInsensitiveCompare($1.displayName) == .orderedAscending } - + self.resultsByDomain[domain] = results + self.updateGateways(for: domain) self.recomputeGateways() } } @@ -125,8 +78,12 @@ final class GatewayDiscoveryModel { browser.cancel() } self.browsers = [:] + self.resultsByDomain = [:] self.gatewaysByDomain = [:] self.statesByDomain = [:] + self.resolvedTXTByID = [:] + self.pendingTXTResolvers.values.forEach { $0.cancel() } + self.pendingTXTResolvers = [:] self.gateways = [] self.statusText = "Stopped" } @@ -138,6 +95,85 @@ final class GatewayDiscoveryModel { .sorted { $0.displayName.localizedCaseInsensitiveCompare($1.displayName) == .orderedAscending } } + private func updateGateways(for domain: String) { + guard let results = self.resultsByDomain[domain] else { + self.gatewaysByDomain[domain] = [] + return + } + + self.gatewaysByDomain[domain] = results.compactMap { result -> DiscoveredGateway? in + guard case let .service(name, type, resultDomain, _) = result.endpoint else { return nil } + + let decodedName = BonjourEscapes.decode(name) + let stableID = BridgeEndpointID.stableID(result.endpoint) + let resolvedTXT = self.resolvedTXTByID[stableID] ?? [:] + let txt = Self.txtDictionary(from: result).merging( + resolvedTXT, + uniquingKeysWith: { _, new in new }) + + let advertisedName = txt["displayName"] + .map(Self.prettifyInstanceName) + .flatMap { $0.isEmpty ? nil : $0 } + let prettyName = + advertisedName ?? Self.prettifyServiceName(decodedName) + + var lanHost: String? + var tailnetDns: String? + var sshPort = 22 + var cliPath: String? + + if let value = txt["lanHost"] { + let trimmed = value.trimmingCharacters(in: .whitespacesAndNewlines) + lanHost = trimmed.isEmpty ? nil : trimmed + } + if let value = txt["tailnetDns"] { + let trimmed = value.trimmingCharacters(in: .whitespacesAndNewlines) + tailnetDns = trimmed.isEmpty ? nil : trimmed + } + if let value = txt["sshPort"], + let parsed = Int(value.trimmingCharacters(in: .whitespacesAndNewlines)), + parsed > 0 + { + sshPort = parsed + } + if let value = txt["cliPath"] { + let trimmed = value.trimmingCharacters(in: .whitespacesAndNewlines) + cliPath = trimmed.isEmpty ? nil : trimmed + } + + if lanHost == nil || tailnetDns == nil { + self.ensureTXTResolution( + stableID: stableID, + serviceName: name, + type: type, + domain: resultDomain) + } + + let isLocal = Self.isLocalGateway( + lanHost: lanHost, + tailnetDns: tailnetDns, + displayName: prettyName, + serviceName: decodedName, + local: self.localIdentity) + return DiscoveredGateway( + displayName: prettyName, + lanHost: lanHost, + tailnetDns: tailnetDns, + sshPort: sshPort, + cliPath: cliPath, + stableID: stableID, + debugID: BridgeEndpointID.prettyDescription(result.endpoint), + isLocal: isLocal) + } + .sorted { $0.displayName.localizedCaseInsensitiveCompare($1.displayName) == .orderedAscending } + } + + private func updateGatewaysForAllDomains() { + for domain in self.resultsByDomain.keys { + self.updateGateways(for: domain) + } + } + private func updateStatusText() { let states = Array(self.statesByDomain.values) if states.isEmpty { @@ -192,6 +228,39 @@ final class GatewayDiscoveryModel { return merged } + private func ensureTXTResolution( + stableID: String, + serviceName: String, + type: String, + domain: String) + { + guard self.resolvedTXTByID[stableID] == nil else { return } + guard self.pendingTXTResolvers[stableID] == nil else { return } + + let resolver = GatewayTXTResolver( + name: serviceName, + type: type, + domain: domain, + logger: self.logger) + { [weak self] result in + Task { @MainActor in + guard let self else { return } + self.pendingTXTResolvers[stableID] = nil + switch result { + case let .success(txt): + self.resolvedTXTByID[stableID] = txt + self.updateGatewaysForAllDomains() + self.recomputeGateways() + case .failure: + break + } + } + } + + self.pendingTXTResolvers[stableID] = resolver + resolver.start() + } + private static func prettifyInstanceName(_ decodedName: String) -> String { let normalized = decodedName.split(whereSeparator: \.isWhitespace).joined(separator: " ") let stripped = normalized.replacingOccurrences(of: " (Clawdis)", with: "") @@ -339,3 +408,78 @@ final class GatewayDiscoveryModel { return trimmed.lowercased() } } + +final class GatewayTXTResolver: NSObject, NetServiceDelegate { + private let service: NetService + private let completion: (Result<[String: String], Error>) -> Void + private let logger: Logger + private var didFinish = false + + init( + name: String, + type: String, + domain: String, + logger: Logger, + completion: @escaping (Result<[String: String], Error>) -> Void) + { + self.service = NetService(domain: domain, type: type, name: name) + self.completion = completion + self.logger = logger + super.init() + self.service.delegate = self + } + + func start(timeout: TimeInterval = 2.0) { + self.service.schedule(in: .main, forMode: .common) + self.service.resolve(withTimeout: timeout) + } + + func cancel() { + self.finish(result: .failure(GatewayTXTResolverError.cancelled)) + } + + func netServiceDidResolveAddress(_ sender: NetService) { + let txt = Self.decodeTXT(sender.txtRecordData()) + if !txt.isEmpty { + self.logger.debug( + "discovery: resolved TXT for \(sender.name, privacy: .public): \(self.formatTXT(txt), privacy: .public)") + } + self.finish(result: .success(txt)) + } + + func netService(_ sender: NetService, didNotResolve errorDict: [String: NSNumber]) { + self.finish(result: .failure(GatewayTXTResolverError.resolveFailed(errorDict))) + } + + private func finish(result: Result<[String: String], Error>) { + guard !self.didFinish else { return } + self.didFinish = true + self.service.stop() + self.service.remove(from: .main, forMode: .common) + self.completion(result) + } + + private static func decodeTXT(_ data: Data?) -> [String: String] { + guard let data else { return [:] } + let dict = NetService.dictionary(fromTXTRecord: data) + var out: [String: String] = [:] + out.reserveCapacity(dict.count) + for (key, value) in dict { + if let str = String(data: value, encoding: .utf8) { + out[key] = str + } + } + return out + } + + private func formatTXT(_ txt: [String: String]) -> String { + txt.sorted(by: { $0.key < $1.key }) + .map { "\($0.key)=\($0.value)" } + .joined(separator: " ") + } +} + +enum GatewayTXTResolverError: Error { + case cancelled + case resolveFailed([String: NSNumber]) +}