refactor: tighten steerable agent loop typing
This commit is contained in:
@@ -1,19 +1,18 @@
|
||||
import { streamSimple, validateToolArguments } from "@mariozechner/pi-ai";
|
||||
import type {
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Message,
|
||||
ToolResultMessage,
|
||||
UserMessage,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import type {
|
||||
AgentContext,
|
||||
AgentEvent,
|
||||
AgentLoopConfig,
|
||||
AgentTool,
|
||||
AgentToolResult,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Message,
|
||||
QueuedMessage,
|
||||
ToolResultMessage,
|
||||
UserMessage,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { streamSimple, validateToolArguments } from "@mariozechner/pi-ai";
|
||||
import type { TSchema } from "@sinclair/typebox";
|
||||
|
||||
class EventStream<T, R = T> implements AsyncIterable<T> {
|
||||
private queue: T[] = [];
|
||||
@@ -53,15 +52,20 @@ class EventStream<T, R = T> implements AsyncIterable<T> {
|
||||
this.resolveFinalResult(result);
|
||||
}
|
||||
while (this.waiting.length > 0) {
|
||||
const waiter = this.waiting.shift()!;
|
||||
waiter({ value: undefined as never, done: true });
|
||||
const waiter = this.waiting.shift();
|
||||
if (waiter) {
|
||||
waiter({ value: undefined as never, done: true });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async *[Symbol.asyncIterator](): AsyncIterator<T> {
|
||||
while (true) {
|
||||
if (this.queue.length > 0) {
|
||||
yield this.queue.shift()!;
|
||||
const next = this.queue.shift();
|
||||
if (next !== undefined) {
|
||||
yield next;
|
||||
}
|
||||
} else if (this.done) {
|
||||
return;
|
||||
} else {
|
||||
@@ -79,7 +83,10 @@ class EventStream<T, R = T> implements AsyncIterable<T> {
|
||||
}
|
||||
}
|
||||
|
||||
function createAgentStream(): EventStream<AgentEvent, AgentContext["messages"]> {
|
||||
function createAgentStream(): EventStream<
|
||||
AgentEvent,
|
||||
AgentContext["messages"]
|
||||
> {
|
||||
return new EventStream<AgentEvent, AgentContext["messages"]>(
|
||||
(event) => event.type === "agent_end",
|
||||
(event) => (event.type === "agent_end" ? event.messages : []),
|
||||
@@ -107,7 +114,14 @@ export function agentLoop(
|
||||
stream.push({ type: "message_start", message: prompt });
|
||||
stream.push({ type: "message_end", message: prompt });
|
||||
|
||||
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
|
||||
await runLoop(
|
||||
currentContext,
|
||||
newMessages,
|
||||
config,
|
||||
signal,
|
||||
stream,
|
||||
streamFn,
|
||||
);
|
||||
})();
|
||||
|
||||
return stream;
|
||||
@@ -138,7 +152,14 @@ export function agentLoopContinue(
|
||||
stream.push({ type: "agent_start" });
|
||||
stream.push({ type: "turn_start" });
|
||||
|
||||
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
|
||||
await runLoop(
|
||||
currentContext,
|
||||
newMessages,
|
||||
config,
|
||||
signal,
|
||||
stream,
|
||||
streamFn,
|
||||
);
|
||||
})();
|
||||
|
||||
return stream;
|
||||
@@ -154,9 +175,11 @@ async function runLoop(
|
||||
): Promise<void> {
|
||||
let hasMoreToolCalls = true;
|
||||
let firstTurn = true;
|
||||
let queuedMessages: QueuedMessage<any>[] =
|
||||
(await config.getQueuedMessages?.()) || [];
|
||||
let queuedAfterTools: QueuedMessage<any>[] | null = null;
|
||||
const getQueuedMessages = config.getQueuedMessages;
|
||||
let queuedMessages: QueuedMessage<Message>[] = getQueuedMessages
|
||||
? await getQueuedMessages<Message>()
|
||||
: [];
|
||||
let queuedAfterTools: QueuedMessage<Message>[] | null = null;
|
||||
|
||||
while (hasMoreToolCalls || queuedMessages.length > 0) {
|
||||
if (!firstTurn) {
|
||||
@@ -216,7 +239,9 @@ async function runLoop(
|
||||
queuedMessages = queuedAfterTools;
|
||||
queuedAfterTools = null;
|
||||
} else {
|
||||
queuedMessages = (await config.getQueuedMessages?.()) || [];
|
||||
queuedMessages = getQueuedMessages
|
||||
? await getQueuedMessages<Message>()
|
||||
: [];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -238,7 +263,7 @@ async function streamAssistantResponse(
|
||||
systemPrompt: context.systemPrompt,
|
||||
messages: [...processedMessages].map((m) => {
|
||||
if (m.role === "toolResult") {
|
||||
const { details, ...rest } = m;
|
||||
const { details: _details, ...rest } = m;
|
||||
return rest;
|
||||
}
|
||||
return m;
|
||||
@@ -248,8 +273,9 @@ async function streamAssistantResponse(
|
||||
|
||||
const streamFunction = streamFn || streamSimple;
|
||||
const resolvedApiKey =
|
||||
(config.getApiKey ? await config.getApiKey(config.model.provider) : undefined) ||
|
||||
config.apiKey;
|
||||
(config.getApiKey
|
||||
? await config.getApiKey(config.model.provider)
|
||||
: undefined) || config.apiKey;
|
||||
|
||||
const response = await streamFunction(config.model, processedContext, {
|
||||
...config,
|
||||
@@ -310,18 +336,20 @@ async function streamAssistantResponse(
|
||||
}
|
||||
|
||||
async function executeToolCalls<T>(
|
||||
tools: AgentTool<any, T>[] | undefined,
|
||||
tools: AgentTool<TSchema, T>[] | undefined,
|
||||
assistantMessage: AssistantMessage,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, Message[]>,
|
||||
getQueuedMessages?: AgentLoopConfig["getQueuedMessages"],
|
||||
): Promise<{
|
||||
toolResults: ToolResultMessage<T>[];
|
||||
queuedMessages?: QueuedMessage<any>[];
|
||||
queuedMessages?: QueuedMessage<Message>[];
|
||||
}> {
|
||||
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall");
|
||||
const results: ToolResultMessage<any>[] = [];
|
||||
let queuedMessages: QueuedMessage<any>[] | undefined;
|
||||
const toolCalls = assistantMessage.content.filter(
|
||||
(c) => c.type === "toolCall",
|
||||
);
|
||||
const results: ToolResultMessage<T>[] = [];
|
||||
let queuedMessages: QueuedMessage<Message>[] | undefined;
|
||||
|
||||
for (let index = 0; index < toolCalls.length; index++) {
|
||||
const toolCall = toolCalls[index];
|
||||
@@ -340,19 +368,27 @@ async function executeToolCalls<T>(
|
||||
try {
|
||||
if (!tool) throw new Error(`Tool ${toolCall.name} not found`);
|
||||
const validatedArgs = validateToolArguments(tool, toolCall);
|
||||
result = await tool.execute(toolCall.id, validatedArgs, signal, (partialResult) => {
|
||||
stream.push({
|
||||
type: "tool_execution_update",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
args: toolCall.arguments,
|
||||
partialResult,
|
||||
});
|
||||
});
|
||||
result = await tool.execute(
|
||||
toolCall.id,
|
||||
validatedArgs,
|
||||
signal,
|
||||
(partialResult) => {
|
||||
stream.push({
|
||||
type: "tool_execution_update",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
args: toolCall.arguments,
|
||||
partialResult,
|
||||
});
|
||||
},
|
||||
);
|
||||
} catch (err) {
|
||||
result = {
|
||||
content: [
|
||||
{ type: "text", text: err instanceof Error ? err.message : String(err) },
|
||||
{
|
||||
type: "text",
|
||||
text: err instanceof Error ? err.message : String(err),
|
||||
},
|
||||
],
|
||||
details: {} as T,
|
||||
};
|
||||
@@ -382,7 +418,7 @@ async function executeToolCalls<T>(
|
||||
stream.push({ type: "message_end", message: toolResultMessage });
|
||||
|
||||
if (getQueuedMessages) {
|
||||
const queued = await getQueuedMessages();
|
||||
const queued = await getQueuedMessages<Message>();
|
||||
if (queued.length > 0) {
|
||||
queuedMessages = queued;
|
||||
const remainingCalls = toolCalls.slice(index + 1);
|
||||
|
||||
Reference in New Issue
Block a user