Merge pull request #992 from tyler6204/fix/tool-typing-race-condition

fix: send text between tool calls to channel immediately
This commit is contained in:
Peter Steinberger
2026-01-16 07:26:10 +00:00
committed by GitHub
11 changed files with 121 additions and 71 deletions

View File

@@ -14,7 +14,7 @@ export function handleAgentStart(ctx: EmbeddedPiSubscribeContext) {
startedAt: Date.now(), startedAt: Date.now(),
}, },
}); });
ctx.params.onAgentEvent?.({ void ctx.params.onAgentEvent?.({
stream: "lifecycle", stream: "lifecycle",
data: { phase: "start" }, data: { phase: "start" },
}); });
@@ -24,7 +24,7 @@ export function handleAutoCompactionStart(ctx: EmbeddedPiSubscribeContext) {
ctx.state.compactionInFlight = true; ctx.state.compactionInFlight = true;
ctx.ensureCompactionPromise(); ctx.ensureCompactionPromise();
ctx.log.debug(`embedded run compaction start: runId=${ctx.params.runId}`); ctx.log.debug(`embedded run compaction start: runId=${ctx.params.runId}`);
ctx.params.onAgentEvent?.({ void ctx.params.onAgentEvent?.({
stream: "compaction", stream: "compaction",
data: { phase: "start" }, data: { phase: "start" },
}); });
@@ -43,7 +43,7 @@ export function handleAutoCompactionEnd(
} else { } else {
ctx.maybeResolveCompactionWait(); ctx.maybeResolveCompactionWait();
} }
ctx.params.onAgentEvent?.({ void ctx.params.onAgentEvent?.({
stream: "compaction", stream: "compaction",
data: { phase: "end", willRetry }, data: { phase: "end", willRetry },
}); });
@@ -59,7 +59,7 @@ export function handleAgentEnd(ctx: EmbeddedPiSubscribeContext) {
endedAt: Date.now(), endedAt: Date.now(),
}, },
}); });
ctx.params.onAgentEvent?.({ void ctx.params.onAgentEvent?.({
stream: "lifecycle", stream: "lifecycle",
data: { phase: "end" }, data: { phase: "end" },
}); });

View File

@@ -118,7 +118,7 @@ export function handleMessageUpdate(
mediaUrls: mediaUrls?.length ? mediaUrls : undefined, mediaUrls: mediaUrls?.length ? mediaUrls : undefined,
}, },
}); });
ctx.params.onAgentEvent?.({ void ctx.params.onAgentEvent?.({
stream: "assistant", stream: "assistant",
data: { data: {
text: cleanedText, text: cleanedText,

View File

@@ -11,7 +11,7 @@ import {
} from "./pi-embedded-subscribe.tools.js"; } from "./pi-embedded-subscribe.tools.js";
import { inferToolMetaFromArgs } from "./pi-embedded-utils.js"; import { inferToolMetaFromArgs } from "./pi-embedded-utils.js";
export function handleToolExecutionStart( export async function handleToolExecutionStart(
ctx: EmbeddedPiSubscribeContext, ctx: EmbeddedPiSubscribeContext,
evt: AgentEvent & { toolName: string; toolCallId: string; args: unknown }, evt: AgentEvent & { toolName: string; toolCallId: string; args: unknown },
) { ) {
@@ -53,7 +53,8 @@ export function handleToolExecutionStart(
args: args as Record<string, unknown>, args: args as Record<string, unknown>,
}, },
}); });
ctx.params.onAgentEvent?.({ // Await onAgentEvent to ensure typing indicator starts before tool summaries are emitted.
await ctx.params.onAgentEvent?.({
stream: "tool", stream: "tool",
data: { phase: "start", name: toolName, toolCallId }, data: { phase: "start", name: toolName, toolCallId },
}); });
@@ -108,7 +109,7 @@ export function handleToolExecutionUpdate(
partialResult: sanitized, partialResult: sanitized,
}, },
}); });
ctx.params.onAgentEvent?.({ void ctx.params.onAgentEvent?.({
stream: "tool", stream: "tool",
data: { data: {
phase: "update", phase: "update",
@@ -170,7 +171,7 @@ export function handleToolExecutionEnd(
result: sanitizedResult, result: sanitizedResult,
}, },
}); });
ctx.params.onAgentEvent?.({ void ctx.params.onAgentEvent?.({
stream: "tool", stream: "tool",
data: { data: {
phase: "result", phase: "result",

View File

@@ -32,7 +32,11 @@ export function createEmbeddedPiSessionEventHandler(ctx: EmbeddedPiSubscribeCont
handleMessageEnd(ctx, evt as never); handleMessageEnd(ctx, evt as never);
return; return;
case "tool_execution_start": case "tool_execution_start":
handleToolExecutionStart(ctx, evt as never); // Async handler - awaits typing indicator before emitting tool summaries.
// Catch rejections to avoid unhandled promise rejection crashes.
handleToolExecutionStart(ctx, evt as never).catch((err) => {
ctx.log.debug(`tool_execution_start handler failed: ${String(err)}`);
});
return; return;
case "tool_execution_update": case "tool_execution_update":
handleToolExecutionUpdate(ctx, evt as never); handleToolExecutionUpdate(ctx, evt as never);

View File

@@ -13,7 +13,7 @@ describe("subscribeEmbeddedPiSession", () => {
{ tag: "antthinking", open: "<antthinking>", close: "</antthinking>" }, { tag: "antthinking", open: "<antthinking>", close: "</antthinking>" },
] as const; ] as const;
it("includes canvas action metadata in tool summaries", () => { it("includes canvas action metadata in tool summaries", async () => {
let handler: ((evt: unknown) => void) | undefined; let handler: ((evt: unknown) => void) | undefined;
const session: StubSession = { const session: StubSession = {
subscribe: (fn) => { subscribe: (fn) => {
@@ -38,6 +38,9 @@ describe("subscribeEmbeddedPiSession", () => {
args: { action: "a2ui_push", jsonlPath: "/tmp/a2ui.jsonl" }, args: { action: "a2ui_push", jsonlPath: "/tmp/a2ui.jsonl" },
}); });
// Wait for async handler to complete
await Promise.resolve();
expect(onToolResult).toHaveBeenCalledTimes(1); expect(onToolResult).toHaveBeenCalledTimes(1);
const payload = onToolResult.mock.calls[0][0]; const payload = onToolResult.mock.calls[0][0];
expect(payload.text).toContain("🖼️"); expect(payload.text).toContain("🖼️");
@@ -72,7 +75,7 @@ describe("subscribeEmbeddedPiSession", () => {
expect(onToolResult).not.toHaveBeenCalled(); expect(onToolResult).not.toHaveBeenCalled();
}); });
it("emits tool summaries when shouldEmitToolResult overrides verbose", () => { it("emits tool summaries when shouldEmitToolResult overrides verbose", async () => {
let handler: ((evt: unknown) => void) | undefined; let handler: ((evt: unknown) => void) | undefined;
const session: StubSession = { const session: StubSession = {
subscribe: (fn) => { subscribe: (fn) => {
@@ -98,6 +101,9 @@ describe("subscribeEmbeddedPiSession", () => {
args: { path: "/tmp/c.txt" }, args: { path: "/tmp/c.txt" },
}); });
// Wait for async handler to complete
await Promise.resolve();
expect(onToolResult).toHaveBeenCalledTimes(1); expect(onToolResult).toHaveBeenCalledTimes(1);
}); });
}); });

View File

@@ -14,7 +14,7 @@ describe("subscribeEmbeddedPiSession", () => {
{ tag: "antthinking", open: "<antthinking>", close: "</antthinking>" }, { tag: "antthinking", open: "<antthinking>", close: "</antthinking>" },
] as const; ] as const;
it("suppresses message_end block replies when the message tool already sent", () => { it("suppresses message_end block replies when the message tool already sent", async () => {
let handler: ((evt: unknown) => void) | undefined; let handler: ((evt: unknown) => void) | undefined;
const session: StubSession = { const session: StubSession = {
subscribe: (fn) => { subscribe: (fn) => {
@@ -41,6 +41,9 @@ describe("subscribeEmbeddedPiSession", () => {
args: { action: "send", to: "+1555", message: messageText }, args: { action: "send", to: "+1555", message: messageText },
}); });
// Wait for async handler to complete
await Promise.resolve();
handler?.({ handler?.({
type: "tool_execution_end", type: "tool_execution_end",
toolName: "message", toolName: "message",
@@ -58,7 +61,7 @@ describe("subscribeEmbeddedPiSession", () => {
expect(onBlockReply).not.toHaveBeenCalled(); expect(onBlockReply).not.toHaveBeenCalled();
}); });
it("does not suppress message_end replies when message tool reports error", () => { it("does not suppress message_end replies when message tool reports error", async () => {
let handler: ((evt: unknown) => void) | undefined; let handler: ((evt: unknown) => void) | undefined;
const session: StubSession = { const session: StubSession = {
subscribe: (fn) => { subscribe: (fn) => {
@@ -85,6 +88,9 @@ describe("subscribeEmbeddedPiSession", () => {
args: { action: "send", to: "+1555", message: messageText }, args: { action: "send", to: "+1555", message: messageText },
}); });
// Wait for async handler to complete
await Promise.resolve();
handler?.({ handler?.({
type: "tool_execution_end", type: "tool_execution_end",
toolName: "message", toolName: "message",

View File

@@ -54,7 +54,7 @@ describe("subscribeEmbeddedPiSession", () => {
await waitPromise; await waitPromise;
expect(resolved).toBe(true); expect(resolved).toBe(true);
}); });
it("emits tool summaries at tool start when verbose is on", () => { it("emits tool summaries at tool start when verbose is on", async () => {
let handler: ((evt: unknown) => void) | undefined; let handler: ((evt: unknown) => void) | undefined;
const session: StubSession = { const session: StubSession = {
subscribe: (fn) => { subscribe: (fn) => {
@@ -79,6 +79,9 @@ describe("subscribeEmbeddedPiSession", () => {
args: { path: "/tmp/a.txt" }, args: { path: "/tmp/a.txt" },
}); });
// Wait for async handler to complete
await Promise.resolve();
expect(onToolResult).toHaveBeenCalledTimes(1); expect(onToolResult).toHaveBeenCalledTimes(1);
const payload = onToolResult.mock.calls[0][0]; const payload = onToolResult.mock.calls[0][0];
expect(payload.text).toContain("/tmp/a.txt"); expect(payload.text).toContain("/tmp/a.txt");
@@ -93,7 +96,7 @@ describe("subscribeEmbeddedPiSession", () => {
expect(onToolResult).toHaveBeenCalledTimes(1); expect(onToolResult).toHaveBeenCalledTimes(1);
}); });
it("includes browser action metadata in tool summaries", () => { it("includes browser action metadata in tool summaries", async () => {
let handler: ((evt: unknown) => void) | undefined; let handler: ((evt: unknown) => void) | undefined;
const session: StubSession = { const session: StubSession = {
subscribe: (fn) => { subscribe: (fn) => {
@@ -118,6 +121,9 @@ describe("subscribeEmbeddedPiSession", () => {
args: { action: "snapshot", targetUrl: "https://example.com" }, args: { action: "snapshot", targetUrl: "https://example.com" },
}); });
// Wait for async handler to complete
await Promise.resolve();
expect(onToolResult).toHaveBeenCalledTimes(1); expect(onToolResult).toHaveBeenCalledTimes(1);
const payload = onToolResult.mock.calls[0][0]; const payload = onToolResult.mock.calls[0][0];
expect(payload.text).toContain("🌐"); expect(payload.text).toContain("🌐");

View File

@@ -22,7 +22,10 @@ export type SubscribeEmbeddedPiSessionParams = {
blockReplyChunking?: BlockReplyChunking; blockReplyChunking?: BlockReplyChunking;
onPartialReply?: (payload: { text?: string; mediaUrls?: string[] }) => void | Promise<void>; onPartialReply?: (payload: { text?: string; mediaUrls?: string[] }) => void | Promise<void>;
onAssistantMessageStart?: () => void | Promise<void>; onAssistantMessageStart?: () => void | Promise<void>;
onAgentEvent?: (evt: { stream: string; data: Record<string, unknown> }) => void; onAgentEvent?: (evt: {
stream: string;
data: Record<string, unknown>;
}) => void | Promise<void>;
enforceFinalTag?: boolean; enforceFinalTag?: boolean;
}; };

View File

@@ -26,7 +26,7 @@ import type { VerboseLevel } from "../thinking.js";
import { isSilentReplyText, SILENT_REPLY_TOKEN } from "../tokens.js"; import { isSilentReplyText, SILENT_REPLY_TOKEN } from "../tokens.js";
import type { GetReplyOptions, ReplyPayload } from "../types.js"; import type { GetReplyOptions, ReplyPayload } from "../types.js";
import { buildThreadingToolContext, resolveEnforceFinalTag } from "./agent-runner-utils.js"; import { buildThreadingToolContext, resolveEnforceFinalTag } from "./agent-runner-utils.js";
import type { BlockReplyPipeline } from "./block-reply-pipeline.js"; import { createBlockReplyPayloadKey, type BlockReplyPipeline } from "./block-reply-pipeline.js";
import type { FollowupRun } from "./queue.js"; import type { FollowupRun } from "./queue.js";
import { parseReplyDirectives } from "./reply-directives.js"; import { parseReplyDirectives } from "./reply-directives.js";
import { applyReplyTagsToPayload, isRenderablePayload } from "./reply-payloads.js"; import { applyReplyTagsToPayload, isRenderablePayload } from "./reply-payloads.js";
@@ -40,6 +40,8 @@ export type AgentRunLoopResult =
fallbackModel?: string; fallbackModel?: string;
didLogHeartbeatStrip: boolean; didLogHeartbeatStrip: boolean;
autoCompactionCompleted: boolean; autoCompactionCompleted: boolean;
/** Payload keys sent directly (not via pipeline) during tool flush. */
directlySentBlockKeys?: Set<string>;
} }
| { kind: "final"; payload: ReplyPayload }; | { kind: "final"; payload: ReplyPayload };
@@ -70,6 +72,8 @@ export async function runAgentTurnWithFallback(params: {
}): Promise<AgentRunLoopResult> { }): Promise<AgentRunLoopResult> {
let didLogHeartbeatStrip = false; let didLogHeartbeatStrip = false;
let autoCompactionCompleted = false; let autoCompactionCompleted = false;
// Track payloads sent directly (not via pipeline) during tool flush to avoid duplicates.
const directlySentBlockKeys = new Set<string>();
const runId = crypto.randomUUID(); const runId = crypto.randomUUID();
if (params.sessionKey) { if (params.sessionKey) {
@@ -244,12 +248,13 @@ export async function runAgentTurnWithFallback(params: {
}); });
} }
: undefined, : undefined,
onAgentEvent: (evt) => { onAgentEvent: async (evt) => {
// Trigger typing when tools start executing // Trigger typing when tools start executing.
// Must await to ensure typing indicator starts before tool summaries are emitted.
if (evt.stream === "tool") { if (evt.stream === "tool") {
const phase = typeof evt.data.phase === "string" ? evt.data.phase : ""; const phase = typeof evt.data.phase === "string" ? evt.data.phase : "";
if (phase === "start" || phase === "update") { if (phase === "start" || phase === "update") {
void params.typingSignals.signalToolStart(); await params.typingSignals.signalToolStart();
} }
} }
// Track auto-compaction completion // Track auto-compaction completion
@@ -261,57 +266,67 @@ export async function runAgentTurnWithFallback(params: {
} }
} }
}, },
onBlockReply: // Always pass onBlockReply so flushBlockReplyBuffer works before tool execution,
params.blockStreamingEnabled && params.opts?.onBlockReply // even when regular block streaming is disabled. The handler sends directly
? async (payload) => { // via opts.onBlockReply when the pipeline isn't available.
const { text, skip } = normalizeStreamingText(payload); onBlockReply: params.opts?.onBlockReply
const hasPayloadMedia = (payload.mediaUrls?.length ?? 0) > 0; ? async (payload) => {
if (skip && !hasPayloadMedia) return; const { text, skip } = normalizeStreamingText(payload);
const taggedPayload = applyReplyTagsToPayload( const hasPayloadMedia = (payload.mediaUrls?.length ?? 0) > 0;
{ if (skip && !hasPayloadMedia) return;
text, const taggedPayload = applyReplyTagsToPayload(
mediaUrls: payload.mediaUrls, {
mediaUrl: payload.mediaUrls?.[0], text,
}, mediaUrls: payload.mediaUrls,
params.sessionCtx.MessageSid, mediaUrl: payload.mediaUrls?.[0],
); },
// Let through payloads with audioAsVoice flag even if empty (need to track it) params.sessionCtx.MessageSid,
if (!isRenderablePayload(taggedPayload) && !payload.audioAsVoice) return; );
const parsed = parseReplyDirectives(taggedPayload.text ?? "", { // Let through payloads with audioAsVoice flag even if empty (need to track it)
currentMessageId: params.sessionCtx.MessageSid, if (!isRenderablePayload(taggedPayload) && !payload.audioAsVoice) return;
silentToken: SILENT_REPLY_TOKEN, const parsed = parseReplyDirectives(taggedPayload.text ?? "", {
}); currentMessageId: params.sessionCtx.MessageSid,
const cleaned = parsed.text || undefined; silentToken: SILENT_REPLY_TOKEN,
const hasRenderableMedia = });
Boolean(taggedPayload.mediaUrl) || (taggedPayload.mediaUrls?.length ?? 0) > 0; const cleaned = parsed.text || undefined;
// Skip empty payloads unless they have audioAsVoice flag (need to track it) const hasRenderableMedia =
if ( Boolean(taggedPayload.mediaUrl) || (taggedPayload.mediaUrls?.length ?? 0) > 0;
!cleaned && // Skip empty payloads unless they have audioAsVoice flag (need to track it)
!hasRenderableMedia && if (
!payload.audioAsVoice && !cleaned &&
!parsed.audioAsVoice !hasRenderableMedia &&
) !payload.audioAsVoice &&
return; !parsed.audioAsVoice
if (parsed.isSilent && !hasRenderableMedia) return; )
return;
if (parsed.isSilent && !hasRenderableMedia) return;
const blockPayload: ReplyPayload = params.applyReplyToMode({ const blockPayload: ReplyPayload = params.applyReplyToMode({
...taggedPayload, ...taggedPayload,
text: cleaned, text: cleaned,
audioAsVoice: Boolean(parsed.audioAsVoice || payload.audioAsVoice), audioAsVoice: Boolean(parsed.audioAsVoice || payload.audioAsVoice),
replyToId: taggedPayload.replyToId ?? parsed.replyToId, replyToId: taggedPayload.replyToId ?? parsed.replyToId,
replyToTag: taggedPayload.replyToTag || parsed.replyToTag, replyToTag: taggedPayload.replyToTag || parsed.replyToTag,
replyToCurrent: taggedPayload.replyToCurrent || parsed.replyToCurrent, replyToCurrent: taggedPayload.replyToCurrent || parsed.replyToCurrent,
});
void params.typingSignals
.signalTextDelta(cleaned ?? taggedPayload.text)
.catch((err) => {
logVerbose(`block reply typing signal failed: ${String(err)}`);
}); });
void params.typingSignals // Use pipeline if available (block streaming enabled), otherwise send directly
.signalTextDelta(cleaned ?? taggedPayload.text) if (params.blockStreamingEnabled && params.blockReplyPipeline) {
.catch((err) => { params.blockReplyPipeline.enqueue(blockPayload);
logVerbose(`block reply typing signal failed: ${String(err)}`); } else {
}); // Send directly when flushing before tool execution (no streaming).
// Track sent key to avoid duplicate in final payloads.
params.blockReplyPipeline?.enqueue(blockPayload); directlySentBlockKeys.add(createBlockReplyPayloadKey(blockPayload));
await params.opts?.onBlockReply?.(blockPayload);
} }
: undefined, }
: undefined,
onBlockReplyFlush: onBlockReplyFlush:
params.blockStreamingEnabled && blockReplyPipeline params.blockStreamingEnabled && blockReplyPipeline
? async () => { ? async () => {
@@ -447,5 +462,6 @@ export async function runAgentTurnWithFallback(params: {
fallbackModel, fallbackModel,
didLogHeartbeatStrip, didLogHeartbeatStrip,
autoCompactionCompleted, autoCompactionCompleted,
directlySentBlockKeys: directlySentBlockKeys.size > 0 ? directlySentBlockKeys : undefined,
}; };
} }

View File

@@ -5,7 +5,7 @@ import type { OriginatingChannelType } from "../templating.js";
import { SILENT_REPLY_TOKEN } from "../tokens.js"; import { SILENT_REPLY_TOKEN } from "../tokens.js";
import type { ReplyPayload } from "../types.js"; import type { ReplyPayload } from "../types.js";
import { formatBunFetchSocketError, isBunFetchSocketError } from "./agent-runner-utils.js"; import { formatBunFetchSocketError, isBunFetchSocketError } from "./agent-runner-utils.js";
import type { BlockReplyPipeline } from "./block-reply-pipeline.js"; import { createBlockReplyPayloadKey, type BlockReplyPipeline } from "./block-reply-pipeline.js";
import { parseReplyDirectives } from "./reply-directives.js"; import { parseReplyDirectives } from "./reply-directives.js";
import { import {
applyReplyThreading, applyReplyThreading,
@@ -20,6 +20,8 @@ export function buildReplyPayloads(params: {
didLogHeartbeatStrip: boolean; didLogHeartbeatStrip: boolean;
blockStreamingEnabled: boolean; blockStreamingEnabled: boolean;
blockReplyPipeline: BlockReplyPipeline | null; blockReplyPipeline: BlockReplyPipeline | null;
/** Payload keys sent directly (not via pipeline) during tool flush. */
directlySentBlockKeys?: Set<string>;
replyToMode: ReplyToMode; replyToMode: ReplyToMode;
replyToChannel?: OriginatingChannelType; replyToChannel?: OriginatingChannelType;
currentMessageId?: string; currentMessageId?: string;
@@ -98,11 +100,16 @@ export function buildReplyPayloads(params: {
payloads: replyTaggedPayloads, payloads: replyTaggedPayloads,
sentTexts: messagingToolSentTexts, sentTexts: messagingToolSentTexts,
}); });
// Filter out payloads already sent via pipeline or directly during tool flush.
const filteredPayloads = shouldDropFinalPayloads const filteredPayloads = shouldDropFinalPayloads
? [] ? []
: params.blockStreamingEnabled : params.blockStreamingEnabled
? dedupedPayloads.filter((payload) => !params.blockReplyPipeline?.hasSentPayload(payload)) ? dedupedPayloads.filter((payload) => !params.blockReplyPipeline?.hasSentPayload(payload))
: dedupedPayloads; : params.directlySentBlockKeys?.size
? dedupedPayloads.filter(
(payload) => !params.directlySentBlockKeys!.has(createBlockReplyPayloadKey(payload)),
)
: dedupedPayloads;
const replyPayloads = suppressMessagingToolReplies ? [] : filteredPayloads; const replyPayloads = suppressMessagingToolReplies ? [] : filteredPayloads;
return { return {

View File

@@ -272,7 +272,7 @@ export async function runReplyAgent(params: {
return finalizeWithFollowup(runOutcome.payload, queueKey, runFollowupTurn); return finalizeWithFollowup(runOutcome.payload, queueKey, runFollowupTurn);
} }
const { runResult, fallbackProvider, fallbackModel } = runOutcome; const { runResult, fallbackProvider, fallbackModel, directlySentBlockKeys } = runOutcome;
let { didLogHeartbeatStrip, autoCompactionCompleted } = runOutcome; let { didLogHeartbeatStrip, autoCompactionCompleted } = runOutcome;
if ( if (
@@ -314,6 +314,7 @@ export async function runReplyAgent(params: {
didLogHeartbeatStrip, didLogHeartbeatStrip,
blockStreamingEnabled, blockStreamingEnabled,
blockReplyPipeline, blockReplyPipeline,
directlySentBlockKeys,
replyToMode, replyToMode,
replyToChannel, replyToChannel,
currentMessageId: sessionCtx.MessageSid, currentMessageId: sessionCtx.MessageSid,