refactor: tighten steerable agent loop typing

This commit is contained in:
Peter Steinberger
2025-12-20 17:50:35 +01:00
parent 4c054917ef
commit 0f271ab535

View File

@@ -1,19 +1,18 @@
import { streamSimple, validateToolArguments } from "@mariozechner/pi-ai";
import type {
AssistantMessage,
Context,
Message,
ToolResultMessage,
UserMessage,
} from "@mariozechner/pi-ai";
import type {
AgentContext,
AgentEvent,
AgentLoopConfig,
AgentTool,
AgentToolResult,
AssistantMessage,
Context,
Message,
QueuedMessage,
ToolResultMessage,
UserMessage,
} from "@mariozechner/pi-ai";
import { streamSimple, validateToolArguments } from "@mariozechner/pi-ai";
import type { TSchema } from "@sinclair/typebox";
class EventStream<T, R = T> implements AsyncIterable<T> {
private queue: T[] = [];
@@ -53,15 +52,20 @@ class EventStream<T, R = T> implements AsyncIterable<T> {
this.resolveFinalResult(result);
}
while (this.waiting.length > 0) {
const waiter = this.waiting.shift()!;
waiter({ value: undefined as never, done: true });
const waiter = this.waiting.shift();
if (waiter) {
waiter({ value: undefined as never, done: true });
}
}
}
async *[Symbol.asyncIterator](): AsyncIterator<T> {
while (true) {
if (this.queue.length > 0) {
yield this.queue.shift()!;
const next = this.queue.shift();
if (next !== undefined) {
yield next;
}
} else if (this.done) {
return;
} else {
@@ -79,7 +83,10 @@ class EventStream<T, R = T> implements AsyncIterable<T> {
}
}
function createAgentStream(): EventStream<AgentEvent, AgentContext["messages"]> {
function createAgentStream(): EventStream<
AgentEvent,
AgentContext["messages"]
> {
return new EventStream<AgentEvent, AgentContext["messages"]>(
(event) => event.type === "agent_end",
(event) => (event.type === "agent_end" ? event.messages : []),
@@ -107,7 +114,14 @@ export function agentLoop(
stream.push({ type: "message_start", message: prompt });
stream.push({ type: "message_end", message: prompt });
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
await runLoop(
currentContext,
newMessages,
config,
signal,
stream,
streamFn,
);
})();
return stream;
@@ -138,7 +152,14 @@ export function agentLoopContinue(
stream.push({ type: "agent_start" });
stream.push({ type: "turn_start" });
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
await runLoop(
currentContext,
newMessages,
config,
signal,
stream,
streamFn,
);
})();
return stream;
@@ -154,9 +175,11 @@ async function runLoop(
): Promise<void> {
let hasMoreToolCalls = true;
let firstTurn = true;
let queuedMessages: QueuedMessage<any>[] =
(await config.getQueuedMessages?.()) || [];
let queuedAfterTools: QueuedMessage<any>[] | null = null;
const getQueuedMessages = config.getQueuedMessages;
let queuedMessages: QueuedMessage<Message>[] = getQueuedMessages
? await getQueuedMessages<Message>()
: [];
let queuedAfterTools: QueuedMessage<Message>[] | null = null;
while (hasMoreToolCalls || queuedMessages.length > 0) {
if (!firstTurn) {
@@ -216,7 +239,9 @@ async function runLoop(
queuedMessages = queuedAfterTools;
queuedAfterTools = null;
} else {
queuedMessages = (await config.getQueuedMessages?.()) || [];
queuedMessages = getQueuedMessages
? await getQueuedMessages<Message>()
: [];
}
}
@@ -238,7 +263,7 @@ async function streamAssistantResponse(
systemPrompt: context.systemPrompt,
messages: [...processedMessages].map((m) => {
if (m.role === "toolResult") {
const { details, ...rest } = m;
const { details: _details, ...rest } = m;
return rest;
}
return m;
@@ -248,8 +273,9 @@ async function streamAssistantResponse(
const streamFunction = streamFn || streamSimple;
const resolvedApiKey =
(config.getApiKey ? await config.getApiKey(config.model.provider) : undefined) ||
config.apiKey;
(config.getApiKey
? await config.getApiKey(config.model.provider)
: undefined) || config.apiKey;
const response = await streamFunction(config.model, processedContext, {
...config,
@@ -310,18 +336,20 @@ async function streamAssistantResponse(
}
async function executeToolCalls<T>(
tools: AgentTool<any, T>[] | undefined,
tools: AgentTool<TSchema, T>[] | undefined,
assistantMessage: AssistantMessage,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, Message[]>,
getQueuedMessages?: AgentLoopConfig["getQueuedMessages"],
): Promise<{
toolResults: ToolResultMessage<T>[];
queuedMessages?: QueuedMessage<any>[];
queuedMessages?: QueuedMessage<Message>[];
}> {
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall");
const results: ToolResultMessage<any>[] = [];
let queuedMessages: QueuedMessage<any>[] | undefined;
const toolCalls = assistantMessage.content.filter(
(c) => c.type === "toolCall",
);
const results: ToolResultMessage<T>[] = [];
let queuedMessages: QueuedMessage<Message>[] | undefined;
for (let index = 0; index < toolCalls.length; index++) {
const toolCall = toolCalls[index];
@@ -340,19 +368,27 @@ async function executeToolCalls<T>(
try {
if (!tool) throw new Error(`Tool ${toolCall.name} not found`);
const validatedArgs = validateToolArguments(tool, toolCall);
result = await tool.execute(toolCall.id, validatedArgs, signal, (partialResult) => {
stream.push({
type: "tool_execution_update",
toolCallId: toolCall.id,
toolName: toolCall.name,
args: toolCall.arguments,
partialResult,
});
});
result = await tool.execute(
toolCall.id,
validatedArgs,
signal,
(partialResult) => {
stream.push({
type: "tool_execution_update",
toolCallId: toolCall.id,
toolName: toolCall.name,
args: toolCall.arguments,
partialResult,
});
},
);
} catch (err) {
result = {
content: [
{ type: "text", text: err instanceof Error ? err.message : String(err) },
{
type: "text",
text: err instanceof Error ? err.message : String(err),
},
],
details: {} as T,
};
@@ -382,7 +418,7 @@ async function executeToolCalls<T>(
stream.push({ type: "message_end", message: toolResultMessage });
if (getQueuedMessages) {
const queued = await getQueuedMessages();
const queued = await getQueuedMessages<Message>();
if (queued.length > 0) {
queuedMessages = queued;
const remainingCalls = toolCalls.slice(index + 1);