fix(macos): share gateway websocket connection
parent
ce8db12b22
commit
e944a0239d
|
|
@ -11,26 +11,9 @@ actor AgentRPC {
|
||||||
static let shared = AgentRPC()
|
static let shared = AgentRPC()
|
||||||
|
|
||||||
private let logger = Logger(subsystem: "com.steipete.clawdis", category: "agent.rpc")
|
private let logger = Logger(subsystem: "com.steipete.clawdis", category: "agent.rpc")
|
||||||
private let gateway = GatewayChannel()
|
|
||||||
private var configured = false
|
|
||||||
|
|
||||||
private var gatewayURL: URL {
|
|
||||||
let port = GatewayEnvironment.gatewayPort()
|
|
||||||
return URL(string: "ws://127.0.0.1:\(port)")!
|
|
||||||
}
|
|
||||||
|
|
||||||
private var gatewayToken: String? {
|
|
||||||
ProcessInfo.processInfo.environment["CLAWDIS_GATEWAY_TOKEN"]
|
|
||||||
}
|
|
||||||
|
|
||||||
func start() async throws {
|
|
||||||
if self.configured { return }
|
|
||||||
await self.gateway.configure(url: self.gatewayURL, token: self.gatewayToken)
|
|
||||||
self.configured = true
|
|
||||||
}
|
|
||||||
|
|
||||||
func shutdown() async {
|
func shutdown() async {
|
||||||
// no-op for WS; socket managed by GatewayChannel
|
// no-op; socket managed by GatewayConnection
|
||||||
}
|
}
|
||||||
|
|
||||||
func setHeartbeatsEnabled(_ enabled: Bool) async -> Bool {
|
func setHeartbeatsEnabled(_ enabled: Bool) async -> Bool {
|
||||||
|
|
@ -85,8 +68,7 @@ actor AgentRPC {
|
||||||
}
|
}
|
||||||
|
|
||||||
func controlRequest(method: String, params: ControlRequestParams? = nil) async throws -> Data {
|
func controlRequest(method: String, params: ControlRequestParams? = nil) async throws -> Data {
|
||||||
try await self.start()
|
|
||||||
let rawParams = params?.raw.reduce(into: [String: AnyCodable]()) { $0[$1.key] = AnyCodable($1.value) }
|
let rawParams = params?.raw.reduce(into: [String: AnyCodable]()) { $0[$1.key] = AnyCodable($1.value) }
|
||||||
return try await self.gateway.request(method: method, params: rawParams)
|
return try await GatewayConnection.shared.request(method: method, params: rawParams)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -55,35 +55,34 @@ final class ControlChannel: ObservableObject {
|
||||||
@Published private(set) var lastPingMs: Double?
|
@Published private(set) var lastPingMs: Double?
|
||||||
|
|
||||||
private let logger = Logger(subsystem: "com.steipete.clawdis", category: "control")
|
private let logger = Logger(subsystem: "com.steipete.clawdis", category: "control")
|
||||||
private let gateway = GatewayChannel()
|
|
||||||
private var gatewayPort: Int = GatewayEnvironment.gatewayPort()
|
|
||||||
private var gatewayURL: URL { URL(string: "ws://127.0.0.1:\(self.gatewayPort)")! }
|
|
||||||
|
|
||||||
private var gatewayToken: String? {
|
|
||||||
ProcessInfo.processInfo.environment["CLAWDIS_GATEWAY_TOKEN"]
|
|
||||||
}
|
|
||||||
|
|
||||||
private var eventTokens: [NSObjectProtocol] = []
|
private var eventTokens: [NSObjectProtocol] = []
|
||||||
|
|
||||||
|
private init() {
|
||||||
|
self.startEventStream()
|
||||||
|
}
|
||||||
|
|
||||||
func configure() async {
|
func configure() async {
|
||||||
self.state = .connecting
|
self.state = .connecting
|
||||||
await self.gateway.configure(url: self.gatewayURL, token: self.gatewayToken)
|
do {
|
||||||
self.startEventStream()
|
try await GatewayConnection.shared.refresh()
|
||||||
self.state = .connected
|
self.state = .connected
|
||||||
PresenceReporter.shared.sendImmediate(reason: "connect")
|
PresenceReporter.shared.sendImmediate(reason: "connect")
|
||||||
|
} catch {
|
||||||
|
let message = self.friendlyGatewayMessage(error)
|
||||||
|
self.state = .degraded(message)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func configure(mode: Mode = .local) async throws {
|
func configure(mode: Mode = .local) async throws {
|
||||||
switch mode {
|
switch mode {
|
||||||
case .local:
|
case .local:
|
||||||
self.gatewayPort = GatewayEnvironment.gatewayPort()
|
|
||||||
await self.configure()
|
await self.configure()
|
||||||
case let .remote(target, identity):
|
case let .remote(target, identity):
|
||||||
// Create/ensure SSH tunnel, then talk to the forwarded local port.
|
// Create/ensure SSH tunnel, then talk to the forwarded local port.
|
||||||
_ = (target, identity)
|
_ = (target, identity)
|
||||||
do {
|
do {
|
||||||
let forwarded = try await RemoteTunnelManager.shared.ensureControlTunnel()
|
_ = try await RemoteTunnelManager.shared.ensureControlTunnel()
|
||||||
self.gatewayPort = Int(forwarded)
|
|
||||||
await self.configure()
|
await self.configure()
|
||||||
} catch {
|
} catch {
|
||||||
self.state = .degraded(error.localizedDescription)
|
self.state = .degraded(error.localizedDescription)
|
||||||
|
|
@ -124,7 +123,7 @@ final class ControlChannel: ObservableObject {
|
||||||
{
|
{
|
||||||
do {
|
do {
|
||||||
let rawParams = params?.reduce(into: [String: AnyCodable]()) { $0[$1.key] = AnyCodable($1.value) }
|
let rawParams = params?.reduce(into: [String: AnyCodable]()) { $0[$1.key] = AnyCodable($1.value) }
|
||||||
let data = try await self.gateway.request(method: method, params: rawParams, timeoutMs: timeoutMs)
|
let data = try await GatewayConnection.shared.request(method: method, params: rawParams, timeoutMs: timeoutMs)
|
||||||
self.state = .connected
|
self.state = .connected
|
||||||
return data
|
return data
|
||||||
} catch {
|
} catch {
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,10 @@ extension URLSession: WebSocketSessioning {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct WebSocketSessionBox: @unchecked Sendable {
|
||||||
|
let session: any WebSocketSessioning
|
||||||
|
}
|
||||||
|
|
||||||
struct GatewayEvent: Codable {
|
struct GatewayEvent: Codable {
|
||||||
let type: String
|
let type: String
|
||||||
let event: String?
|
let event: String?
|
||||||
|
|
@ -81,17 +85,40 @@ actor GatewayChannelActor {
|
||||||
private let decoder = JSONDecoder()
|
private let decoder = JSONDecoder()
|
||||||
private let encoder = JSONEncoder()
|
private let encoder = JSONEncoder()
|
||||||
private var watchdogTask: Task<Void, Never>?
|
private var watchdogTask: Task<Void, Never>?
|
||||||
|
private var tickTask: Task<Void, Never>?
|
||||||
private let defaultRequestTimeoutMs: Double = 15000
|
private let defaultRequestTimeoutMs: Double = 15000
|
||||||
|
|
||||||
init(url: URL, token: String?, session: WebSocketSessioning? = nil) {
|
init(url: URL, token: String?, session: WebSocketSessionBox? = nil) {
|
||||||
self.url = url
|
self.url = url
|
||||||
self.token = token
|
self.token = token
|
||||||
self.session = session ?? URLSession(configuration: .default)
|
self.session = session?.session ?? URLSession(configuration: .default)
|
||||||
Task { [weak self] in
|
Task { [weak self] in
|
||||||
await self?.startWatchdog()
|
await self?.startWatchdog()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shutdown() async {
|
||||||
|
self.shouldReconnect = false
|
||||||
|
self.connected = false
|
||||||
|
|
||||||
|
self.watchdogTask?.cancel()
|
||||||
|
self.watchdogTask = nil
|
||||||
|
|
||||||
|
self.tickTask?.cancel()
|
||||||
|
self.tickTask = nil
|
||||||
|
|
||||||
|
self.task?.cancel(with: .goingAway, reason: nil)
|
||||||
|
self.task = nil
|
||||||
|
|
||||||
|
await self.failPending(NSError(domain: "Gateway", code: 0, userInfo: [NSLocalizedDescriptionKey: "gateway channel shutdown"]))
|
||||||
|
|
||||||
|
let waiters = self.connectWaiters
|
||||||
|
self.connectWaiters.removeAll()
|
||||||
|
for waiter in waiters {
|
||||||
|
waiter.resume(throwing: NSError(domain: "Gateway", code: 0, userInfo: [NSLocalizedDescriptionKey: "gateway channel shutdown"]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private func startWatchdog() {
|
private func startWatchdog() {
|
||||||
self.watchdogTask?.cancel()
|
self.watchdogTask?.cancel()
|
||||||
self.watchdogTask = Task { [weak self] in
|
self.watchdogTask = Task { [weak self] in
|
||||||
|
|
@ -104,6 +131,7 @@ actor GatewayChannelActor {
|
||||||
// Keep nudging reconnect in case exponential backoff stalls.
|
// Keep nudging reconnect in case exponential backoff stalls.
|
||||||
while self.shouldReconnect {
|
while self.shouldReconnect {
|
||||||
try? await Task.sleep(nanoseconds: 30 * 1_000_000_000) // 30s cadence
|
try? await Task.sleep(nanoseconds: 30 * 1_000_000_000) // 30s cadence
|
||||||
|
guard self.shouldReconnect else { return }
|
||||||
if self.connected { continue }
|
if self.connected { continue }
|
||||||
do {
|
do {
|
||||||
try await self.connect()
|
try await self.connect()
|
||||||
|
|
@ -207,7 +235,11 @@ actor GatewayChannelActor {
|
||||||
self.tickIntervalMs = Double(tick)
|
self.tickIntervalMs = Double(tick)
|
||||||
}
|
}
|
||||||
self.lastTick = Date()
|
self.lastTick = Date()
|
||||||
Task { await self.watchTicks() }
|
self.tickTask?.cancel()
|
||||||
|
self.tickTask = Task { [weak self] in
|
||||||
|
guard let self else { return }
|
||||||
|
await self.watchTicks()
|
||||||
|
}
|
||||||
let frame = GatewayFrame.helloOk(ok)
|
let frame = GatewayFrame.helloOk(ok)
|
||||||
NotificationCenter.default.post(name: .gatewaySnapshot, object: frame)
|
NotificationCenter.default.post(name: .gatewaySnapshot, object: frame)
|
||||||
return
|
return
|
||||||
|
|
@ -314,6 +346,7 @@ actor GatewayChannelActor {
|
||||||
let delay = self.backoffMs / 1000
|
let delay = self.backoffMs / 1000
|
||||||
self.backoffMs = min(self.backoffMs * 2, 30000)
|
self.backoffMs = min(self.backoffMs * 2, 30000)
|
||||||
try? await Task.sleep(nanoseconds: UInt64(delay * 1_000_000_000))
|
try? await Task.sleep(nanoseconds: UInt64(delay * 1_000_000_000))
|
||||||
|
guard self.shouldReconnect else { return }
|
||||||
do {
|
do {
|
||||||
try await self.connect()
|
try await self.connect()
|
||||||
} catch {
|
} catch {
|
||||||
|
|
@ -414,21 +447,4 @@ actor GatewayChannelActor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
actor GatewayChannel {
|
// Intentionally no `GatewayChannel` wrapper: the app should use the single shared `GatewayConnection`.
|
||||||
private var inner: GatewayChannelActor?
|
|
||||||
|
|
||||||
func configure(url: URL, token: String?) {
|
|
||||||
self.inner = GatewayChannelActor(url: url, token: token)
|
|
||||||
}
|
|
||||||
|
|
||||||
func request(
|
|
||||||
method: String,
|
|
||||||
params: [String: AnyCodable]?,
|
|
||||||
timeoutMs: Double? = nil) async throws -> Data
|
|
||||||
{
|
|
||||||
guard let inner else {
|
|
||||||
throw NSError(domain: "Gateway", code: 0, userInfo: [NSLocalizedDescriptionKey: "not configured"])
|
|
||||||
}
|
|
||||||
return try await inner.request(method: method, params: params, timeoutMs: timeoutMs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,80 @@
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
/// Single, shared Gateway websocket connection for the whole app.
|
||||||
|
///
|
||||||
|
/// This owns exactly one `GatewayChannelActor` and reuses it across all callers
|
||||||
|
/// (ControlChannel, AgentRPC, SwiftUI WebChat, etc.).
|
||||||
|
actor GatewayConnection {
|
||||||
|
static let shared = GatewayConnection()
|
||||||
|
|
||||||
|
typealias Config = (url: URL, token: String?)
|
||||||
|
|
||||||
|
private let configProvider: @Sendable () async throws -> Config
|
||||||
|
private let sessionBox: WebSocketSessionBox?
|
||||||
|
|
||||||
|
private var client: GatewayChannelActor?
|
||||||
|
private var configuredURL: URL?
|
||||||
|
private var configuredToken: String?
|
||||||
|
|
||||||
|
init(
|
||||||
|
configProvider: @escaping @Sendable () async throws -> Config = GatewayConnection.defaultConfigProvider,
|
||||||
|
sessionBox: WebSocketSessionBox? = nil)
|
||||||
|
{
|
||||||
|
self.configProvider = configProvider
|
||||||
|
self.sessionBox = sessionBox
|
||||||
|
}
|
||||||
|
|
||||||
|
func request(
|
||||||
|
method: String,
|
||||||
|
params: [String: AnyCodable]?,
|
||||||
|
timeoutMs: Double? = nil) async throws -> Data
|
||||||
|
{
|
||||||
|
let cfg = try await self.configProvider()
|
||||||
|
await self.configure(url: cfg.url, token: cfg.token)
|
||||||
|
guard let client else {
|
||||||
|
throw NSError(domain: "Gateway", code: 0, userInfo: [NSLocalizedDescriptionKey: "gateway not configured"])
|
||||||
|
}
|
||||||
|
return try await client.request(method: method, params: params, timeoutMs: timeoutMs)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Ensure the underlying socket is configured (and replaced if config changed).
|
||||||
|
func refresh() async throws {
|
||||||
|
let cfg = try await self.configProvider()
|
||||||
|
await self.configure(url: cfg.url, token: cfg.token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func shutdown() async {
|
||||||
|
if let client {
|
||||||
|
await client.shutdown()
|
||||||
|
}
|
||||||
|
self.client = nil
|
||||||
|
self.configuredURL = nil
|
||||||
|
self.configuredToken = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
private func configure(url: URL, token: String?) async {
|
||||||
|
if self.client != nil, self.configuredURL == url, self.configuredToken == token {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if let client {
|
||||||
|
await client.shutdown()
|
||||||
|
}
|
||||||
|
self.client = GatewayChannelActor(url: url, token: token, session: self.sessionBox)
|
||||||
|
self.configuredURL = url
|
||||||
|
self.configuredToken = token
|
||||||
|
}
|
||||||
|
|
||||||
|
private static func defaultConfigProvider() async throws -> Config {
|
||||||
|
let mode = await MainActor.run { AppStateStore.shared.connectionMode }
|
||||||
|
let token = ProcessInfo.processInfo.environment["CLAWDIS_GATEWAY_TOKEN"]
|
||||||
|
switch mode {
|
||||||
|
case .local:
|
||||||
|
let port = GatewayEnvironment.gatewayPort()
|
||||||
|
return (URL(string: "ws://127.0.0.1:\(port)")!, token)
|
||||||
|
case .remote:
|
||||||
|
let forwarded = try await RemoteTunnelManager.shared.ensureControlTunnel()
|
||||||
|
return (URL(string: "ws://127.0.0.1:\(Int(forwarded))")!, token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -149,12 +149,8 @@ final class GatewayProcessManager: ObservableObject {
|
||||||
/// If successful, mark status as attached and skip spawning a new process.
|
/// If successful, mark status as attached and skip spawning a new process.
|
||||||
private func attachExistingGatewayIfAvailable() async -> Bool {
|
private func attachExistingGatewayIfAvailable() async -> Bool {
|
||||||
let port = GatewayEnvironment.gatewayPort()
|
let port = GatewayEnvironment.gatewayPort()
|
||||||
guard let url = URL(string: "ws://127.0.0.1:\(port)") else { return false }
|
|
||||||
let token = ProcessInfo.processInfo.environment["CLAWDIS_GATEWAY_TOKEN"]
|
|
||||||
let channel = GatewayChannel()
|
|
||||||
await channel.configure(url: url, token: token)
|
|
||||||
do {
|
do {
|
||||||
let data = try await channel.request(method: "health", params: nil)
|
let data = try await GatewayConnection.shared.request(method: "health", params: nil)
|
||||||
let details: String
|
let details: String
|
||||||
if let snap = decodeHealthSnapshot(from: data) {
|
if let snap = decodeHealthSnapshot(from: data) {
|
||||||
let linked = snap.web.linked ? "linked" : "not linked"
|
let linked = snap.web.linked ? "linked" : "not linked"
|
||||||
|
|
|
||||||
|
|
@ -92,16 +92,6 @@ final class HealthStore: ObservableObject {
|
||||||
defer { self.isRefreshing = false }
|
defer { self.isRefreshing = false }
|
||||||
|
|
||||||
do {
|
do {
|
||||||
let mode = AppStateStore.shared.connectionMode
|
|
||||||
switch mode {
|
|
||||||
case .local:
|
|
||||||
try await ControlChannel.shared.configure(mode: .local)
|
|
||||||
case .remote:
|
|
||||||
let target = AppStateStore.shared.remoteTarget
|
|
||||||
let identity = AppStateStore.shared.remoteIdentity
|
|
||||||
try await ControlChannel.shared.configure(mode: .remote(target: target, identity: identity))
|
|
||||||
}
|
|
||||||
|
|
||||||
let data = try await ControlChannel.shared.health(timeout: 15)
|
let data = try await ControlChannel.shared.health(timeout: 15)
|
||||||
if let decoded = decodeHealthSnapshot(from: data) {
|
if let decoded = decodeHealthSnapshot(from: data) {
|
||||||
self.snapshot = decoded
|
self.snapshot = decoded
|
||||||
|
|
|
||||||
|
|
@ -189,6 +189,7 @@ final class AppDelegate: NSObject, NSApplicationDelegate {
|
||||||
WebChatManager.shared.resetTunnels()
|
WebChatManager.shared.resetTunnels()
|
||||||
Task { await RemoteTunnelManager.shared.stopAll() }
|
Task { await RemoteTunnelManager.shared.stopAll() }
|
||||||
Task { await AgentRPC.shared.shutdown() }
|
Task { await AgentRPC.shared.shutdown() }
|
||||||
|
Task { await GatewayConnection.shared.shutdown() }
|
||||||
Task { await self.socketServer.stop() }
|
Task { await self.socketServer.stop() }
|
||||||
Task { await BridgeServer.shared.stop() }
|
Task { await BridgeServer.shared.stop() }
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -79,11 +79,8 @@ final class WebChatViewModel: ObservableObject {
|
||||||
@Published var healthOK: Bool = true
|
@Published var healthOK: Bool = true
|
||||||
|
|
||||||
private let sessionKey: String
|
private let sessionKey: String
|
||||||
private let gateway = GatewayChannel()
|
|
||||||
private var gatewayConfigured = false
|
|
||||||
private var eventToken: NSObjectProtocol?
|
private var eventToken: NSObjectProtocol?
|
||||||
private var pendingRuns = Set<String>()
|
private var pendingRuns = Set<String>()
|
||||||
private var currentPort: Int?
|
|
||||||
|
|
||||||
init(sessionKey: String) {
|
init(sessionKey: String) {
|
||||||
self.sessionKey = sessionKey
|
self.sessionKey = sessionKey
|
||||||
|
|
@ -141,7 +138,6 @@ final class WebChatViewModel: ObservableObject {
|
||||||
self.isLoading = true
|
self.isLoading = true
|
||||||
defer { self.isLoading = false }
|
defer { self.isLoading = false }
|
||||||
do {
|
do {
|
||||||
try await self.ensureGatewayConfigured()
|
|
||||||
let payload = try await self.requestHistory()
|
let payload = try await self.requestHistory()
|
||||||
self.messages = payload.messages ?? []
|
self.messages = payload.messages ?? []
|
||||||
if let level = payload.thinkingLevel, !level.isEmpty {
|
if let level = payload.thinkingLevel, !level.isEmpty {
|
||||||
|
|
@ -157,12 +153,6 @@ final class WebChatViewModel: ObservableObject {
|
||||||
guard !self.isSending else { return }
|
guard !self.isSending else { return }
|
||||||
let trimmed = self.input.trimmingCharacters(in: .whitespacesAndNewlines)
|
let trimmed = self.input.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||||
guard !trimmed.isEmpty || !self.attachments.isEmpty else { return }
|
guard !trimmed.isEmpty || !self.attachments.isEmpty else { return }
|
||||||
do {
|
|
||||||
try await self.ensureGatewayConfigured()
|
|
||||||
} catch {
|
|
||||||
self.errorText = error.localizedDescription
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
self.isSending = true
|
self.isSending = true
|
||||||
self.errorText = nil
|
self.errorText = nil
|
||||||
|
|
@ -202,7 +192,7 @@ final class WebChatViewModel: ObservableObject {
|
||||||
"idempotencyKey": AnyCodable(runId),
|
"idempotencyKey": AnyCodable(runId),
|
||||||
"timeoutMs": AnyCodable(30_000)
|
"timeoutMs": AnyCodable(30_000)
|
||||||
]
|
]
|
||||||
let data = try await self.gateway.request(method: "chat.send", params: params)
|
let data = try await GatewayConnection.shared.request(method: "chat.send", params: params)
|
||||||
let response = try JSONDecoder().decode(ChatSendResponse.self, from: data)
|
let response = try JSONDecoder().decode(ChatSendResponse.self, from: data)
|
||||||
self.pendingRuns.insert(response.runId)
|
self.pendingRuns.insert(response.runId)
|
||||||
} catch {
|
} catch {
|
||||||
|
|
@ -215,26 +205,8 @@ final class WebChatViewModel: ObservableObject {
|
||||||
self.isSending = false
|
self.isSending = false
|
||||||
}
|
}
|
||||||
|
|
||||||
private func ensureGatewayConfigured() async throws {
|
|
||||||
guard !self.gatewayConfigured else { return }
|
|
||||||
let port = try await self.resolveGatewayPort()
|
|
||||||
self.currentPort = port
|
|
||||||
let url = URL(string: "ws://127.0.0.1:\(port)")!
|
|
||||||
let token = ProcessInfo.processInfo.environment["CLAWDIS_GATEWAY_TOKEN"]
|
|
||||||
await self.gateway.configure(url: url, token: token)
|
|
||||||
self.gatewayConfigured = true
|
|
||||||
}
|
|
||||||
|
|
||||||
private func resolveGatewayPort() async throws -> Int {
|
|
||||||
if CommandResolver.connectionModeIsRemote() {
|
|
||||||
let forwarded = try await RemoteTunnelManager.shared.ensureControlTunnel()
|
|
||||||
return Int(forwarded)
|
|
||||||
}
|
|
||||||
return GatewayEnvironment.gatewayPort()
|
|
||||||
}
|
|
||||||
|
|
||||||
private func requestHistory() async throws -> ChatHistoryPayload {
|
private func requestHistory() async throws -> ChatHistoryPayload {
|
||||||
let data = try await self.gateway.request(
|
let data = try await GatewayConnection.shared.request(
|
||||||
method: "chat.history",
|
method: "chat.history",
|
||||||
params: ["sessionKey": AnyCodable(self.sessionKey)])
|
params: ["sessionKey": AnyCodable(self.sessionKey)])
|
||||||
return try JSONDecoder().decode(ChatHistoryPayload.self, from: data)
|
return try JSONDecoder().decode(ChatHistoryPayload.self, from: data)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,160 @@
|
||||||
|
import Foundation
|
||||||
|
import os
|
||||||
|
import Testing
|
||||||
|
@testable import Clawdis
|
||||||
|
|
||||||
|
@Suite struct GatewayConnectionTests {
|
||||||
|
private final class FakeWebSocketTask: WebSocketTasking, @unchecked Sendable {
|
||||||
|
private let pendingReceiveHandler =
|
||||||
|
OSAllocatedUnfairLock<(@Sendable (Result<URLSessionWebSocketTask.Message, Error>) -> Void)?>(initialState: nil)
|
||||||
|
private let cancelCount = OSAllocatedUnfairLock(initialState: 0)
|
||||||
|
private let sendCount = OSAllocatedUnfairLock(initialState: 0)
|
||||||
|
|
||||||
|
var state: URLSessionTask.State = .suspended
|
||||||
|
|
||||||
|
func snapshotCancelCount() -> Int { self.cancelCount.withLock { $0 } }
|
||||||
|
|
||||||
|
func resume() {
|
||||||
|
self.state = .running
|
||||||
|
}
|
||||||
|
|
||||||
|
func cancel(with closeCode: URLSessionWebSocketTask.CloseCode, reason: Data?) {
|
||||||
|
_ = (closeCode, reason)
|
||||||
|
self.state = .canceling
|
||||||
|
self.cancelCount.withLock { $0 += 1 }
|
||||||
|
let handler = self.pendingReceiveHandler.withLock { handler in
|
||||||
|
defer { handler = nil }
|
||||||
|
return handler
|
||||||
|
}
|
||||||
|
handler?(Result<URLSessionWebSocketTask.Message, Error>.failure(URLError(.cancelled)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func send(_ message: URLSessionWebSocketTask.Message) async throws {
|
||||||
|
let currentSendCount = self.sendCount.withLock { count in
|
||||||
|
defer { count += 1 }
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
// First send is the hello frame. Subsequent sends are request frames.
|
||||||
|
if currentSendCount == 0 { return }
|
||||||
|
|
||||||
|
guard case let .data(data) = message else { return }
|
||||||
|
guard
|
||||||
|
let obj = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
|
||||||
|
(obj["type"] as? String) == "req",
|
||||||
|
let id = obj["id"] as? String
|
||||||
|
else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = Self.responseData(id: id)
|
||||||
|
let handler = self.pendingReceiveHandler.withLock { $0 }
|
||||||
|
handler?(Result<URLSessionWebSocketTask.Message, Error>.success(.data(response)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func receive() async throws -> URLSessionWebSocketTask.Message {
|
||||||
|
.data(Self.helloOkData())
|
||||||
|
}
|
||||||
|
|
||||||
|
func receive(
|
||||||
|
completionHandler: @escaping @Sendable (Result<URLSessionWebSocketTask.Message, Error>) -> Void)
|
||||||
|
{
|
||||||
|
self.pendingReceiveHandler.withLock { $0 = completionHandler }
|
||||||
|
}
|
||||||
|
|
||||||
|
private static func helloOkData() -> Data {
|
||||||
|
let json = """
|
||||||
|
{
|
||||||
|
"type": "hello-ok",
|
||||||
|
"protocol": 1,
|
||||||
|
"server": { "version": "test", "connId": "test" },
|
||||||
|
"features": { "methods": [], "events": [] },
|
||||||
|
"snapshot": {
|
||||||
|
"presence": [ { "ts": 1 } ],
|
||||||
|
"health": {},
|
||||||
|
"stateVersion": { "presence": 0, "health": 0 },
|
||||||
|
"uptimeMs": 0
|
||||||
|
},
|
||||||
|
"policy": { "maxPayload": 1, "maxBufferedBytes": 1, "tickIntervalMs": 30000 }
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
return Data(json.utf8)
|
||||||
|
}
|
||||||
|
|
||||||
|
private static func responseData(id: String) -> Data {
|
||||||
|
let json = """
|
||||||
|
{
|
||||||
|
"type": "res",
|
||||||
|
"id": "\(id)",
|
||||||
|
"ok": true,
|
||||||
|
"payload": { "ok": true }
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
return Data(json.utf8)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private final class FakeWebSocketSession: WebSocketSessioning, @unchecked Sendable {
|
||||||
|
private let makeCount = OSAllocatedUnfairLock(initialState: 0)
|
||||||
|
private let tasks = OSAllocatedUnfairLock(initialState: [FakeWebSocketTask]())
|
||||||
|
|
||||||
|
func snapshotMakeCount() -> Int { self.makeCount.withLock { $0 } }
|
||||||
|
func snapshotCancelCount() -> Int {
|
||||||
|
self.tasks.withLock { tasks in
|
||||||
|
tasks.reduce(0) { $0 + $1.snapshotCancelCount() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeWebSocketTask(url: URL) -> WebSocketTaskBox {
|
||||||
|
_ = url
|
||||||
|
self.makeCount.withLock { $0 += 1 }
|
||||||
|
let task = FakeWebSocketTask()
|
||||||
|
self.tasks.withLock { $0.append(task) }
|
||||||
|
return WebSocketTaskBox(task: task)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private final class ConfigSource: @unchecked Sendable {
|
||||||
|
private let token = OSAllocatedUnfairLock<String?>(initialState: nil)
|
||||||
|
|
||||||
|
init(token: String?) {
|
||||||
|
self.token.withLock { $0 = token }
|
||||||
|
}
|
||||||
|
|
||||||
|
func snapshotToken() -> String? { self.token.withLock { $0 } }
|
||||||
|
func setToken(_ value: String?) { self.token.withLock { $0 = value } }
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test func requestReusesSingleWebSocketForSameConfig() async throws {
|
||||||
|
let session = FakeWebSocketSession()
|
||||||
|
let url = URL(string: "ws://example.invalid")!
|
||||||
|
let cfg = ConfigSource(token: nil)
|
||||||
|
let conn = GatewayConnection(
|
||||||
|
configProvider: { (url, cfg.snapshotToken()) },
|
||||||
|
sessionBox: WebSocketSessionBox(session: session))
|
||||||
|
|
||||||
|
_ = try await conn.request(method: "status", params: nil)
|
||||||
|
#expect(session.snapshotMakeCount() == 1)
|
||||||
|
|
||||||
|
_ = try await conn.request(method: "status", params: nil)
|
||||||
|
#expect(session.snapshotMakeCount() == 1)
|
||||||
|
#expect(session.snapshotCancelCount() == 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test func requestReconfiguresAndCancelsOnTokenChange() async throws {
|
||||||
|
let session = FakeWebSocketSession()
|
||||||
|
let url = URL(string: "ws://example.invalid")!
|
||||||
|
let cfg = ConfigSource(token: "a")
|
||||||
|
let conn = GatewayConnection(
|
||||||
|
configProvider: { (url, cfg.snapshotToken()) },
|
||||||
|
sessionBox: WebSocketSessionBox(session: session))
|
||||||
|
|
||||||
|
_ = try await conn.request(method: "status", params: nil)
|
||||||
|
#expect(session.snapshotMakeCount() == 1)
|
||||||
|
|
||||||
|
cfg.setToken("b")
|
||||||
|
_ = try await conn.request(method: "status", params: nil)
|
||||||
|
#expect(session.snapshotMakeCount() == 2)
|
||||||
|
#expect(session.snapshotCancelCount() == 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -101,7 +101,7 @@ import Testing
|
||||||
let channel = GatewayChannelActor(
|
let channel = GatewayChannelActor(
|
||||||
url: URL(string: "ws://example.invalid")!,
|
url: URL(string: "ws://example.invalid")!,
|
||||||
token: nil,
|
token: nil,
|
||||||
session: session)
|
session: WebSocketSessionBox(session: session))
|
||||||
|
|
||||||
let t1 = Task { try await channel.connect() }
|
let t1 = Task { try await channel.connect() }
|
||||||
let t2 = Task { try await channel.connect() }
|
let t2 = Task { try await channel.connect() }
|
||||||
|
|
@ -117,7 +117,7 @@ import Testing
|
||||||
let channel = GatewayChannelActor(
|
let channel = GatewayChannelActor(
|
||||||
url: URL(string: "ws://example.invalid")!,
|
url: URL(string: "ws://example.invalid")!,
|
||||||
token: nil,
|
token: nil,
|
||||||
session: session)
|
session: WebSocketSessionBox(session: session))
|
||||||
|
|
||||||
let t1 = Task { try await channel.connect() }
|
let t1 = Task { try await channel.connect() }
|
||||||
let t2 = Task { try await channel.connect() }
|
let t2 = Task { try await channel.connect() }
|
||||||
|
|
|
||||||
|
|
@ -93,7 +93,7 @@ import Testing
|
||||||
let channel = GatewayChannelActor(
|
let channel = GatewayChannelActor(
|
||||||
url: URL(string: "ws://example.invalid")!,
|
url: URL(string: "ws://example.invalid")!,
|
||||||
token: nil,
|
token: nil,
|
||||||
session: session)
|
session: WebSocketSessionBox(session: session))
|
||||||
|
|
||||||
do {
|
do {
|
||||||
_ = try await channel.request(method: "test", params: nil, timeoutMs: 10)
|
_ = try await channel.request(method: "test", params: nil, timeoutMs: 10)
|
||||||
|
|
@ -108,4 +108,3 @@ import Testing
|
||||||
try? await Task.sleep(nanoseconds: 250 * 1_000_000)
|
try? await Task.sleep(nanoseconds: 250 * 1_000_000)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue