refactor: tighten steerable agent loop typing
This commit is contained in:
@@ -1,19 +1,18 @@
|
|||||||
import { streamSimple, validateToolArguments } from "@mariozechner/pi-ai";
|
|
||||||
import type {
|
|
||||||
AssistantMessage,
|
|
||||||
Context,
|
|
||||||
Message,
|
|
||||||
ToolResultMessage,
|
|
||||||
UserMessage,
|
|
||||||
} from "@mariozechner/pi-ai";
|
|
||||||
import type {
|
import type {
|
||||||
AgentContext,
|
AgentContext,
|
||||||
AgentEvent,
|
AgentEvent,
|
||||||
AgentLoopConfig,
|
AgentLoopConfig,
|
||||||
AgentTool,
|
AgentTool,
|
||||||
AgentToolResult,
|
AgentToolResult,
|
||||||
|
AssistantMessage,
|
||||||
|
Context,
|
||||||
|
Message,
|
||||||
QueuedMessage,
|
QueuedMessage,
|
||||||
|
ToolResultMessage,
|
||||||
|
UserMessage,
|
||||||
} from "@mariozechner/pi-ai";
|
} from "@mariozechner/pi-ai";
|
||||||
|
import { streamSimple, validateToolArguments } from "@mariozechner/pi-ai";
|
||||||
|
import type { TSchema } from "@sinclair/typebox";
|
||||||
|
|
||||||
class EventStream<T, R = T> implements AsyncIterable<T> {
|
class EventStream<T, R = T> implements AsyncIterable<T> {
|
||||||
private queue: T[] = [];
|
private queue: T[] = [];
|
||||||
@@ -53,15 +52,20 @@ class EventStream<T, R = T> implements AsyncIterable<T> {
|
|||||||
this.resolveFinalResult(result);
|
this.resolveFinalResult(result);
|
||||||
}
|
}
|
||||||
while (this.waiting.length > 0) {
|
while (this.waiting.length > 0) {
|
||||||
const waiter = this.waiting.shift()!;
|
const waiter = this.waiting.shift();
|
||||||
waiter({ value: undefined as never, done: true });
|
if (waiter) {
|
||||||
|
waiter({ value: undefined as never, done: true });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async *[Symbol.asyncIterator](): AsyncIterator<T> {
|
async *[Symbol.asyncIterator](): AsyncIterator<T> {
|
||||||
while (true) {
|
while (true) {
|
||||||
if (this.queue.length > 0) {
|
if (this.queue.length > 0) {
|
||||||
yield this.queue.shift()!;
|
const next = this.queue.shift();
|
||||||
|
if (next !== undefined) {
|
||||||
|
yield next;
|
||||||
|
}
|
||||||
} else if (this.done) {
|
} else if (this.done) {
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
@@ -79,7 +83,10 @@ class EventStream<T, R = T> implements AsyncIterable<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function createAgentStream(): EventStream<AgentEvent, AgentContext["messages"]> {
|
function createAgentStream(): EventStream<
|
||||||
|
AgentEvent,
|
||||||
|
AgentContext["messages"]
|
||||||
|
> {
|
||||||
return new EventStream<AgentEvent, AgentContext["messages"]>(
|
return new EventStream<AgentEvent, AgentContext["messages"]>(
|
||||||
(event) => event.type === "agent_end",
|
(event) => event.type === "agent_end",
|
||||||
(event) => (event.type === "agent_end" ? event.messages : []),
|
(event) => (event.type === "agent_end" ? event.messages : []),
|
||||||
@@ -107,7 +114,14 @@ export function agentLoop(
|
|||||||
stream.push({ type: "message_start", message: prompt });
|
stream.push({ type: "message_start", message: prompt });
|
||||||
stream.push({ type: "message_end", message: prompt });
|
stream.push({ type: "message_end", message: prompt });
|
||||||
|
|
||||||
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
|
await runLoop(
|
||||||
|
currentContext,
|
||||||
|
newMessages,
|
||||||
|
config,
|
||||||
|
signal,
|
||||||
|
stream,
|
||||||
|
streamFn,
|
||||||
|
);
|
||||||
})();
|
})();
|
||||||
|
|
||||||
return stream;
|
return stream;
|
||||||
@@ -138,7 +152,14 @@ export function agentLoopContinue(
|
|||||||
stream.push({ type: "agent_start" });
|
stream.push({ type: "agent_start" });
|
||||||
stream.push({ type: "turn_start" });
|
stream.push({ type: "turn_start" });
|
||||||
|
|
||||||
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
|
await runLoop(
|
||||||
|
currentContext,
|
||||||
|
newMessages,
|
||||||
|
config,
|
||||||
|
signal,
|
||||||
|
stream,
|
||||||
|
streamFn,
|
||||||
|
);
|
||||||
})();
|
})();
|
||||||
|
|
||||||
return stream;
|
return stream;
|
||||||
@@ -154,9 +175,11 @@ async function runLoop(
|
|||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
let hasMoreToolCalls = true;
|
let hasMoreToolCalls = true;
|
||||||
let firstTurn = true;
|
let firstTurn = true;
|
||||||
let queuedMessages: QueuedMessage<any>[] =
|
const getQueuedMessages = config.getQueuedMessages;
|
||||||
(await config.getQueuedMessages?.()) || [];
|
let queuedMessages: QueuedMessage<Message>[] = getQueuedMessages
|
||||||
let queuedAfterTools: QueuedMessage<any>[] | null = null;
|
? await getQueuedMessages<Message>()
|
||||||
|
: [];
|
||||||
|
let queuedAfterTools: QueuedMessage<Message>[] | null = null;
|
||||||
|
|
||||||
while (hasMoreToolCalls || queuedMessages.length > 0) {
|
while (hasMoreToolCalls || queuedMessages.length > 0) {
|
||||||
if (!firstTurn) {
|
if (!firstTurn) {
|
||||||
@@ -216,7 +239,9 @@ async function runLoop(
|
|||||||
queuedMessages = queuedAfterTools;
|
queuedMessages = queuedAfterTools;
|
||||||
queuedAfterTools = null;
|
queuedAfterTools = null;
|
||||||
} else {
|
} else {
|
||||||
queuedMessages = (await config.getQueuedMessages?.()) || [];
|
queuedMessages = getQueuedMessages
|
||||||
|
? await getQueuedMessages<Message>()
|
||||||
|
: [];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -238,7 +263,7 @@ async function streamAssistantResponse(
|
|||||||
systemPrompt: context.systemPrompt,
|
systemPrompt: context.systemPrompt,
|
||||||
messages: [...processedMessages].map((m) => {
|
messages: [...processedMessages].map((m) => {
|
||||||
if (m.role === "toolResult") {
|
if (m.role === "toolResult") {
|
||||||
const { details, ...rest } = m;
|
const { details: _details, ...rest } = m;
|
||||||
return rest;
|
return rest;
|
||||||
}
|
}
|
||||||
return m;
|
return m;
|
||||||
@@ -248,8 +273,9 @@ async function streamAssistantResponse(
|
|||||||
|
|
||||||
const streamFunction = streamFn || streamSimple;
|
const streamFunction = streamFn || streamSimple;
|
||||||
const resolvedApiKey =
|
const resolvedApiKey =
|
||||||
(config.getApiKey ? await config.getApiKey(config.model.provider) : undefined) ||
|
(config.getApiKey
|
||||||
config.apiKey;
|
? await config.getApiKey(config.model.provider)
|
||||||
|
: undefined) || config.apiKey;
|
||||||
|
|
||||||
const response = await streamFunction(config.model, processedContext, {
|
const response = await streamFunction(config.model, processedContext, {
|
||||||
...config,
|
...config,
|
||||||
@@ -310,18 +336,20 @@ async function streamAssistantResponse(
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function executeToolCalls<T>(
|
async function executeToolCalls<T>(
|
||||||
tools: AgentTool<any, T>[] | undefined,
|
tools: AgentTool<TSchema, T>[] | undefined,
|
||||||
assistantMessage: AssistantMessage,
|
assistantMessage: AssistantMessage,
|
||||||
signal: AbortSignal | undefined,
|
signal: AbortSignal | undefined,
|
||||||
stream: EventStream<AgentEvent, Message[]>,
|
stream: EventStream<AgentEvent, Message[]>,
|
||||||
getQueuedMessages?: AgentLoopConfig["getQueuedMessages"],
|
getQueuedMessages?: AgentLoopConfig["getQueuedMessages"],
|
||||||
): Promise<{
|
): Promise<{
|
||||||
toolResults: ToolResultMessage<T>[];
|
toolResults: ToolResultMessage<T>[];
|
||||||
queuedMessages?: QueuedMessage<any>[];
|
queuedMessages?: QueuedMessage<Message>[];
|
||||||
}> {
|
}> {
|
||||||
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall");
|
const toolCalls = assistantMessage.content.filter(
|
||||||
const results: ToolResultMessage<any>[] = [];
|
(c) => c.type === "toolCall",
|
||||||
let queuedMessages: QueuedMessage<any>[] | undefined;
|
);
|
||||||
|
const results: ToolResultMessage<T>[] = [];
|
||||||
|
let queuedMessages: QueuedMessage<Message>[] | undefined;
|
||||||
|
|
||||||
for (let index = 0; index < toolCalls.length; index++) {
|
for (let index = 0; index < toolCalls.length; index++) {
|
||||||
const toolCall = toolCalls[index];
|
const toolCall = toolCalls[index];
|
||||||
@@ -340,19 +368,27 @@ async function executeToolCalls<T>(
|
|||||||
try {
|
try {
|
||||||
if (!tool) throw new Error(`Tool ${toolCall.name} not found`);
|
if (!tool) throw new Error(`Tool ${toolCall.name} not found`);
|
||||||
const validatedArgs = validateToolArguments(tool, toolCall);
|
const validatedArgs = validateToolArguments(tool, toolCall);
|
||||||
result = await tool.execute(toolCall.id, validatedArgs, signal, (partialResult) => {
|
result = await tool.execute(
|
||||||
stream.push({
|
toolCall.id,
|
||||||
type: "tool_execution_update",
|
validatedArgs,
|
||||||
toolCallId: toolCall.id,
|
signal,
|
||||||
toolName: toolCall.name,
|
(partialResult) => {
|
||||||
args: toolCall.arguments,
|
stream.push({
|
||||||
partialResult,
|
type: "tool_execution_update",
|
||||||
});
|
toolCallId: toolCall.id,
|
||||||
});
|
toolName: toolCall.name,
|
||||||
|
args: toolCall.arguments,
|
||||||
|
partialResult,
|
||||||
|
});
|
||||||
|
},
|
||||||
|
);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
result = {
|
result = {
|
||||||
content: [
|
content: [
|
||||||
{ type: "text", text: err instanceof Error ? err.message : String(err) },
|
{
|
||||||
|
type: "text",
|
||||||
|
text: err instanceof Error ? err.message : String(err),
|
||||||
|
},
|
||||||
],
|
],
|
||||||
details: {} as T,
|
details: {} as T,
|
||||||
};
|
};
|
||||||
@@ -382,7 +418,7 @@ async function executeToolCalls<T>(
|
|||||||
stream.push({ type: "message_end", message: toolResultMessage });
|
stream.push({ type: "message_end", message: toolResultMessage });
|
||||||
|
|
||||||
if (getQueuedMessages) {
|
if (getQueuedMessages) {
|
||||||
const queued = await getQueuedMessages();
|
const queued = await getQueuedMessages<Message>();
|
||||||
if (queued.length > 0) {
|
if (queued.length > 0) {
|
||||||
queuedMessages = queued;
|
queuedMessages = queued;
|
||||||
const remainingCalls = toolCalls.slice(index + 1);
|
const remainingCalls = toolCalls.slice(index + 1);
|
||||||
|
|||||||
Reference in New Issue
Block a user