From 0f271ab535c9c2e337af515e781dd9372aa90a71 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Sat, 20 Dec 2025 17:50:35 +0100 Subject: [PATCH] refactor: tighten steerable agent loop typing --- src/agents/steerable-agent-loop.ts | 110 +++++++++++++++++++---------- 1 file changed, 73 insertions(+), 37 deletions(-) diff --git a/src/agents/steerable-agent-loop.ts b/src/agents/steerable-agent-loop.ts index ef1849d33..69b2170b6 100644 --- a/src/agents/steerable-agent-loop.ts +++ b/src/agents/steerable-agent-loop.ts @@ -1,19 +1,18 @@ -import { streamSimple, validateToolArguments } from "@mariozechner/pi-ai"; -import type { - AssistantMessage, - Context, - Message, - ToolResultMessage, - UserMessage, -} from "@mariozechner/pi-ai"; import type { AgentContext, AgentEvent, AgentLoopConfig, AgentTool, AgentToolResult, + AssistantMessage, + Context, + Message, QueuedMessage, + ToolResultMessage, + UserMessage, } from "@mariozechner/pi-ai"; +import { streamSimple, validateToolArguments } from "@mariozechner/pi-ai"; +import type { TSchema } from "@sinclair/typebox"; class EventStream implements AsyncIterable { private queue: T[] = []; @@ -53,15 +52,20 @@ class EventStream implements AsyncIterable { this.resolveFinalResult(result); } while (this.waiting.length > 0) { - const waiter = this.waiting.shift()!; - waiter({ value: undefined as never, done: true }); + const waiter = this.waiting.shift(); + if (waiter) { + waiter({ value: undefined as never, done: true }); + } } } async *[Symbol.asyncIterator](): AsyncIterator { while (true) { if (this.queue.length > 0) { - yield this.queue.shift()!; + const next = this.queue.shift(); + if (next !== undefined) { + yield next; + } } else if (this.done) { return; } else { @@ -79,7 +83,10 @@ class EventStream implements AsyncIterable { } } -function createAgentStream(): EventStream { +function createAgentStream(): EventStream< + AgentEvent, + AgentContext["messages"] +> { return new EventStream( (event) => event.type === "agent_end", (event) => (event.type === "agent_end" ? event.messages : []), @@ -107,7 +114,14 @@ export function agentLoop( stream.push({ type: "message_start", message: prompt }); stream.push({ type: "message_end", message: prompt }); - await runLoop(currentContext, newMessages, config, signal, stream, streamFn); + await runLoop( + currentContext, + newMessages, + config, + signal, + stream, + streamFn, + ); })(); return stream; @@ -138,7 +152,14 @@ export function agentLoopContinue( stream.push({ type: "agent_start" }); stream.push({ type: "turn_start" }); - await runLoop(currentContext, newMessages, config, signal, stream, streamFn); + await runLoop( + currentContext, + newMessages, + config, + signal, + stream, + streamFn, + ); })(); return stream; @@ -154,9 +175,11 @@ async function runLoop( ): Promise { let hasMoreToolCalls = true; let firstTurn = true; - let queuedMessages: QueuedMessage[] = - (await config.getQueuedMessages?.()) || []; - let queuedAfterTools: QueuedMessage[] | null = null; + const getQueuedMessages = config.getQueuedMessages; + let queuedMessages: QueuedMessage[] = getQueuedMessages + ? await getQueuedMessages() + : []; + let queuedAfterTools: QueuedMessage[] | null = null; while (hasMoreToolCalls || queuedMessages.length > 0) { if (!firstTurn) { @@ -216,7 +239,9 @@ async function runLoop( queuedMessages = queuedAfterTools; queuedAfterTools = null; } else { - queuedMessages = (await config.getQueuedMessages?.()) || []; + queuedMessages = getQueuedMessages + ? await getQueuedMessages() + : []; } } @@ -238,7 +263,7 @@ async function streamAssistantResponse( systemPrompt: context.systemPrompt, messages: [...processedMessages].map((m) => { if (m.role === "toolResult") { - const { details, ...rest } = m; + const { details: _details, ...rest } = m; return rest; } return m; @@ -248,8 +273,9 @@ async function streamAssistantResponse( const streamFunction = streamFn || streamSimple; const resolvedApiKey = - (config.getApiKey ? await config.getApiKey(config.model.provider) : undefined) || - config.apiKey; + (config.getApiKey + ? await config.getApiKey(config.model.provider) + : undefined) || config.apiKey; const response = await streamFunction(config.model, processedContext, { ...config, @@ -310,18 +336,20 @@ async function streamAssistantResponse( } async function executeToolCalls( - tools: AgentTool[] | undefined, + tools: AgentTool[] | undefined, assistantMessage: AssistantMessage, signal: AbortSignal | undefined, stream: EventStream, getQueuedMessages?: AgentLoopConfig["getQueuedMessages"], ): Promise<{ toolResults: ToolResultMessage[]; - queuedMessages?: QueuedMessage[]; + queuedMessages?: QueuedMessage[]; }> { - const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall"); - const results: ToolResultMessage[] = []; - let queuedMessages: QueuedMessage[] | undefined; + const toolCalls = assistantMessage.content.filter( + (c) => c.type === "toolCall", + ); + const results: ToolResultMessage[] = []; + let queuedMessages: QueuedMessage[] | undefined; for (let index = 0; index < toolCalls.length; index++) { const toolCall = toolCalls[index]; @@ -340,19 +368,27 @@ async function executeToolCalls( try { if (!tool) throw new Error(`Tool ${toolCall.name} not found`); const validatedArgs = validateToolArguments(tool, toolCall); - result = await tool.execute(toolCall.id, validatedArgs, signal, (partialResult) => { - stream.push({ - type: "tool_execution_update", - toolCallId: toolCall.id, - toolName: toolCall.name, - args: toolCall.arguments, - partialResult, - }); - }); + result = await tool.execute( + toolCall.id, + validatedArgs, + signal, + (partialResult) => { + stream.push({ + type: "tool_execution_update", + toolCallId: toolCall.id, + toolName: toolCall.name, + args: toolCall.arguments, + partialResult, + }); + }, + ); } catch (err) { result = { content: [ - { type: "text", text: err instanceof Error ? err.message : String(err) }, + { + type: "text", + text: err instanceof Error ? err.message : String(err), + }, ], details: {} as T, }; @@ -382,7 +418,7 @@ async function executeToolCalls( stream.push({ type: "message_end", message: toolResultMessage }); if (getQueuedMessages) { - const queued = await getQueuedMessages(); + const queued = await getQueuedMessages(); if (queued.length > 0) { queuedMessages = queued; const remainingCalls = toolCalls.slice(index + 1);