chore(pi): bump deps, drop steerable transport
parent
7aeacdcc6c
commit
b635e83651
|
|
@ -41,9 +41,9 @@
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@grammyjs/transformer-throttler": "^1.2.1",
|
"@grammyjs/transformer-throttler": "^1.2.1",
|
||||||
"@homebridge/ciao": "^1.3.4",
|
"@homebridge/ciao": "^1.3.4",
|
||||||
"@mariozechner/pi-agent-core": "^0.24.5",
|
"@mariozechner/pi-agent-core": "^0.25.0",
|
||||||
"@mariozechner/pi-ai": "^0.24.5",
|
"@mariozechner/pi-ai": "^0.25.0",
|
||||||
"@mariozechner/pi-coding-agent": "^0.24.5",
|
"@mariozechner/pi-coding-agent": "^0.25.0",
|
||||||
"@sinclair/typebox": "^0.34.41",
|
"@sinclair/typebox": "^0.34.41",
|
||||||
"@whiskeysockets/baileys": "7.0.0-rc.9",
|
"@whiskeysockets/baileys": "7.0.0-rc.9",
|
||||||
"ajv": "^8.17.1",
|
"ajv": "^8.17.1",
|
||||||
|
|
|
||||||
|
|
@ -15,14 +15,14 @@ importers:
|
||||||
specifier: ^1.3.4
|
specifier: ^1.3.4
|
||||||
version: 1.3.4
|
version: 1.3.4
|
||||||
'@mariozechner/pi-agent-core':
|
'@mariozechner/pi-agent-core':
|
||||||
specifier: ^0.24.5
|
specifier: ^0.25.0
|
||||||
version: 0.24.5(ws@8.18.3)(zod@4.1.13)
|
version: 0.25.0(ws@8.18.3)(zod@4.1.13)
|
||||||
'@mariozechner/pi-ai':
|
'@mariozechner/pi-ai':
|
||||||
specifier: ^0.24.5
|
specifier: ^0.25.0
|
||||||
version: 0.24.5(ws@8.18.3)(zod@4.1.13)
|
version: 0.25.0(ws@8.18.3)(zod@4.1.13)
|
||||||
'@mariozechner/pi-coding-agent':
|
'@mariozechner/pi-coding-agent':
|
||||||
specifier: ^0.24.5
|
specifier: ^0.25.0
|
||||||
version: 0.24.5(ws@8.18.3)(zod@4.1.13)
|
version: 0.25.0(ws@8.18.3)(zod@4.1.13)
|
||||||
'@sinclair/typebox':
|
'@sinclair/typebox':
|
||||||
specifier: ^0.34.41
|
specifier: ^0.34.41
|
||||||
version: 0.34.41
|
version: 0.34.41
|
||||||
|
|
@ -811,21 +811,21 @@ packages:
|
||||||
peerDependencies:
|
peerDependencies:
|
||||||
lit: ^3.3.1
|
lit: ^3.3.1
|
||||||
|
|
||||||
'@mariozechner/pi-agent-core@0.24.5':
|
'@mariozechner/pi-agent-core@0.25.0':
|
||||||
resolution: {integrity: sha512-36rj74NcZItzgsWCEpAlrcDxrTNsi5NlYgkU2/tnkBABItvFCHCSJyDUFns5c4Swv/iPkkeqKWPrpTReO1ta9A==}
|
resolution: {integrity: sha512-aiM0GvkmHJtFudNGlXiuLr/IqRot1Sus9vqrarVf/gF5ooubYyGYhP6QotAfbFqI0z6HpFa2O3mx8KEp0AiBKg==}
|
||||||
engines: {node: '>=20.0.0'}
|
engines: {node: '>=20.0.0'}
|
||||||
|
|
||||||
'@mariozechner/pi-ai@0.24.5':
|
'@mariozechner/pi-ai@0.25.0':
|
||||||
resolution: {integrity: sha512-7DKydy/xgOwDr3uZgFl41jAc3CNLsZebW0Z19bNQ9iI0Y1f5UmYTTELuirwaAa0vlKeS0+HoHU2t0e5TFo/I6g==}
|
resolution: {integrity: sha512-N3INs/PNIEYx/U8tM6NaV75Gpx263o4b+YYxsD1Ag9ratdzz+JxL2ATYENi+Ma+BjsMaowPCMO2oeotHdsr/cA==}
|
||||||
engines: {node: '>=20.0.0'}
|
engines: {node: '>=20.0.0'}
|
||||||
|
|
||||||
'@mariozechner/pi-coding-agent@0.24.5':
|
'@mariozechner/pi-coding-agent@0.25.0':
|
||||||
resolution: {integrity: sha512-e0M4zoNWsXL2FinABlUEBKgVOWwShTjN3sOBqcvjssKndtpouHqAZDLSyvXbcFycQjImsh8Iq7/l1x2UjmylwA==}
|
resolution: {integrity: sha512-docYKq6zEVZcO5ngb0NTpayeipr+pLCMCeNfwdiC55zNI5nKMg1O4s6aMv2clJ4fUisHP0uhyK9URIohqSadbw==}
|
||||||
engines: {node: '>=20.0.0'}
|
engines: {node: '>=20.0.0'}
|
||||||
hasBin: true
|
hasBin: true
|
||||||
|
|
||||||
'@mariozechner/pi-tui@0.24.5':
|
'@mariozechner/pi-tui@0.25.0':
|
||||||
resolution: {integrity: sha512-ajh81L/qLk4HczTrxPwNXluFa2tRGhAESiczroH0P5qYOG578KNFfVw6UxVm1QhQVjynWXjE0fJacAFeHu/leQ==}
|
resolution: {integrity: sha512-7pU/EPFTYgyEsfcDBb+fzp6BQWr6tmykgMMGZx3Pxvet3NF5HmphAdLBitjmThri+M7lrGaJVrpIRHjQM1CPVQ==}
|
||||||
engines: {node: '>=20.0.0'}
|
engines: {node: '>=20.0.0'}
|
||||||
|
|
||||||
'@mistralai/mistralai@1.10.0':
|
'@mistralai/mistralai@1.10.0':
|
||||||
|
|
@ -3142,10 +3142,10 @@ snapshots:
|
||||||
transitivePeerDependencies:
|
transitivePeerDependencies:
|
||||||
- tailwindcss
|
- tailwindcss
|
||||||
|
|
||||||
'@mariozechner/pi-agent-core@0.24.5(ws@8.18.3)(zod@4.1.13)':
|
'@mariozechner/pi-agent-core@0.25.0(ws@8.18.3)(zod@4.1.13)':
|
||||||
dependencies:
|
dependencies:
|
||||||
'@mariozechner/pi-ai': 0.24.5(ws@8.18.3)(zod@4.1.13)
|
'@mariozechner/pi-ai': 0.25.0(ws@8.18.3)(zod@4.1.13)
|
||||||
'@mariozechner/pi-tui': 0.24.5
|
'@mariozechner/pi-tui': 0.25.0
|
||||||
transitivePeerDependencies:
|
transitivePeerDependencies:
|
||||||
- '@modelcontextprotocol/sdk'
|
- '@modelcontextprotocol/sdk'
|
||||||
- bufferutil
|
- bufferutil
|
||||||
|
|
@ -3154,7 +3154,7 @@ snapshots:
|
||||||
- ws
|
- ws
|
||||||
- zod
|
- zod
|
||||||
|
|
||||||
'@mariozechner/pi-ai@0.24.5(ws@8.18.3)(zod@4.1.13)':
|
'@mariozechner/pi-ai@0.25.0(ws@8.18.3)(zod@4.1.13)':
|
||||||
dependencies:
|
dependencies:
|
||||||
'@anthropic-ai/sdk': 0.71.2(zod@4.1.13)
|
'@anthropic-ai/sdk': 0.71.2(zod@4.1.13)
|
||||||
'@google/genai': 1.34.0
|
'@google/genai': 1.34.0
|
||||||
|
|
@ -3174,11 +3174,11 @@ snapshots:
|
||||||
- ws
|
- ws
|
||||||
- zod
|
- zod
|
||||||
|
|
||||||
'@mariozechner/pi-coding-agent@0.24.5(ws@8.18.3)(zod@4.1.13)':
|
'@mariozechner/pi-coding-agent@0.25.0(ws@8.18.3)(zod@4.1.13)':
|
||||||
dependencies:
|
dependencies:
|
||||||
'@mariozechner/pi-agent-core': 0.24.5(ws@8.18.3)(zod@4.1.13)
|
'@mariozechner/pi-agent-core': 0.25.0(ws@8.18.3)(zod@4.1.13)
|
||||||
'@mariozechner/pi-ai': 0.24.5(ws@8.18.3)(zod@4.1.13)
|
'@mariozechner/pi-ai': 0.25.0(ws@8.18.3)(zod@4.1.13)
|
||||||
'@mariozechner/pi-tui': 0.24.5
|
'@mariozechner/pi-tui': 0.25.0
|
||||||
chalk: 5.6.2
|
chalk: 5.6.2
|
||||||
cli-highlight: 2.1.11
|
cli-highlight: 2.1.11
|
||||||
diff: 8.0.2
|
diff: 8.0.2
|
||||||
|
|
@ -3193,7 +3193,7 @@ snapshots:
|
||||||
- ws
|
- ws
|
||||||
- zod
|
- zod
|
||||||
|
|
||||||
'@mariozechner/pi-tui@0.24.5':
|
'@mariozechner/pi-tui@0.25.0':
|
||||||
dependencies:
|
dependencies:
|
||||||
'@types/mime-types': 2.1.4
|
'@types/mime-types': 2.1.4
|
||||||
chalk: 5.6.2
|
chalk: 5.6.2
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import {
|
||||||
Agent,
|
Agent,
|
||||||
type AgentEvent,
|
type AgentEvent,
|
||||||
type AppMessage,
|
type AppMessage,
|
||||||
|
ProviderTransport,
|
||||||
type ThinkingLevel,
|
type ThinkingLevel,
|
||||||
} from "@mariozechner/pi-agent-core";
|
} from "@mariozechner/pi-agent-core";
|
||||||
import {
|
import {
|
||||||
|
|
@ -50,7 +51,6 @@ import {
|
||||||
loadWorkspaceSkillEntries,
|
loadWorkspaceSkillEntries,
|
||||||
type SkillSnapshot,
|
type SkillSnapshot,
|
||||||
} from "./skills.js";
|
} from "./skills.js";
|
||||||
import { SteerableProviderTransport } from "./steerable-provider-transport.js";
|
|
||||||
import { buildAgentSystemPrompt } from "./system-prompt.js";
|
import { buildAgentSystemPrompt } from "./system-prompt.js";
|
||||||
import { loadWorkspaceBootstrapFiles } from "./workspace.js";
|
import { loadWorkspaceBootstrapFiles } from "./workspace.js";
|
||||||
|
|
||||||
|
|
@ -317,7 +317,6 @@ export async function runEmbeddedPiAgent(params: {
|
||||||
const sessionManager = new SessionManager(false, params.sessionFile);
|
const sessionManager = new SessionManager(false, params.sessionFile);
|
||||||
const settingsManager = new SettingsManager();
|
const settingsManager = new SettingsManager();
|
||||||
|
|
||||||
// TODO(steipete): Drop the steerable transport after pi-mono PR #259 lands and deps are bumped.
|
|
||||||
const agent = new Agent({
|
const agent = new Agent({
|
||||||
initialState: {
|
initialState: {
|
||||||
systemPrompt: systemPromptWithSkills,
|
systemPrompt: systemPromptWithSkills,
|
||||||
|
|
@ -329,7 +328,7 @@ export async function runEmbeddedPiAgent(params: {
|
||||||
},
|
},
|
||||||
messageTransformer,
|
messageTransformer,
|
||||||
queueMode: settingsManager.getQueueMode(),
|
queueMode: settingsManager.getQueueMode(),
|
||||||
transport: new SteerableProviderTransport({
|
transport: new ProviderTransport({
|
||||||
getApiKey: async (providerName) => {
|
getApiKey: async (providerName) => {
|
||||||
const key = await getApiKeyForProvider(providerName);
|
const key = await getApiKeyForProvider(providerName);
|
||||||
if (!key) {
|
if (!key) {
|
||||||
|
|
|
||||||
|
|
@ -1,473 +0,0 @@
|
||||||
import type {
|
|
||||||
AgentContext,
|
|
||||||
AgentEvent,
|
|
||||||
AgentLoopConfig,
|
|
||||||
AgentTool,
|
|
||||||
AgentToolResult,
|
|
||||||
AssistantMessage,
|
|
||||||
Context,
|
|
||||||
Message,
|
|
||||||
QueuedMessage,
|
|
||||||
ToolResultMessage,
|
|
||||||
UserMessage,
|
|
||||||
} from "@mariozechner/pi-ai";
|
|
||||||
import { streamSimple, validateToolArguments } from "@mariozechner/pi-ai";
|
|
||||||
import type { TSchema } from "@sinclair/typebox";
|
|
||||||
|
|
||||||
class EventStream<T, R = T> implements AsyncIterable<T> {
|
|
||||||
private queue: T[] = [];
|
|
||||||
private waiting: ((value: IteratorResult<T>) => void)[] = [];
|
|
||||||
private done = false;
|
|
||||||
private finalResultPromise: Promise<R>;
|
|
||||||
private resolveFinalResult!: (result: R) => void;
|
|
||||||
|
|
||||||
constructor(
|
|
||||||
private isComplete: (event: T) => boolean,
|
|
||||||
private extractResult: (event: T) => R,
|
|
||||||
) {
|
|
||||||
this.finalResultPromise = new Promise((resolve) => {
|
|
||||||
this.resolveFinalResult = resolve;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
push(event: T): void {
|
|
||||||
if (this.done) return;
|
|
||||||
|
|
||||||
if (this.isComplete(event)) {
|
|
||||||
this.done = true;
|
|
||||||
this.resolveFinalResult(this.extractResult(event));
|
|
||||||
}
|
|
||||||
|
|
||||||
const waiter = this.waiting.shift();
|
|
||||||
if (waiter) {
|
|
||||||
waiter({ value: event, done: false });
|
|
||||||
} else {
|
|
||||||
this.queue.push(event);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
end(result?: R): void {
|
|
||||||
this.done = true;
|
|
||||||
if (result !== undefined) {
|
|
||||||
this.resolveFinalResult(result);
|
|
||||||
}
|
|
||||||
while (this.waiting.length > 0) {
|
|
||||||
const waiter = this.waiting.shift();
|
|
||||||
if (waiter) {
|
|
||||||
waiter({ value: undefined as never, done: true });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async *[Symbol.asyncIterator](): AsyncIterator<T> {
|
|
||||||
while (true) {
|
|
||||||
if (this.queue.length > 0) {
|
|
||||||
const next = this.queue.shift();
|
|
||||||
if (next !== undefined) {
|
|
||||||
yield next;
|
|
||||||
}
|
|
||||||
} else if (this.done) {
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
const result = await new Promise<IteratorResult<T>>((resolve) =>
|
|
||||||
this.waiting.push(resolve),
|
|
||||||
);
|
|
||||||
if (result.done) return;
|
|
||||||
yield result.value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
result(): Promise<R> {
|
|
||||||
return this.finalResultPromise;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function createAgentStream(): EventStream<
|
|
||||||
AgentEvent,
|
|
||||||
AgentContext["messages"]
|
|
||||||
> {
|
|
||||||
return new EventStream<AgentEvent, AgentContext["messages"]>(
|
|
||||||
(event) => event.type === "agent_end",
|
|
||||||
(event) => (event.type === "agent_end" ? event.messages : []),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function agentLoop(
|
|
||||||
prompt: UserMessage,
|
|
||||||
context: AgentContext,
|
|
||||||
config: AgentLoopConfig,
|
|
||||||
signal?: AbortSignal,
|
|
||||||
streamFn?: typeof streamSimple,
|
|
||||||
): EventStream<AgentEvent, AgentContext["messages"]> {
|
|
||||||
const stream = createAgentStream();
|
|
||||||
|
|
||||||
void (async () => {
|
|
||||||
const newMessages: AgentContext["messages"] = [prompt];
|
|
||||||
const currentContext: AgentContext = {
|
|
||||||
...context,
|
|
||||||
messages: [...context.messages, prompt],
|
|
||||||
};
|
|
||||||
|
|
||||||
stream.push({ type: "agent_start" });
|
|
||||||
stream.push({ type: "turn_start" });
|
|
||||||
stream.push({ type: "message_start", message: prompt });
|
|
||||||
stream.push({ type: "message_end", message: prompt });
|
|
||||||
|
|
||||||
await runLoop(
|
|
||||||
currentContext,
|
|
||||||
newMessages,
|
|
||||||
config,
|
|
||||||
signal,
|
|
||||||
stream,
|
|
||||||
streamFn,
|
|
||||||
);
|
|
||||||
})();
|
|
||||||
|
|
||||||
return stream;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function agentLoopContinue(
|
|
||||||
context: AgentContext,
|
|
||||||
config: AgentLoopConfig,
|
|
||||||
signal?: AbortSignal,
|
|
||||||
streamFn?: typeof streamSimple,
|
|
||||||
): EventStream<AgentEvent, AgentContext["messages"]> {
|
|
||||||
const lastMessage = context.messages[context.messages.length - 1];
|
|
||||||
if (!lastMessage) {
|
|
||||||
throw new Error("Cannot continue: no messages in context");
|
|
||||||
}
|
|
||||||
if (lastMessage.role !== "user" && lastMessage.role !== "toolResult") {
|
|
||||||
throw new Error(
|
|
||||||
`Cannot continue from message role: ${lastMessage.role}. Expected 'user' or 'toolResult'.`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
const stream = createAgentStream();
|
|
||||||
|
|
||||||
void (async () => {
|
|
||||||
const newMessages: AgentContext["messages"] = [];
|
|
||||||
const currentContext: AgentContext = { ...context };
|
|
||||||
|
|
||||||
stream.push({ type: "agent_start" });
|
|
||||||
stream.push({ type: "turn_start" });
|
|
||||||
|
|
||||||
await runLoop(
|
|
||||||
currentContext,
|
|
||||||
newMessages,
|
|
||||||
config,
|
|
||||||
signal,
|
|
||||||
stream,
|
|
||||||
streamFn,
|
|
||||||
);
|
|
||||||
})();
|
|
||||||
|
|
||||||
return stream;
|
|
||||||
}
|
|
||||||
|
|
||||||
async function runLoop(
|
|
||||||
currentContext: AgentContext,
|
|
||||||
newMessages: AgentContext["messages"],
|
|
||||||
config: AgentLoopConfig,
|
|
||||||
signal: AbortSignal | undefined,
|
|
||||||
stream: EventStream<AgentEvent, AgentContext["messages"]>,
|
|
||||||
streamFn?: typeof streamSimple,
|
|
||||||
): Promise<void> {
|
|
||||||
let hasMoreToolCalls = true;
|
|
||||||
let firstTurn = true;
|
|
||||||
const getQueuedMessages = config.getQueuedMessages;
|
|
||||||
let queuedMessages: QueuedMessage<Message>[] = getQueuedMessages
|
|
||||||
? await getQueuedMessages<Message>()
|
|
||||||
: [];
|
|
||||||
let queuedAfterTools: QueuedMessage<Message>[] | null = null;
|
|
||||||
|
|
||||||
while (hasMoreToolCalls || queuedMessages.length > 0) {
|
|
||||||
if (!firstTurn) {
|
|
||||||
stream.push({ type: "turn_start" });
|
|
||||||
} else {
|
|
||||||
firstTurn = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (queuedMessages.length > 0) {
|
|
||||||
for (const { original, llm } of queuedMessages) {
|
|
||||||
stream.push({ type: "message_start", message: original });
|
|
||||||
stream.push({ type: "message_end", message: original });
|
|
||||||
if (llm) {
|
|
||||||
currentContext.messages.push(llm);
|
|
||||||
newMessages.push(llm);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
queuedMessages = [];
|
|
||||||
}
|
|
||||||
|
|
||||||
const message = await streamAssistantResponse(
|
|
||||||
currentContext,
|
|
||||||
config,
|
|
||||||
signal,
|
|
||||||
stream,
|
|
||||||
streamFn,
|
|
||||||
);
|
|
||||||
newMessages.push(message);
|
|
||||||
|
|
||||||
if (message.stopReason === "error" || message.stopReason === "aborted") {
|
|
||||||
stream.push({ type: "turn_end", message, toolResults: [] });
|
|
||||||
stream.push({ type: "agent_end", messages: newMessages });
|
|
||||||
stream.end(newMessages);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const toolCalls = message.content.filter((c) => c.type === "toolCall");
|
|
||||||
hasMoreToolCalls = toolCalls.length > 0;
|
|
||||||
|
|
||||||
const toolResults: ToolResultMessage[] = [];
|
|
||||||
if (hasMoreToolCalls) {
|
|
||||||
const toolExecution = await executeToolCalls(
|
|
||||||
currentContext.tools,
|
|
||||||
message,
|
|
||||||
signal,
|
|
||||||
stream,
|
|
||||||
config.getQueuedMessages,
|
|
||||||
);
|
|
||||||
toolResults.push(...toolExecution.toolResults);
|
|
||||||
queuedAfterTools = toolExecution.queuedMessages ?? null;
|
|
||||||
currentContext.messages.push(...toolResults);
|
|
||||||
newMessages.push(...toolResults);
|
|
||||||
}
|
|
||||||
stream.push({ type: "turn_end", message, toolResults: toolResults });
|
|
||||||
|
|
||||||
if (queuedAfterTools && queuedAfterTools.length > 0) {
|
|
||||||
queuedMessages = queuedAfterTools;
|
|
||||||
queuedAfterTools = null;
|
|
||||||
} else {
|
|
||||||
queuedMessages = getQueuedMessages
|
|
||||||
? await getQueuedMessages<Message>()
|
|
||||||
: [];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.push({ type: "agent_end", messages: newMessages });
|
|
||||||
stream.end(newMessages);
|
|
||||||
}
|
|
||||||
|
|
||||||
async function streamAssistantResponse(
|
|
||||||
context: AgentContext,
|
|
||||||
config: AgentLoopConfig,
|
|
||||||
signal: AbortSignal | undefined,
|
|
||||||
stream: EventStream<AgentEvent, AgentContext["messages"]>,
|
|
||||||
streamFn?: typeof streamSimple,
|
|
||||||
): Promise<AssistantMessage> {
|
|
||||||
const processedMessages = config.preprocessor
|
|
||||||
? await config.preprocessor(context.messages, signal)
|
|
||||||
: [...context.messages];
|
|
||||||
const processedContext: Context = {
|
|
||||||
systemPrompt: context.systemPrompt,
|
|
||||||
messages: [...processedMessages].map((m) => {
|
|
||||||
if (m.role === "toolResult") {
|
|
||||||
const { details: _details, ...rest } = m;
|
|
||||||
return rest;
|
|
||||||
}
|
|
||||||
return m;
|
|
||||||
}),
|
|
||||||
tools: context.tools,
|
|
||||||
};
|
|
||||||
|
|
||||||
const streamFunction = streamFn || streamSimple;
|
|
||||||
const resolvedApiKey =
|
|
||||||
(config.getApiKey
|
|
||||||
? await config.getApiKey(config.model.provider)
|
|
||||||
: undefined) || config.apiKey;
|
|
||||||
|
|
||||||
const response = await streamFunction(config.model, processedContext, {
|
|
||||||
...config,
|
|
||||||
apiKey: resolvedApiKey,
|
|
||||||
signal,
|
|
||||||
});
|
|
||||||
|
|
||||||
let partialMessage: AssistantMessage | null = null;
|
|
||||||
let addedPartial = false;
|
|
||||||
|
|
||||||
for await (const event of response) {
|
|
||||||
switch (event.type) {
|
|
||||||
case "start":
|
|
||||||
partialMessage = event.partial;
|
|
||||||
context.messages.push(partialMessage);
|
|
||||||
addedPartial = true;
|
|
||||||
stream.push({ type: "message_start", message: { ...partialMessage } });
|
|
||||||
break;
|
|
||||||
|
|
||||||
case "text_start":
|
|
||||||
case "text_delta":
|
|
||||||
case "text_end":
|
|
||||||
case "thinking_start":
|
|
||||||
case "thinking_delta":
|
|
||||||
case "thinking_end":
|
|
||||||
case "toolcall_start":
|
|
||||||
case "toolcall_delta":
|
|
||||||
case "toolcall_end":
|
|
||||||
if (partialMessage) {
|
|
||||||
partialMessage = event.partial;
|
|
||||||
context.messages[context.messages.length - 1] = partialMessage;
|
|
||||||
stream.push({
|
|
||||||
type: "message_update",
|
|
||||||
assistantMessageEvent: event,
|
|
||||||
message: { ...partialMessage },
|
|
||||||
});
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case "done":
|
|
||||||
case "error": {
|
|
||||||
const finalMessage = await response.result();
|
|
||||||
if (addedPartial) {
|
|
||||||
context.messages[context.messages.length - 1] = finalMessage;
|
|
||||||
} else {
|
|
||||||
context.messages.push(finalMessage);
|
|
||||||
}
|
|
||||||
if (!addedPartial) {
|
|
||||||
stream.push({ type: "message_start", message: { ...finalMessage } });
|
|
||||||
}
|
|
||||||
stream.push({ type: "message_end", message: finalMessage });
|
|
||||||
return finalMessage;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return await response.result();
|
|
||||||
}
|
|
||||||
|
|
||||||
async function executeToolCalls<T>(
|
|
||||||
tools: AgentTool<TSchema, T>[] | undefined,
|
|
||||||
assistantMessage: AssistantMessage,
|
|
||||||
signal: AbortSignal | undefined,
|
|
||||||
stream: EventStream<AgentEvent, Message[]>,
|
|
||||||
getQueuedMessages?: AgentLoopConfig["getQueuedMessages"],
|
|
||||||
): Promise<{
|
|
||||||
toolResults: ToolResultMessage<T>[];
|
|
||||||
queuedMessages?: QueuedMessage<Message>[];
|
|
||||||
}> {
|
|
||||||
const toolCalls = assistantMessage.content.filter(
|
|
||||||
(c) => c.type === "toolCall",
|
|
||||||
);
|
|
||||||
const results: ToolResultMessage<T>[] = [];
|
|
||||||
let queuedMessages: QueuedMessage<Message>[] | undefined;
|
|
||||||
|
|
||||||
for (let index = 0; index < toolCalls.length; index++) {
|
|
||||||
const toolCall = toolCalls[index];
|
|
||||||
const tool = tools?.find((t) => t.name === toolCall.name);
|
|
||||||
|
|
||||||
stream.push({
|
|
||||||
type: "tool_execution_start",
|
|
||||||
toolCallId: toolCall.id,
|
|
||||||
toolName: toolCall.name,
|
|
||||||
args: toolCall.arguments,
|
|
||||||
});
|
|
||||||
|
|
||||||
let result: AgentToolResult<T>;
|
|
||||||
let isError = false;
|
|
||||||
|
|
||||||
try {
|
|
||||||
if (!tool) throw new Error(`Tool ${toolCall.name} not found`);
|
|
||||||
const validatedArgs = validateToolArguments(tool, toolCall);
|
|
||||||
result = await tool.execute(
|
|
||||||
toolCall.id,
|
|
||||||
validatedArgs,
|
|
||||||
signal,
|
|
||||||
(partialResult) => {
|
|
||||||
stream.push({
|
|
||||||
type: "tool_execution_update",
|
|
||||||
toolCallId: toolCall.id,
|
|
||||||
toolName: toolCall.name,
|
|
||||||
args: toolCall.arguments,
|
|
||||||
partialResult,
|
|
||||||
});
|
|
||||||
},
|
|
||||||
);
|
|
||||||
} catch (err) {
|
|
||||||
result = {
|
|
||||||
content: [
|
|
||||||
{
|
|
||||||
type: "text",
|
|
||||||
text: err instanceof Error ? err.message : String(err),
|
|
||||||
},
|
|
||||||
],
|
|
||||||
details: {} as T,
|
|
||||||
};
|
|
||||||
isError = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.push({
|
|
||||||
type: "tool_execution_end",
|
|
||||||
toolCallId: toolCall.id,
|
|
||||||
toolName: toolCall.name,
|
|
||||||
result,
|
|
||||||
isError,
|
|
||||||
});
|
|
||||||
|
|
||||||
const toolResultMessage: ToolResultMessage<T> = {
|
|
||||||
role: "toolResult",
|
|
||||||
toolCallId: toolCall.id,
|
|
||||||
toolName: toolCall.name,
|
|
||||||
content: result.content,
|
|
||||||
details: result.details,
|
|
||||||
isError,
|
|
||||||
timestamp: Date.now(),
|
|
||||||
};
|
|
||||||
|
|
||||||
results.push(toolResultMessage);
|
|
||||||
stream.push({ type: "message_start", message: toolResultMessage });
|
|
||||||
stream.push({ type: "message_end", message: toolResultMessage });
|
|
||||||
|
|
||||||
if (getQueuedMessages) {
|
|
||||||
const queued = await getQueuedMessages<Message>();
|
|
||||||
if (queued.length > 0) {
|
|
||||||
queuedMessages = queued;
|
|
||||||
const remainingCalls = toolCalls.slice(index + 1);
|
|
||||||
for (const skipped of remainingCalls) {
|
|
||||||
results.push(skipToolCall(skipped, stream));
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return { toolResults: results, queuedMessages };
|
|
||||||
}
|
|
||||||
|
|
||||||
function skipToolCall<T>(
|
|
||||||
toolCall: Extract<AssistantMessage["content"][number], { type: "toolCall" }>,
|
|
||||||
stream: EventStream<AgentEvent, Message[]>,
|
|
||||||
): ToolResultMessage<T> {
|
|
||||||
const result: AgentToolResult<T> = {
|
|
||||||
content: [{ type: "text", text: "Skipped due to queued user message." }],
|
|
||||||
details: {} as T,
|
|
||||||
};
|
|
||||||
|
|
||||||
stream.push({
|
|
||||||
type: "tool_execution_start",
|
|
||||||
toolCallId: toolCall.id,
|
|
||||||
toolName: toolCall.name,
|
|
||||||
args: toolCall.arguments,
|
|
||||||
});
|
|
||||||
stream.push({
|
|
||||||
type: "tool_execution_end",
|
|
||||||
toolCallId: toolCall.id,
|
|
||||||
toolName: toolCall.name,
|
|
||||||
result,
|
|
||||||
isError: true,
|
|
||||||
});
|
|
||||||
|
|
||||||
const toolResultMessage: ToolResultMessage<T> = {
|
|
||||||
role: "toolResult",
|
|
||||||
toolCallId: toolCall.id,
|
|
||||||
toolName: toolCall.name,
|
|
||||||
content: result.content,
|
|
||||||
details: result.details,
|
|
||||||
isError: true,
|
|
||||||
timestamp: Date.now(),
|
|
||||||
};
|
|
||||||
|
|
||||||
stream.push({ type: "message_start", message: toolResultMessage });
|
|
||||||
stream.push({ type: "message_end", message: toolResultMessage });
|
|
||||||
|
|
||||||
return toolResultMessage;
|
|
||||||
}
|
|
||||||
|
|
@ -1,85 +0,0 @@
|
||||||
import type {
|
|
||||||
AgentRunConfig,
|
|
||||||
AgentTransport,
|
|
||||||
ProviderTransportOptions,
|
|
||||||
} from "@mariozechner/pi-agent-core";
|
|
||||||
import type {
|
|
||||||
AgentContext,
|
|
||||||
AgentLoopConfig,
|
|
||||||
Message,
|
|
||||||
UserMessage,
|
|
||||||
} from "@mariozechner/pi-ai";
|
|
||||||
import { agentLoop, agentLoopContinue } from "./steerable-agent-loop.js";
|
|
||||||
|
|
||||||
export class SteerableProviderTransport implements AgentTransport {
|
|
||||||
private options: ProviderTransportOptions;
|
|
||||||
|
|
||||||
constructor(options: ProviderTransportOptions = {}) {
|
|
||||||
this.options = options;
|
|
||||||
}
|
|
||||||
|
|
||||||
private getModel(cfg: AgentRunConfig) {
|
|
||||||
let model = cfg.model;
|
|
||||||
if (this.options.corsProxyUrl && cfg.model.baseUrl) {
|
|
||||||
model = {
|
|
||||||
...cfg.model,
|
|
||||||
baseUrl: `${this.options.corsProxyUrl}/?url=${encodeURIComponent(cfg.model.baseUrl)}`,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
return model;
|
|
||||||
}
|
|
||||||
|
|
||||||
private buildContext(messages: Message[], cfg: AgentRunConfig): AgentContext {
|
|
||||||
return {
|
|
||||||
systemPrompt: cfg.systemPrompt,
|
|
||||||
messages,
|
|
||||||
tools: cfg.tools,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
private buildLoopConfig(
|
|
||||||
model: AgentRunConfig["model"],
|
|
||||||
cfg: AgentRunConfig,
|
|
||||||
): AgentLoopConfig {
|
|
||||||
return {
|
|
||||||
model,
|
|
||||||
reasoning: cfg.reasoning,
|
|
||||||
getApiKey: this.options.getApiKey,
|
|
||||||
getQueuedMessages: cfg.getQueuedMessages,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
async *run(
|
|
||||||
messages: Message[],
|
|
||||||
userMessage: Message,
|
|
||||||
cfg: AgentRunConfig,
|
|
||||||
signal?: AbortSignal,
|
|
||||||
) {
|
|
||||||
const model = this.getModel(cfg);
|
|
||||||
const context = this.buildContext(messages, cfg);
|
|
||||||
const pc = this.buildLoopConfig(model, cfg);
|
|
||||||
|
|
||||||
for await (const ev of agentLoop(
|
|
||||||
userMessage as unknown as UserMessage,
|
|
||||||
context,
|
|
||||||
pc,
|
|
||||||
signal,
|
|
||||||
)) {
|
|
||||||
yield ev;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async *continue(
|
|
||||||
messages: Message[],
|
|
||||||
cfg: AgentRunConfig,
|
|
||||||
signal?: AbortSignal,
|
|
||||||
) {
|
|
||||||
const model = this.getModel(cfg);
|
|
||||||
const context = this.buildContext(messages, cfg);
|
|
||||||
const pc = this.buildLoopConfig(model, cfg);
|
|
||||||
|
|
||||||
for await (const ev of agentLoopContinue(context, pc, signal)) {
|
|
||||||
yield ev;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -2407,7 +2407,8 @@ export async function startGatewayServer(
|
||||||
const remoteAddr = (
|
const remoteAddr = (
|
||||||
socket as WebSocket & { _socket?: { remoteAddress?: string } }
|
socket as WebSocket & { _socket?: { remoteAddress?: string } }
|
||||||
)._socket?.remoteAddress;
|
)._socket?.remoteAddress;
|
||||||
const canvasHostPortForWs = canvasHostServer?.port ?? (canvasHost ? port : undefined);
|
const canvasHostPortForWs =
|
||||||
|
canvasHostServer?.port ?? (canvasHost ? port : undefined);
|
||||||
const canvasHostOverride =
|
const canvasHostOverride =
|
||||||
bridgeHost && bridgeHost !== "0.0.0.0" && bridgeHost !== "::"
|
bridgeHost && bridgeHost !== "0.0.0.0" && bridgeHost !== "::"
|
||||||
? bridgeHost
|
? bridgeHost
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue