diff --git a/src/agents/pi-embedded.ts b/src/agents/pi-embedded.ts index 40b92f8cf..16f9c28a3 100644 --- a/src/agents/pi-embedded.ts +++ b/src/agents/pi-embedded.ts @@ -5,7 +5,6 @@ import { Agent, type AgentEvent, type AppMessage, - ProviderTransport, type ThinkingLevel, } from "@mariozechner/pi-agent-core"; import { @@ -44,6 +43,7 @@ import { createClawdisCodingTools, sanitizeContentBlocksImages, } from "./pi-tools.js"; +import { SteerableProviderTransport } from "./steerable-provider-transport.js"; import { applySkillEnvOverrides, applySkillEnvOverridesFromSnapshot, @@ -82,6 +82,24 @@ export type EmbeddedPiRunResult = { meta: EmbeddedPiRunMeta; }; +type EmbeddedPiQueueHandle = { + queueMessage: (text: string) => Promise; + isStreaming: () => boolean; +}; + +const ACTIVE_EMBEDDED_RUNS = new Map(); + +export function queueEmbeddedPiMessage( + sessionId: string, + text: string, +): boolean { + const handle = ACTIVE_EMBEDDED_RUNS.get(sessionId); + if (!handle) return false; + if (!handle.isStreaming()) return false; + void handle.queueMessage(text); + return true; +} + function mapThinkingLevel(level?: ThinkLevel): ThinkingLevel { // pi-agent-core supports "xhigh" too; Clawdis doesn't surface it for now. if (!level) return "off"; @@ -310,7 +328,7 @@ export async function runEmbeddedPiAgent(params: { }, messageTransformer, queueMode: settingsManager.getQueueMode(), - transport: new ProviderTransport({ + transport: new SteerableProviderTransport({ getApiKey: async (providerName) => { const key = await getApiKeyForProvider(providerName); if (!key) { @@ -338,6 +356,13 @@ export async function runEmbeddedPiAgent(params: { sessionManager, settingsManager, }); + const queueHandle: EmbeddedPiQueueHandle = { + queueMessage: async (text: string) => { + await session.queueMessage(text); + }, + isStreaming: () => session.isStreaming, + }; + ACTIVE_EMBEDDED_RUNS.set(params.sessionId, queueHandle); const assistantTexts: string[] = []; const toolDebouncer = createToolDebouncer((toolName, metas) => { @@ -553,6 +578,9 @@ export async function runEmbeddedPiAgent(params: { clearTimeout(abortTimer); unsubscribe(); toolDebouncer.flush(); + if (ACTIVE_EMBEDDED_RUNS.get(params.sessionId) === queueHandle) { + ACTIVE_EMBEDDED_RUNS.delete(params.sessionId); + } session.dispose(); params.abortSignal?.removeEventListener?.("abort", onAbort); } diff --git a/src/agents/steerable-agent-loop.ts b/src/agents/steerable-agent-loop.ts new file mode 100644 index 000000000..ef1849d33 --- /dev/null +++ b/src/agents/steerable-agent-loop.ts @@ -0,0 +1,437 @@ +import { streamSimple, validateToolArguments } from "@mariozechner/pi-ai"; +import type { + AssistantMessage, + Context, + Message, + ToolResultMessage, + UserMessage, +} from "@mariozechner/pi-ai"; +import type { + AgentContext, + AgentEvent, + AgentLoopConfig, + AgentTool, + AgentToolResult, + QueuedMessage, +} from "@mariozechner/pi-ai"; + +class EventStream implements AsyncIterable { + private queue: T[] = []; + private waiting: ((value: IteratorResult) => void)[] = []; + private done = false; + private finalResultPromise: Promise; + private resolveFinalResult!: (result: R) => void; + + constructor( + private isComplete: (event: T) => boolean, + private extractResult: (event: T) => R, + ) { + this.finalResultPromise = new Promise((resolve) => { + this.resolveFinalResult = resolve; + }); + } + + push(event: T): void { + if (this.done) return; + + if (this.isComplete(event)) { + this.done = true; + this.resolveFinalResult(this.extractResult(event)); + } + + const waiter = this.waiting.shift(); + if (waiter) { + waiter({ value: event, done: false }); + } else { + this.queue.push(event); + } + } + + end(result?: R): void { + this.done = true; + if (result !== undefined) { + this.resolveFinalResult(result); + } + while (this.waiting.length > 0) { + const waiter = this.waiting.shift()!; + waiter({ value: undefined as never, done: true }); + } + } + + async *[Symbol.asyncIterator](): AsyncIterator { + while (true) { + if (this.queue.length > 0) { + yield this.queue.shift()!; + } else if (this.done) { + return; + } else { + const result = await new Promise>((resolve) => + this.waiting.push(resolve), + ); + if (result.done) return; + yield result.value; + } + } + } + + result(): Promise { + return this.finalResultPromise; + } +} + +function createAgentStream(): EventStream { + return new EventStream( + (event) => event.type === "agent_end", + (event) => (event.type === "agent_end" ? event.messages : []), + ); +} + +export function agentLoop( + prompt: UserMessage, + context: AgentContext, + config: AgentLoopConfig, + signal?: AbortSignal, + streamFn?: typeof streamSimple, +): EventStream { + const stream = createAgentStream(); + + void (async () => { + const newMessages: AgentContext["messages"] = [prompt]; + const currentContext: AgentContext = { + ...context, + messages: [...context.messages, prompt], + }; + + stream.push({ type: "agent_start" }); + stream.push({ type: "turn_start" }); + stream.push({ type: "message_start", message: prompt }); + stream.push({ type: "message_end", message: prompt }); + + await runLoop(currentContext, newMessages, config, signal, stream, streamFn); + })(); + + return stream; +} + +export function agentLoopContinue( + context: AgentContext, + config: AgentLoopConfig, + signal?: AbortSignal, + streamFn?: typeof streamSimple, +): EventStream { + const lastMessage = context.messages[context.messages.length - 1]; + if (!lastMessage) { + throw new Error("Cannot continue: no messages in context"); + } + if (lastMessage.role !== "user" && lastMessage.role !== "toolResult") { + throw new Error( + `Cannot continue from message role: ${lastMessage.role}. Expected 'user' or 'toolResult'.`, + ); + } + + const stream = createAgentStream(); + + void (async () => { + const newMessages: AgentContext["messages"] = []; + const currentContext: AgentContext = { ...context }; + + stream.push({ type: "agent_start" }); + stream.push({ type: "turn_start" }); + + await runLoop(currentContext, newMessages, config, signal, stream, streamFn); + })(); + + return stream; +} + +async function runLoop( + currentContext: AgentContext, + newMessages: AgentContext["messages"], + config: AgentLoopConfig, + signal: AbortSignal | undefined, + stream: EventStream, + streamFn?: typeof streamSimple, +): Promise { + let hasMoreToolCalls = true; + let firstTurn = true; + let queuedMessages: QueuedMessage[] = + (await config.getQueuedMessages?.()) || []; + let queuedAfterTools: QueuedMessage[] | null = null; + + while (hasMoreToolCalls || queuedMessages.length > 0) { + if (!firstTurn) { + stream.push({ type: "turn_start" }); + } else { + firstTurn = false; + } + + if (queuedMessages.length > 0) { + for (const { original, llm } of queuedMessages) { + stream.push({ type: "message_start", message: original }); + stream.push({ type: "message_end", message: original }); + if (llm) { + currentContext.messages.push(llm); + newMessages.push(llm); + } + } + queuedMessages = []; + } + + const message = await streamAssistantResponse( + currentContext, + config, + signal, + stream, + streamFn, + ); + newMessages.push(message); + + if (message.stopReason === "error" || message.stopReason === "aborted") { + stream.push({ type: "turn_end", message, toolResults: [] }); + stream.push({ type: "agent_end", messages: newMessages }); + stream.end(newMessages); + return; + } + + const toolCalls = message.content.filter((c) => c.type === "toolCall"); + hasMoreToolCalls = toolCalls.length > 0; + + const toolResults: ToolResultMessage[] = []; + if (hasMoreToolCalls) { + const toolExecution = await executeToolCalls( + currentContext.tools, + message, + signal, + stream, + config.getQueuedMessages, + ); + toolResults.push(...toolExecution.toolResults); + queuedAfterTools = toolExecution.queuedMessages ?? null; + currentContext.messages.push(...toolResults); + newMessages.push(...toolResults); + } + stream.push({ type: "turn_end", message, toolResults: toolResults }); + + if (queuedAfterTools && queuedAfterTools.length > 0) { + queuedMessages = queuedAfterTools; + queuedAfterTools = null; + } else { + queuedMessages = (await config.getQueuedMessages?.()) || []; + } + } + + stream.push({ type: "agent_end", messages: newMessages }); + stream.end(newMessages); +} + +async function streamAssistantResponse( + context: AgentContext, + config: AgentLoopConfig, + signal: AbortSignal | undefined, + stream: EventStream, + streamFn?: typeof streamSimple, +): Promise { + const processedMessages = config.preprocessor + ? await config.preprocessor(context.messages, signal) + : [...context.messages]; + const processedContext: Context = { + systemPrompt: context.systemPrompt, + messages: [...processedMessages].map((m) => { + if (m.role === "toolResult") { + const { details, ...rest } = m; + return rest; + } + return m; + }), + tools: context.tools, + }; + + const streamFunction = streamFn || streamSimple; + const resolvedApiKey = + (config.getApiKey ? await config.getApiKey(config.model.provider) : undefined) || + config.apiKey; + + const response = await streamFunction(config.model, processedContext, { + ...config, + apiKey: resolvedApiKey, + signal, + }); + + let partialMessage: AssistantMessage | null = null; + let addedPartial = false; + + for await (const event of response) { + switch (event.type) { + case "start": + partialMessage = event.partial; + context.messages.push(partialMessage); + addedPartial = true; + stream.push({ type: "message_start", message: { ...partialMessage } }); + break; + + case "text_start": + case "text_delta": + case "text_end": + case "thinking_start": + case "thinking_delta": + case "thinking_end": + case "toolcall_start": + case "toolcall_delta": + case "toolcall_end": + if (partialMessage) { + partialMessage = event.partial; + context.messages[context.messages.length - 1] = partialMessage; + stream.push({ + type: "message_update", + assistantMessageEvent: event, + message: { ...partialMessage }, + }); + } + break; + + case "done": + case "error": { + const finalMessage = await response.result(); + if (addedPartial) { + context.messages[context.messages.length - 1] = finalMessage; + } else { + context.messages.push(finalMessage); + } + if (!addedPartial) { + stream.push({ type: "message_start", message: { ...finalMessage } }); + } + stream.push({ type: "message_end", message: finalMessage }); + return finalMessage; + } + } + } + + return await response.result(); +} + +async function executeToolCalls( + tools: AgentTool[] | undefined, + assistantMessage: AssistantMessage, + signal: AbortSignal | undefined, + stream: EventStream, + getQueuedMessages?: AgentLoopConfig["getQueuedMessages"], +): Promise<{ + toolResults: ToolResultMessage[]; + queuedMessages?: QueuedMessage[]; +}> { + const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall"); + const results: ToolResultMessage[] = []; + let queuedMessages: QueuedMessage[] | undefined; + + for (let index = 0; index < toolCalls.length; index++) { + const toolCall = toolCalls[index]; + const tool = tools?.find((t) => t.name === toolCall.name); + + stream.push({ + type: "tool_execution_start", + toolCallId: toolCall.id, + toolName: toolCall.name, + args: toolCall.arguments, + }); + + let result: AgentToolResult; + let isError = false; + + try { + if (!tool) throw new Error(`Tool ${toolCall.name} not found`); + const validatedArgs = validateToolArguments(tool, toolCall); + result = await tool.execute(toolCall.id, validatedArgs, signal, (partialResult) => { + stream.push({ + type: "tool_execution_update", + toolCallId: toolCall.id, + toolName: toolCall.name, + args: toolCall.arguments, + partialResult, + }); + }); + } catch (err) { + result = { + content: [ + { type: "text", text: err instanceof Error ? err.message : String(err) }, + ], + details: {} as T, + }; + isError = true; + } + + stream.push({ + type: "tool_execution_end", + toolCallId: toolCall.id, + toolName: toolCall.name, + result, + isError, + }); + + const toolResultMessage: ToolResultMessage = { + role: "toolResult", + toolCallId: toolCall.id, + toolName: toolCall.name, + content: result.content, + details: result.details, + isError, + timestamp: Date.now(), + }; + + results.push(toolResultMessage); + stream.push({ type: "message_start", message: toolResultMessage }); + stream.push({ type: "message_end", message: toolResultMessage }); + + if (getQueuedMessages) { + const queued = await getQueuedMessages(); + if (queued.length > 0) { + queuedMessages = queued; + const remainingCalls = toolCalls.slice(index + 1); + for (const skipped of remainingCalls) { + results.push(skipToolCall(skipped, stream)); + } + break; + } + } + } + + return { toolResults: results, queuedMessages }; +} + +function skipToolCall( + toolCall: Extract, + stream: EventStream, +): ToolResultMessage { + const result: AgentToolResult = { + content: [{ type: "text", text: "Skipped due to queued user message." }], + details: {} as T, + }; + + stream.push({ + type: "tool_execution_start", + toolCallId: toolCall.id, + toolName: toolCall.name, + args: toolCall.arguments, + }); + stream.push({ + type: "tool_execution_end", + toolCallId: toolCall.id, + toolName: toolCall.name, + result, + isError: true, + }); + + const toolResultMessage: ToolResultMessage = { + role: "toolResult", + toolCallId: toolCall.id, + toolName: toolCall.name, + content: result.content, + details: result.details, + isError: true, + timestamp: Date.now(), + }; + + stream.push({ type: "message_start", message: toolResultMessage }); + stream.push({ type: "message_end", message: toolResultMessage }); + + return toolResultMessage; +} diff --git a/src/agents/steerable-provider-transport.ts b/src/agents/steerable-provider-transport.ts new file mode 100644 index 000000000..4beecbbbe --- /dev/null +++ b/src/agents/steerable-provider-transport.ts @@ -0,0 +1,88 @@ +import { + agentLoop, + agentLoopContinue, +} from "./steerable-agent-loop.js"; +import type { + AgentContext, + AgentLoopConfig, + Message, + UserMessage, +} from "@mariozechner/pi-ai"; +import type { + AgentRunConfig, + AgentTransport, +} from "@mariozechner/pi-agent-core"; +import type { ProviderTransportOptions } from "@mariozechner/pi-agent-core"; + +export class SteerableProviderTransport implements AgentTransport { + private options: ProviderTransportOptions; + + constructor(options: ProviderTransportOptions = {}) { + this.options = options; + } + + private getModel(cfg: AgentRunConfig) { + let model = cfg.model; + if (this.options.corsProxyUrl && cfg.model.baseUrl) { + model = { + ...cfg.model, + baseUrl: `${this.options.corsProxyUrl}/?url=${encodeURIComponent(cfg.model.baseUrl)}`, + }; + } + return model; + } + + private buildContext(messages: Message[], cfg: AgentRunConfig): AgentContext { + return { + systemPrompt: cfg.systemPrompt, + messages, + tools: cfg.tools, + }; + } + + private buildLoopConfig( + model: AgentRunConfig["model"], + cfg: AgentRunConfig, + ): AgentLoopConfig { + return { + model, + reasoning: cfg.reasoning, + getApiKey: this.options.getApiKey, + getQueuedMessages: cfg.getQueuedMessages, + }; + } + + async *run( + messages: Message[], + userMessage: Message, + cfg: AgentRunConfig, + signal?: AbortSignal, + ) { + const model = this.getModel(cfg); + const context = this.buildContext(messages, cfg); + const pc = this.buildLoopConfig(model, cfg); + + for await (const ev of agentLoop( + userMessage as unknown as UserMessage, + context, + pc, + signal, + )) { + yield ev; + } + } + + async *continue( + messages: Message[], + cfg: AgentRunConfig, + signal?: AbortSignal, + ) { + const model = this.getModel(cfg); + const context = this.buildContext(messages, cfg); + const pc = this.buildLoopConfig(model, cfg); + + for await (const ev of agentLoopContinue(context, pc, signal)) { + yield ev; + } + } +} diff --git a/src/auto-reply/reply.ts b/src/auto-reply/reply.ts index 3733eecd8..b0d91b15b 100644 --- a/src/auto-reply/reply.ts +++ b/src/auto-reply/reply.ts @@ -6,7 +6,10 @@ import { DEFAULT_MODEL, DEFAULT_PROVIDER, } from "../agents/defaults.js"; -import { runEmbeddedPiAgent } from "../agents/pi-embedded.js"; +import { + queueEmbeddedPiMessage, + runEmbeddedPiAgent, +} from "../agents/pi-embedded.js"; import { buildWorkspaceSkillSnapshot } from "../agents/skills.js"; import { DEFAULT_AGENT_WORKSPACE_DIR, @@ -750,6 +753,25 @@ export async function getReplyFromConfig( const sessionIdFinal = sessionId ?? crypto.randomUUID(); const sessionFile = resolveSessionTranscriptPath(sessionIdFinal); + const queueBodyBase = transcribedText + ? [baseBodyFinal, `Transcript:\n${transcribedText}`] + .filter(Boolean) + .join("\n\n") + : baseBodyFinal; + const queuedBody = mediaNote + ? [mediaNote, mediaReplyHint, queueBodyBase].filter(Boolean).join("\n").trim() + : queueBodyBase; + + if (queueEmbeddedPiMessage(sessionIdFinal, queuedBody)) { + if (sessionEntry && sessionStore && sessionKey) { + sessionEntry.updatedAt = Date.now(); + sessionStore[sessionKey] = sessionEntry; + await saveSessionStore(storePath, sessionStore); + } + cleanupTyping(); + return undefined; + } + await onReplyStart(); try {