refactor: tighten steerable agent loop typing

This commit is contained in:
Peter Steinberger
2025-12-20 17:50:35 +01:00
parent 4c054917ef
commit 0f271ab535

View File

@@ -1,19 +1,18 @@
import { streamSimple, validateToolArguments } from "@mariozechner/pi-ai";
import type {
AssistantMessage,
Context,
Message,
ToolResultMessage,
UserMessage,
} from "@mariozechner/pi-ai";
import type { import type {
AgentContext, AgentContext,
AgentEvent, AgentEvent,
AgentLoopConfig, AgentLoopConfig,
AgentTool, AgentTool,
AgentToolResult, AgentToolResult,
AssistantMessage,
Context,
Message,
QueuedMessage, QueuedMessage,
ToolResultMessage,
UserMessage,
} from "@mariozechner/pi-ai"; } from "@mariozechner/pi-ai";
import { streamSimple, validateToolArguments } from "@mariozechner/pi-ai";
import type { TSchema } from "@sinclair/typebox";
class EventStream<T, R = T> implements AsyncIterable<T> { class EventStream<T, R = T> implements AsyncIterable<T> {
private queue: T[] = []; private queue: T[] = [];
@@ -53,15 +52,20 @@ class EventStream<T, R = T> implements AsyncIterable<T> {
this.resolveFinalResult(result); this.resolveFinalResult(result);
} }
while (this.waiting.length > 0) { while (this.waiting.length > 0) {
const waiter = this.waiting.shift()!; const waiter = this.waiting.shift();
waiter({ value: undefined as never, done: true }); if (waiter) {
waiter({ value: undefined as never, done: true });
}
} }
} }
async *[Symbol.asyncIterator](): AsyncIterator<T> { async *[Symbol.asyncIterator](): AsyncIterator<T> {
while (true) { while (true) {
if (this.queue.length > 0) { if (this.queue.length > 0) {
yield this.queue.shift()!; const next = this.queue.shift();
if (next !== undefined) {
yield next;
}
} else if (this.done) { } else if (this.done) {
return; return;
} else { } else {
@@ -79,7 +83,10 @@ class EventStream<T, R = T> implements AsyncIterable<T> {
} }
} }
function createAgentStream(): EventStream<AgentEvent, AgentContext["messages"]> { function createAgentStream(): EventStream<
AgentEvent,
AgentContext["messages"]
> {
return new EventStream<AgentEvent, AgentContext["messages"]>( return new EventStream<AgentEvent, AgentContext["messages"]>(
(event) => event.type === "agent_end", (event) => event.type === "agent_end",
(event) => (event.type === "agent_end" ? event.messages : []), (event) => (event.type === "agent_end" ? event.messages : []),
@@ -107,7 +114,14 @@ export function agentLoop(
stream.push({ type: "message_start", message: prompt }); stream.push({ type: "message_start", message: prompt });
stream.push({ type: "message_end", message: prompt }); stream.push({ type: "message_end", message: prompt });
await runLoop(currentContext, newMessages, config, signal, stream, streamFn); await runLoop(
currentContext,
newMessages,
config,
signal,
stream,
streamFn,
);
})(); })();
return stream; return stream;
@@ -138,7 +152,14 @@ export function agentLoopContinue(
stream.push({ type: "agent_start" }); stream.push({ type: "agent_start" });
stream.push({ type: "turn_start" }); stream.push({ type: "turn_start" });
await runLoop(currentContext, newMessages, config, signal, stream, streamFn); await runLoop(
currentContext,
newMessages,
config,
signal,
stream,
streamFn,
);
})(); })();
return stream; return stream;
@@ -154,9 +175,11 @@ async function runLoop(
): Promise<void> { ): Promise<void> {
let hasMoreToolCalls = true; let hasMoreToolCalls = true;
let firstTurn = true; let firstTurn = true;
let queuedMessages: QueuedMessage<any>[] = const getQueuedMessages = config.getQueuedMessages;
(await config.getQueuedMessages?.()) || []; let queuedMessages: QueuedMessage<Message>[] = getQueuedMessages
let queuedAfterTools: QueuedMessage<any>[] | null = null; ? await getQueuedMessages<Message>()
: [];
let queuedAfterTools: QueuedMessage<Message>[] | null = null;
while (hasMoreToolCalls || queuedMessages.length > 0) { while (hasMoreToolCalls || queuedMessages.length > 0) {
if (!firstTurn) { if (!firstTurn) {
@@ -216,7 +239,9 @@ async function runLoop(
queuedMessages = queuedAfterTools; queuedMessages = queuedAfterTools;
queuedAfterTools = null; queuedAfterTools = null;
} else { } else {
queuedMessages = (await config.getQueuedMessages?.()) || []; queuedMessages = getQueuedMessages
? await getQueuedMessages<Message>()
: [];
} }
} }
@@ -238,7 +263,7 @@ async function streamAssistantResponse(
systemPrompt: context.systemPrompt, systemPrompt: context.systemPrompt,
messages: [...processedMessages].map((m) => { messages: [...processedMessages].map((m) => {
if (m.role === "toolResult") { if (m.role === "toolResult") {
const { details, ...rest } = m; const { details: _details, ...rest } = m;
return rest; return rest;
} }
return m; return m;
@@ -248,8 +273,9 @@ async function streamAssistantResponse(
const streamFunction = streamFn || streamSimple; const streamFunction = streamFn || streamSimple;
const resolvedApiKey = const resolvedApiKey =
(config.getApiKey ? await config.getApiKey(config.model.provider) : undefined) || (config.getApiKey
config.apiKey; ? await config.getApiKey(config.model.provider)
: undefined) || config.apiKey;
const response = await streamFunction(config.model, processedContext, { const response = await streamFunction(config.model, processedContext, {
...config, ...config,
@@ -310,18 +336,20 @@ async function streamAssistantResponse(
} }
async function executeToolCalls<T>( async function executeToolCalls<T>(
tools: AgentTool<any, T>[] | undefined, tools: AgentTool<TSchema, T>[] | undefined,
assistantMessage: AssistantMessage, assistantMessage: AssistantMessage,
signal: AbortSignal | undefined, signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, Message[]>, stream: EventStream<AgentEvent, Message[]>,
getQueuedMessages?: AgentLoopConfig["getQueuedMessages"], getQueuedMessages?: AgentLoopConfig["getQueuedMessages"],
): Promise<{ ): Promise<{
toolResults: ToolResultMessage<T>[]; toolResults: ToolResultMessage<T>[];
queuedMessages?: QueuedMessage<any>[]; queuedMessages?: QueuedMessage<Message>[];
}> { }> {
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall"); const toolCalls = assistantMessage.content.filter(
const results: ToolResultMessage<any>[] = []; (c) => c.type === "toolCall",
let queuedMessages: QueuedMessage<any>[] | undefined; );
const results: ToolResultMessage<T>[] = [];
let queuedMessages: QueuedMessage<Message>[] | undefined;
for (let index = 0; index < toolCalls.length; index++) { for (let index = 0; index < toolCalls.length; index++) {
const toolCall = toolCalls[index]; const toolCall = toolCalls[index];
@@ -340,19 +368,27 @@ async function executeToolCalls<T>(
try { try {
if (!tool) throw new Error(`Tool ${toolCall.name} not found`); if (!tool) throw new Error(`Tool ${toolCall.name} not found`);
const validatedArgs = validateToolArguments(tool, toolCall); const validatedArgs = validateToolArguments(tool, toolCall);
result = await tool.execute(toolCall.id, validatedArgs, signal, (partialResult) => { result = await tool.execute(
stream.push({ toolCall.id,
type: "tool_execution_update", validatedArgs,
toolCallId: toolCall.id, signal,
toolName: toolCall.name, (partialResult) => {
args: toolCall.arguments, stream.push({
partialResult, type: "tool_execution_update",
}); toolCallId: toolCall.id,
}); toolName: toolCall.name,
args: toolCall.arguments,
partialResult,
});
},
);
} catch (err) { } catch (err) {
result = { result = {
content: [ content: [
{ type: "text", text: err instanceof Error ? err.message : String(err) }, {
type: "text",
text: err instanceof Error ? err.message : String(err),
},
], ],
details: {} as T, details: {} as T,
}; };
@@ -382,7 +418,7 @@ async function executeToolCalls<T>(
stream.push({ type: "message_end", message: toolResultMessage }); stream.push({ type: "message_end", message: toolResultMessage });
if (getQueuedMessages) { if (getQueuedMessages) {
const queued = await getQueuedMessages(); const queued = await getQueuedMessages<Message>();
if (queued.length > 0) { if (queued.length > 0) {
queuedMessages = queued; queuedMessages = queued;
const remainingCalls = toolCalls.slice(index + 1); const remainingCalls = toolCalls.slice(index + 1);