diff --git a/src/agents/pi-embedded-helpers.ts b/src/agents/pi-embedded-helpers.ts index 4d7539937..7cd1817c3 100644 --- a/src/agents/pi-embedded-helpers.ts +++ b/src/agents/pi-embedded-helpers.ts @@ -18,6 +18,10 @@ import { sanitizeToolCallIdsForCloudCodeAssist, } from "./tool-call-id.js"; import { sanitizeContentBlocksImages } from "./tool-images.js"; +import { + repairToolUseResultPairing, + sanitizeToolUseResultPairing, +} from "./session-transcript-repair.js"; import type { WorkspaceBootstrapFile } from "./workspace.js"; export type EmbeddedContextFile = { path: string; content: string }; @@ -98,8 +102,10 @@ export async function sanitizeSessionMessagesImages( const sanitizedIds = options?.sanitizeToolCallIds ? sanitizeToolCallIdsForCloudCodeAssist(messages) : messages; + const repaired = repairToolUseResultPairing(sanitizedIds); + const base = repaired.messages; const out: AgentMessage[] = []; - for (const msg of sanitizedIds) { + for (const msg of base) { if (!msg || typeof msg !== "object") { out.push(msg); continue; diff --git a/src/agents/pi-embedded-runner.guard.test.ts b/src/agents/pi-embedded-runner.guard.test.ts new file mode 100644 index 000000000..4e1816b81 --- /dev/null +++ b/src/agents/pi-embedded-runner.guard.test.ts @@ -0,0 +1,42 @@ +import { describe, expect, it } from "vitest"; + +import type { AgentMessage } from "@mariozechner/pi-agent-core"; +import { SessionManager } from "@mariozechner/pi-coding-agent"; + +import { guardSessionManager } from "./session-tool-result-guard-wrapper.js"; +import { sanitizeToolUseResultPairing } from "./session-transcript-repair.js"; + +function assistantToolCall(id: string): AgentMessage { + return { + role: "assistant", + content: [{ type: "toolCall", id, name: "n", arguments: {} }], + } as AgentMessage; +} + +describe("guardSessionManager integration", () => { + it("persists synthetic toolResult before subsequent assistant message", () => { + const sm = guardSessionManager(SessionManager.inMemory()); + + sm.appendMessage(assistantToolCall("call_1")); + sm.appendMessage({ + role: "assistant", + content: [{ type: "text", text: "followup" }], + } as AgentMessage); + + const messages = sm + .getEntries() + .filter((e) => e.type === "message") + .map((e) => (e as { message: AgentMessage }).message); + + expect(messages.map((m) => m.role)).toEqual([ + "assistant", + "toolResult", + "assistant", + ]); + expect((messages[1] as { toolCallId?: string }).toolCallId).toBe("call_1"); + expect( + sanitizeToolUseResultPairing(messages).map((m) => m.role), + ).toEqual(["assistant", "toolResult", "assistant"]); + }); +}); + diff --git a/src/agents/pi-embedded-runner.ts b/src/agents/pi-embedded-runner.ts index 1fa0cfc07..9a2949872 100644 --- a/src/agents/pi-embedded-runner.ts +++ b/src/agents/pi-embedded-runner.ts @@ -118,6 +118,10 @@ import { toToolDefinitions } from "./pi-tool-definition-adapter.js"; import { createClawdbotCodingTools } from "./pi-tools.js"; import { resolveSandboxContext } from "./sandbox.js"; import { sanitizeToolUseResultPairing } from "./session-transcript-repair.js"; +import { + guardSessionManager, + type GuardedSessionManager, +} from "./session-tool-result-guard-wrapper.js"; import { applySkillEnvOverrides, applySkillEnvOverridesFromSnapshot, @@ -1227,7 +1231,9 @@ export async function compactEmbeddedPiSession(params: { try { // Pre-warm session file to bring it into OS page cache await prewarmSessionFile(params.sessionFile); - const sessionManager = SessionManager.open(params.sessionFile); + const sessionManager = guardSessionManager( + SessionManager.open(params.sessionFile), + ); trackSessionManagerAccess(params.sessionFile); const settingsManager = SettingsManager.create( effectiveWorkspace, @@ -1308,6 +1314,7 @@ export async function compactEmbeddedPiSession(params: { }, }; } finally { + sessionManager.flushPendingToolResults?.(); session.dispose(); } } finally { @@ -1665,6 +1672,9 @@ export async function runEmbeddedPiAgent(params: { model, }); + const toolResultGuard = + installSessionToolResultGuard(sessionManager); + const { builtInTools, customTools } = splitSdkTools({ tools, sandboxEnabled: !!sandbox?.enabled, @@ -1717,6 +1727,7 @@ export async function runEmbeddedPiAgent(params: { session.agent.replaceMessages(limited); } } catch (err) { + toolResultGuard.flushPendingToolResults(); session.dispose(); await sessionLock.release(); throw err; @@ -1748,6 +1759,7 @@ export async function runEmbeddedPiAgent(params: { enforceFinalTag: params.enforceFinalTag, }); } catch (err) { + toolResultGuard.flushPendingToolResults(); session.dispose(); await sessionLock.release(); throw err; @@ -1845,6 +1857,7 @@ export async function runEmbeddedPiAgent(params: { ACTIVE_EMBEDDED_RUNS.delete(params.sessionId); notifyEmbeddedRunEnded(params.sessionId); } + sessionManager.flushPendingToolResults?.(); session.dispose(); await sessionLock.release(); params.abortSignal?.removeEventListener?.("abort", onAbort); diff --git a/src/agents/pi-extensions/transcript-sanitize.ts b/src/agents/pi-extensions/transcript-sanitize.ts index fd819bfe7..5ac8333ac 100644 --- a/src/agents/pi-extensions/transcript-sanitize.ts +++ b/src/agents/pi-extensions/transcript-sanitize.ts @@ -14,15 +14,18 @@ import type { } from "@mariozechner/pi-coding-agent"; import { isGoogleModelApi } from "../pi-embedded-helpers.js"; -import { sanitizeToolUseResultPairing } from "../session-transcript-repair.js"; +import { + repairToolUseResultPairing, + sanitizeToolUseResultPairing, +} from "../session-transcript-repair.js"; import { sanitizeToolCallIdsForCloudCodeAssist } from "../tool-call-id.js"; export default function transcriptSanitizeExtension(api: ExtensionAPI): void { api.on("context", (event: ContextEvent, ctx: ExtensionContext) => { let next = event.messages as AgentMessage[]; - const repairedTools = sanitizeToolUseResultPairing(next); - if (repairedTools !== next) next = repairedTools; + const repaired = repairToolUseResultPairing(next); + if (repaired.messages !== next) next = repaired.messages; if (isGoogleModelApi(ctx.model?.api)) { const repairedIds = sanitizeToolCallIdsForCloudCodeAssist(next); diff --git a/src/agents/session-tool-result-guard-wrapper.ts b/src/agents/session-tool-result-guard-wrapper.ts new file mode 100644 index 000000000..4e11dc9fd --- /dev/null +++ b/src/agents/session-tool-result-guard-wrapper.ts @@ -0,0 +1,26 @@ +import type { SessionManager } from "@mariozechner/pi-coding-agent"; + +import { installSessionToolResultGuard } from "./session-tool-result-guard.js"; + +export type GuardedSessionManager = SessionManager & { + /** Flush any synthetic tool results for pending tool calls. Idempotent. */ + flushPendingToolResults?: () => void; +}; + +/** + * Apply the tool-result guard to a SessionManager exactly once and expose + * a flush method on the instance for easy teardown handling. + */ +export function guardSessionManager( + sessionManager: SessionManager, +): GuardedSessionManager { + if (typeof (sessionManager as GuardedSessionManager).flushPendingToolResults === "function") { + return sessionManager as GuardedSessionManager; + } + + const guard = installSessionToolResultGuard(sessionManager); + (sessionManager as GuardedSessionManager).flushPendingToolResults = + guard.flushPendingToolResults; + return sessionManager as GuardedSessionManager; +} + diff --git a/src/agents/session-tool-result-guard.test.ts b/src/agents/session-tool-result-guard.test.ts new file mode 100644 index 000000000..ed0d3959a --- /dev/null +++ b/src/agents/session-tool-result-guard.test.ts @@ -0,0 +1,152 @@ +import { describe, expect, it } from "vitest"; + +import { SessionManager } from "@mariozechner/pi-coding-agent"; +import type { AgentMessage } from "@mariozechner/pi-agent-core"; + +import { installSessionToolResultGuard } from "./session-tool-result-guard.js"; + +const toolCallMessage = { + role: "assistant", + content: [{ type: "toolCall", id: "call_1", name: "read", arguments: {} }], +} satisfies AgentMessage; + +describe("installSessionToolResultGuard", () => { + it("inserts synthetic toolResult before non-tool message when pending", () => { + const sm = SessionManager.inMemory(); + installSessionToolResultGuard(sm); + + sm.appendMessage(toolCallMessage); + sm.appendMessage({ + role: "assistant", + content: [{ type: "text", text: "error" }], + stopReason: "error", + } as AgentMessage); + + const entries = sm + .getEntries() + .filter((e) => e.type === "message") + .map((e) => (e as { message: AgentMessage }).message); + + expect(entries.map((m) => m.role)).toEqual([ + "assistant", + "toolResult", + "assistant", + ]); + const synthetic = entries[1] as { + toolCallId?: string; + isError?: boolean; + content?: Array<{ type?: string; text?: string }>; + }; + expect(synthetic.toolCallId).toBe("call_1"); + expect(synthetic.isError).toBe(true); + expect(synthetic.content?.[0]?.text).toContain("missing tool result"); + }); + + it("flushes pending tool calls when asked explicitly", () => { + const sm = SessionManager.inMemory(); + const guard = installSessionToolResultGuard(sm); + + sm.appendMessage(toolCallMessage); + guard.flushPendingToolResults(); + + const messages = sm + .getEntries() + .filter((e) => e.type === "message") + .map((e) => (e as { message: AgentMessage }).message); + + expect(messages.map((m) => m.role)).toEqual(["assistant", "toolResult"]); + }); + + it("does not add synthetic toolResult when a matching one exists", () => { + const sm = SessionManager.inMemory(); + installSessionToolResultGuard(sm); + + sm.appendMessage(toolCallMessage); + sm.appendMessage({ + role: "toolResult", + toolCallId: "call_1", + content: [{ type: "text", text: "ok" }], + isError: false, + } as AgentMessage); + + const messages = sm + .getEntries() + .filter((e) => e.type === "message") + .map((e) => (e as { message: AgentMessage }).message); + + expect(messages.map((m) => m.role)).toEqual(["assistant", "toolResult"]); + }); + + it("preserves ordering with multiple tool calls and partial results", () => { + const sm = SessionManager.inMemory(); + const guard = installSessionToolResultGuard(sm); + + sm.appendMessage({ + role: "assistant", + content: [ + { type: "toolCall", id: "call_a", name: "one", arguments: {} }, + { type: "toolUse", id: "call_b", name: "two", arguments: {} }, + ], + } as AgentMessage); + sm.appendMessage({ + role: "toolResult", + toolUseId: "call_a", + content: [{ type: "text", text: "a" }], + isError: false, + } as AgentMessage); + sm.appendMessage({ + role: "assistant", + content: [{ type: "text", text: "after tools" }], + } as AgentMessage); + + const messages = sm + .getEntries() + .filter((e) => e.type === "message") + .map((e) => (e as { message: AgentMessage }).message); + + expect(messages.map((m) => m.role)).toEqual([ + "assistant", // tool calls + "toolResult", // call_a real + "toolResult", // synthetic for call_b + "assistant", // text + ]); + expect( + (messages[2] as { toolCallId?: string }).toolCallId, + ).toBe("call_b"); + expect(guard.getPendingIds()).toEqual([]); + }); + + it("flushes pending on guard when no toolResult arrived", () => { + const sm = SessionManager.inMemory(); + const guard = installSessionToolResultGuard(sm); + + sm.appendMessage(toolCallMessage); + sm.appendMessage({ + role: "assistant", + content: [{ type: "text", text: "hard error" }], + stopReason: "error", + } as AgentMessage); + expect(guard.getPendingIds()).toEqual([]); + }); + + it("handles toolUseId on toolResult", () => { + const sm = SessionManager.inMemory(); + installSessionToolResultGuard(sm); + + sm.appendMessage({ + role: "assistant", + content: [{ type: "toolUse", id: "use_1", name: "f", arguments: {} }], + } as AgentMessage); + sm.appendMessage({ + role: "toolResult", + toolUseId: "use_1", + content: [{ type: "text", text: "ok" }], + } as AgentMessage); + + const messages = sm + .getEntries() + .filter((e) => e.type === "message") + .map((e) => (e as { message: AgentMessage }).message); + expect(messages.map((m) => m.role)).toEqual(["assistant", "toolResult"]); + }); +}); diff --git a/src/agents/session-tool-result-guard.ts b/src/agents/session-tool-result-guard.ts new file mode 100644 index 000000000..b672318a2 --- /dev/null +++ b/src/agents/session-tool-result-guard.ts @@ -0,0 +1,103 @@ +import type { AgentMessage } from "@mariozechner/pi-agent-core"; +import type { SessionManager } from "@mariozechner/pi-coding-agent"; + +import { makeMissingToolResult } from "./session-transcript-repair.js"; + +type ToolCall = { id: string; name?: string }; + +function extractAssistantToolCalls( + msg: Extract, +): ToolCall[] { + const content = msg.content; + if (!Array.isArray(content)) return []; + + const toolCalls: ToolCall[] = []; + for (const block of content) { + if (!block || typeof block !== "object") continue; + const rec = block as { type?: unknown; id?: unknown; name?: unknown }; + if (typeof rec.id !== "string" || !rec.id) continue; + if ( + rec.type === "toolCall" || + rec.type === "toolUse" || + rec.type === "functionCall" + ) { + toolCalls.push({ + id: rec.id, + name: typeof rec.name === "string" ? rec.name : undefined, + }); + } + } + return toolCalls; +} + +function extractToolResultId( + msg: Extract, +): string | null { + const toolCallId = (msg as { toolCallId?: unknown }).toolCallId; + if (typeof toolCallId === "string" && toolCallId) return toolCallId; + const toolUseId = (msg as { toolUseId?: unknown }).toolUseId; + if (typeof toolUseId === "string" && toolUseId) return toolUseId; + return null; +} + +export function installSessionToolResultGuard(sessionManager: SessionManager): { + flushPendingToolResults: () => void; + getPendingIds: () => string[]; +} { + const originalAppend = sessionManager.appendMessage.bind(sessionManager); + const pending = new Map(); + + const flushPendingToolResults = () => { + if (pending.size === 0) return; + for (const [id, name] of pending.entries()) { + originalAppend(makeMissingToolResult({ toolCallId: id, toolName: name })); + } + pending.clear(); + }; + + const guardedAppend = (message: AgentMessage) => { + const role = (message as { role?: unknown }).role; + + if (role === "toolResult") { + const id = extractToolResultId( + message as Extract, + ); + if (id) pending.delete(id); + return originalAppend(message as never); + } + + const toolCalls = + role === "assistant" + ? extractAssistantToolCalls( + message as Extract, + ) + : []; + + // If previous tool calls are still pending, flush before non-tool results. + if (pending.size > 0 && (toolCalls.length === 0 || role !== "assistant")) { + flushPendingToolResults(); + } + // If new tool calls arrive while older ones are pending, flush the old ones first. + if (pending.size > 0 && toolCalls.length > 0) { + flushPendingToolResults(); + } + + const result = originalAppend(message as never); + + if (toolCalls.length > 0) { + for (const call of toolCalls) { + pending.set(call.id, call.name); + } + } + + return result; + }; + + // Monkey-patch appendMessage with our guarded version. + sessionManager.appendMessage = guardedAppend as SessionManager["appendMessage"]; + + return { + flushPendingToolResults, + getPendingIds: () => Array.from(pending.keys()), + }; +} diff --git a/src/agents/session-transcript-repair.ts b/src/agents/session-transcript-repair.ts index 7e7f86fe8..b214f9964 100644 --- a/src/agents/session-transcript-repair.ts +++ b/src/agents/session-transcript-repair.ts @@ -60,9 +60,25 @@ function makeMissingToolResult(params: { } as Extract; } +export { makeMissingToolResult }; + export function sanitizeToolUseResultPairing( messages: AgentMessage[], ): AgentMessage[] { + return repairToolUseResultPairing(messages).messages; +} + +export type ToolUseRepairReport = { + messages: AgentMessage[]; + added: Array>; + droppedDuplicateCount: number; + droppedOrphanCount: number; + moved: boolean; +}; + +export function repairToolUseResultPairing( + messages: AgentMessage[], +): ToolUseRepairReport { // Anthropic (and Cloud Code Assist) reject transcripts where assistant tool calls are not // immediately followed by matching tool results. Session files can end up with results // displaced (e.g. after user turns) or duplicated. Repair by: @@ -70,13 +86,22 @@ export function sanitizeToolUseResultPairing( // - inserting synthetic error toolResults for missing ids // - dropping duplicate toolResults for the same id (anywhere in the transcript) const out: AgentMessage[] = []; + const added: Array> = []; const seenToolResultIds = new Set(); + let droppedDuplicateCount = 0; + let droppedOrphanCount = 0; + let moved = false; + let changed = false; const pushToolResult = ( msg: Extract, ) => { const id = extractToolResultId(msg); - if (id && seenToolResultIds.has(id)) return; + if (id && seenToolResultIds.has(id)) { + droppedDuplicateCount += 1; + changed = true; + return; + } if (id) seenToolResultIds.add(id); out.push(msg); }; @@ -93,7 +118,12 @@ export function sanitizeToolUseResultPairing( // Tool results must only appear directly after the matching assistant tool call turn. // Any "free-floating" toolResult entries in session history can make strict providers // (Anthropic-compatible APIs, MiniMax, Cloud Code Assist) reject the entire request. - if (role !== "toolResult") out.push(msg); + if (role !== "toolResult") { + out.push(msg); + } else { + droppedOrphanCount += 1; + changed = true; + } continue; } @@ -131,6 +161,8 @@ export function sanitizeToolUseResultPairing( const id = extractToolResultId(toolResult); if (id && toolCallIds.has(id)) { if (seenToolResultIds.has(id)) { + droppedDuplicateCount += 1; + changed = true; continue; } if (!spanResultsById.has(id)) { @@ -141,17 +173,34 @@ export function sanitizeToolUseResultPairing( } // Drop tool results that don't match the current assistant tool calls. - if (nextRole !== "toolResult") remainder.push(next); + if (nextRole !== "toolResult") { + remainder.push(next); + } else { + droppedOrphanCount += 1; + changed = true; + } } out.push(msg); + if (spanResultsById.size > 0 && remainder.length > 0) { + moved = true; + changed = true; + } + for (const call of toolCalls) { const existing = spanResultsById.get(call.id); - pushToolResult( - existing ?? - makeMissingToolResult({ toolCallId: call.id, toolName: call.name }), - ); + if (existing) { + pushToolResult(existing); + } else { + const missing = makeMissingToolResult({ + toolCallId: call.id, + toolName: call.name, + }); + added.push(missing); + changed = true; + pushToolResult(missing); + } } for (const rem of remainder) { @@ -164,5 +213,12 @@ export function sanitizeToolUseResultPairing( i = j - 1; } - return out; + const changedOrMoved = changed || moved; + return { + messages: changedOrMoved ? out : messages, + added, + droppedDuplicateCount, + droppedOrphanCount, + moved: changedOrMoved, + }; }