fix: guard session tool results

This commit is contained in:
Peter Steinberger
2026-01-12 17:28:39 +00:00
parent f83fb70360
commit f5d5661adf
8 changed files with 414 additions and 13 deletions

View File

@@ -18,6 +18,10 @@ import {
sanitizeToolCallIdsForCloudCodeAssist,
} from "./tool-call-id.js";
import { sanitizeContentBlocksImages } from "./tool-images.js";
import {
repairToolUseResultPairing,
sanitizeToolUseResultPairing,
} from "./session-transcript-repair.js";
import type { WorkspaceBootstrapFile } from "./workspace.js";
export type EmbeddedContextFile = { path: string; content: string };
@@ -98,8 +102,10 @@ export async function sanitizeSessionMessagesImages(
const sanitizedIds = options?.sanitizeToolCallIds
? sanitizeToolCallIdsForCloudCodeAssist(messages)
: messages;
const repaired = repairToolUseResultPairing(sanitizedIds);
const base = repaired.messages;
const out: AgentMessage[] = [];
for (const msg of sanitizedIds) {
for (const msg of base) {
if (!msg || typeof msg !== "object") {
out.push(msg);
continue;

View File

@@ -0,0 +1,42 @@
import { describe, expect, it } from "vitest";
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import { SessionManager } from "@mariozechner/pi-coding-agent";
import { guardSessionManager } from "./session-tool-result-guard-wrapper.js";
import { sanitizeToolUseResultPairing } from "./session-transcript-repair.js";
function assistantToolCall(id: string): AgentMessage {
return {
role: "assistant",
content: [{ type: "toolCall", id, name: "n", arguments: {} }],
} as AgentMessage;
}
describe("guardSessionManager integration", () => {
it("persists synthetic toolResult before subsequent assistant message", () => {
const sm = guardSessionManager(SessionManager.inMemory());
sm.appendMessage(assistantToolCall("call_1"));
sm.appendMessage({
role: "assistant",
content: [{ type: "text", text: "followup" }],
} as AgentMessage);
const messages = sm
.getEntries()
.filter((e) => e.type === "message")
.map((e) => (e as { message: AgentMessage }).message);
expect(messages.map((m) => m.role)).toEqual([
"assistant",
"toolResult",
"assistant",
]);
expect((messages[1] as { toolCallId?: string }).toolCallId).toBe("call_1");
expect(
sanitizeToolUseResultPairing(messages).map((m) => m.role),
).toEqual(["assistant", "toolResult", "assistant"]);
});
});

View File

@@ -118,6 +118,10 @@ import { toToolDefinitions } from "./pi-tool-definition-adapter.js";
import { createClawdbotCodingTools } from "./pi-tools.js";
import { resolveSandboxContext } from "./sandbox.js";
import { sanitizeToolUseResultPairing } from "./session-transcript-repair.js";
import {
guardSessionManager,
type GuardedSessionManager,
} from "./session-tool-result-guard-wrapper.js";
import {
applySkillEnvOverrides,
applySkillEnvOverridesFromSnapshot,
@@ -1227,7 +1231,9 @@ export async function compactEmbeddedPiSession(params: {
try {
// Pre-warm session file to bring it into OS page cache
await prewarmSessionFile(params.sessionFile);
const sessionManager = SessionManager.open(params.sessionFile);
const sessionManager = guardSessionManager(
SessionManager.open(params.sessionFile),
);
trackSessionManagerAccess(params.sessionFile);
const settingsManager = SettingsManager.create(
effectiveWorkspace,
@@ -1308,6 +1314,7 @@ export async function compactEmbeddedPiSession(params: {
},
};
} finally {
sessionManager.flushPendingToolResults?.();
session.dispose();
}
} finally {
@@ -1665,6 +1672,9 @@ export async function runEmbeddedPiAgent(params: {
model,
});
const toolResultGuard =
installSessionToolResultGuard(sessionManager);
const { builtInTools, customTools } = splitSdkTools({
tools,
sandboxEnabled: !!sandbox?.enabled,
@@ -1717,6 +1727,7 @@ export async function runEmbeddedPiAgent(params: {
session.agent.replaceMessages(limited);
}
} catch (err) {
toolResultGuard.flushPendingToolResults();
session.dispose();
await sessionLock.release();
throw err;
@@ -1748,6 +1759,7 @@ export async function runEmbeddedPiAgent(params: {
enforceFinalTag: params.enforceFinalTag,
});
} catch (err) {
toolResultGuard.flushPendingToolResults();
session.dispose();
await sessionLock.release();
throw err;
@@ -1845,6 +1857,7 @@ export async function runEmbeddedPiAgent(params: {
ACTIVE_EMBEDDED_RUNS.delete(params.sessionId);
notifyEmbeddedRunEnded(params.sessionId);
}
sessionManager.flushPendingToolResults?.();
session.dispose();
await sessionLock.release();
params.abortSignal?.removeEventListener?.("abort", onAbort);

View File

@@ -14,15 +14,18 @@ import type {
} from "@mariozechner/pi-coding-agent";
import { isGoogleModelApi } from "../pi-embedded-helpers.js";
import { sanitizeToolUseResultPairing } from "../session-transcript-repair.js";
import {
repairToolUseResultPairing,
sanitizeToolUseResultPairing,
} from "../session-transcript-repair.js";
import { sanitizeToolCallIdsForCloudCodeAssist } from "../tool-call-id.js";
export default function transcriptSanitizeExtension(api: ExtensionAPI): void {
api.on("context", (event: ContextEvent, ctx: ExtensionContext) => {
let next = event.messages as AgentMessage[];
const repairedTools = sanitizeToolUseResultPairing(next);
if (repairedTools !== next) next = repairedTools;
const repaired = repairToolUseResultPairing(next);
if (repaired.messages !== next) next = repaired.messages;
if (isGoogleModelApi(ctx.model?.api)) {
const repairedIds = sanitizeToolCallIdsForCloudCodeAssist(next);

View File

@@ -0,0 +1,26 @@
import type { SessionManager } from "@mariozechner/pi-coding-agent";
import { installSessionToolResultGuard } from "./session-tool-result-guard.js";
export type GuardedSessionManager = SessionManager & {
/** Flush any synthetic tool results for pending tool calls. Idempotent. */
flushPendingToolResults?: () => void;
};
/**
* Apply the tool-result guard to a SessionManager exactly once and expose
* a flush method on the instance for easy teardown handling.
*/
export function guardSessionManager(
sessionManager: SessionManager,
): GuardedSessionManager {
if (typeof (sessionManager as GuardedSessionManager).flushPendingToolResults === "function") {
return sessionManager as GuardedSessionManager;
}
const guard = installSessionToolResultGuard(sessionManager);
(sessionManager as GuardedSessionManager).flushPendingToolResults =
guard.flushPendingToolResults;
return sessionManager as GuardedSessionManager;
}

View File

@@ -0,0 +1,152 @@
import { describe, expect, it } from "vitest";
import { SessionManager } from "@mariozechner/pi-coding-agent";
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import { installSessionToolResultGuard } from "./session-tool-result-guard.js";
const toolCallMessage = {
role: "assistant",
content: [{ type: "toolCall", id: "call_1", name: "read", arguments: {} }],
} satisfies AgentMessage;
describe("installSessionToolResultGuard", () => {
it("inserts synthetic toolResult before non-tool message when pending", () => {
const sm = SessionManager.inMemory();
installSessionToolResultGuard(sm);
sm.appendMessage(toolCallMessage);
sm.appendMessage({
role: "assistant",
content: [{ type: "text", text: "error" }],
stopReason: "error",
} as AgentMessage);
const entries = sm
.getEntries()
.filter((e) => e.type === "message")
.map((e) => (e as { message: AgentMessage }).message);
expect(entries.map((m) => m.role)).toEqual([
"assistant",
"toolResult",
"assistant",
]);
const synthetic = entries[1] as {
toolCallId?: string;
isError?: boolean;
content?: Array<{ type?: string; text?: string }>;
};
expect(synthetic.toolCallId).toBe("call_1");
expect(synthetic.isError).toBe(true);
expect(synthetic.content?.[0]?.text).toContain("missing tool result");
});
it("flushes pending tool calls when asked explicitly", () => {
const sm = SessionManager.inMemory();
const guard = installSessionToolResultGuard(sm);
sm.appendMessage(toolCallMessage);
guard.flushPendingToolResults();
const messages = sm
.getEntries()
.filter((e) => e.type === "message")
.map((e) => (e as { message: AgentMessage }).message);
expect(messages.map((m) => m.role)).toEqual(["assistant", "toolResult"]);
});
it("does not add synthetic toolResult when a matching one exists", () => {
const sm = SessionManager.inMemory();
installSessionToolResultGuard(sm);
sm.appendMessage(toolCallMessage);
sm.appendMessage({
role: "toolResult",
toolCallId: "call_1",
content: [{ type: "text", text: "ok" }],
isError: false,
} as AgentMessage);
const messages = sm
.getEntries()
.filter((e) => e.type === "message")
.map((e) => (e as { message: AgentMessage }).message);
expect(messages.map((m) => m.role)).toEqual(["assistant", "toolResult"]);
});
it("preserves ordering with multiple tool calls and partial results", () => {
const sm = SessionManager.inMemory();
const guard = installSessionToolResultGuard(sm);
sm.appendMessage({
role: "assistant",
content: [
{ type: "toolCall", id: "call_a", name: "one", arguments: {} },
{ type: "toolUse", id: "call_b", name: "two", arguments: {} },
],
} as AgentMessage);
sm.appendMessage({
role: "toolResult",
toolUseId: "call_a",
content: [{ type: "text", text: "a" }],
isError: false,
} as AgentMessage);
sm.appendMessage({
role: "assistant",
content: [{ type: "text", text: "after tools" }],
} as AgentMessage);
const messages = sm
.getEntries()
.filter((e) => e.type === "message")
.map((e) => (e as { message: AgentMessage }).message);
expect(messages.map((m) => m.role)).toEqual([
"assistant", // tool calls
"toolResult", // call_a real
"toolResult", // synthetic for call_b
"assistant", // text
]);
expect(
(messages[2] as { toolCallId?: string }).toolCallId,
).toBe("call_b");
expect(guard.getPendingIds()).toEqual([]);
});
it("flushes pending on guard when no toolResult arrived", () => {
const sm = SessionManager.inMemory();
const guard = installSessionToolResultGuard(sm);
sm.appendMessage(toolCallMessage);
sm.appendMessage({
role: "assistant",
content: [{ type: "text", text: "hard error" }],
stopReason: "error",
} as AgentMessage);
expect(guard.getPendingIds()).toEqual([]);
});
it("handles toolUseId on toolResult", () => {
const sm = SessionManager.inMemory();
installSessionToolResultGuard(sm);
sm.appendMessage({
role: "assistant",
content: [{ type: "toolUse", id: "use_1", name: "f", arguments: {} }],
} as AgentMessage);
sm.appendMessage({
role: "toolResult",
toolUseId: "use_1",
content: [{ type: "text", text: "ok" }],
} as AgentMessage);
const messages = sm
.getEntries()
.filter((e) => e.type === "message")
.map((e) => (e as { message: AgentMessage }).message);
expect(messages.map((m) => m.role)).toEqual(["assistant", "toolResult"]);
});
});

View File

@@ -0,0 +1,103 @@
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { SessionManager } from "@mariozechner/pi-coding-agent";
import { makeMissingToolResult } from "./session-transcript-repair.js";
type ToolCall = { id: string; name?: string };
function extractAssistantToolCalls(
msg: Extract<AgentMessage, { role: "assistant" }>,
): ToolCall[] {
const content = msg.content;
if (!Array.isArray(content)) return [];
const toolCalls: ToolCall[] = [];
for (const block of content) {
if (!block || typeof block !== "object") continue;
const rec = block as { type?: unknown; id?: unknown; name?: unknown };
if (typeof rec.id !== "string" || !rec.id) continue;
if (
rec.type === "toolCall" ||
rec.type === "toolUse" ||
rec.type === "functionCall"
) {
toolCalls.push({
id: rec.id,
name: typeof rec.name === "string" ? rec.name : undefined,
});
}
}
return toolCalls;
}
function extractToolResultId(
msg: Extract<AgentMessage, { role: "toolResult" }>,
): string | null {
const toolCallId = (msg as { toolCallId?: unknown }).toolCallId;
if (typeof toolCallId === "string" && toolCallId) return toolCallId;
const toolUseId = (msg as { toolUseId?: unknown }).toolUseId;
if (typeof toolUseId === "string" && toolUseId) return toolUseId;
return null;
}
export function installSessionToolResultGuard(sessionManager: SessionManager): {
flushPendingToolResults: () => void;
getPendingIds: () => string[];
} {
const originalAppend = sessionManager.appendMessage.bind(sessionManager);
const pending = new Map<string, string | undefined>();
const flushPendingToolResults = () => {
if (pending.size === 0) return;
for (const [id, name] of pending.entries()) {
originalAppend(makeMissingToolResult({ toolCallId: id, toolName: name }));
}
pending.clear();
};
const guardedAppend = (message: AgentMessage) => {
const role = (message as { role?: unknown }).role;
if (role === "toolResult") {
const id = extractToolResultId(
message as Extract<AgentMessage, { role: "toolResult" }>,
);
if (id) pending.delete(id);
return originalAppend(message as never);
}
const toolCalls =
role === "assistant"
? extractAssistantToolCalls(
message as Extract<AgentMessage, { role: "assistant" }>,
)
: [];
// If previous tool calls are still pending, flush before non-tool results.
if (pending.size > 0 && (toolCalls.length === 0 || role !== "assistant")) {
flushPendingToolResults();
}
// If new tool calls arrive while older ones are pending, flush the old ones first.
if (pending.size > 0 && toolCalls.length > 0) {
flushPendingToolResults();
}
const result = originalAppend(message as never);
if (toolCalls.length > 0) {
for (const call of toolCalls) {
pending.set(call.id, call.name);
}
}
return result;
};
// Monkey-patch appendMessage with our guarded version.
sessionManager.appendMessage = guardedAppend as SessionManager["appendMessage"];
return {
flushPendingToolResults,
getPendingIds: () => Array.from(pending.keys()),
};
}

View File

@@ -60,9 +60,25 @@ function makeMissingToolResult(params: {
} as Extract<AgentMessage, { role: "toolResult" }>;
}
export { makeMissingToolResult };
export function sanitizeToolUseResultPairing(
messages: AgentMessage[],
): AgentMessage[] {
return repairToolUseResultPairing(messages).messages;
}
export type ToolUseRepairReport = {
messages: AgentMessage[];
added: Array<Extract<AgentMessage, { role: "toolResult" }>>;
droppedDuplicateCount: number;
droppedOrphanCount: number;
moved: boolean;
};
export function repairToolUseResultPairing(
messages: AgentMessage[],
): ToolUseRepairReport {
// Anthropic (and Cloud Code Assist) reject transcripts where assistant tool calls are not
// immediately followed by matching tool results. Session files can end up with results
// displaced (e.g. after user turns) or duplicated. Repair by:
@@ -70,13 +86,22 @@ export function sanitizeToolUseResultPairing(
// - inserting synthetic error toolResults for missing ids
// - dropping duplicate toolResults for the same id (anywhere in the transcript)
const out: AgentMessage[] = [];
const added: Array<Extract<AgentMessage, { role: "toolResult" }>> = [];
const seenToolResultIds = new Set<string>();
let droppedDuplicateCount = 0;
let droppedOrphanCount = 0;
let moved = false;
let changed = false;
const pushToolResult = (
msg: Extract<AgentMessage, { role: "toolResult" }>,
) => {
const id = extractToolResultId(msg);
if (id && seenToolResultIds.has(id)) return;
if (id && seenToolResultIds.has(id)) {
droppedDuplicateCount += 1;
changed = true;
return;
}
if (id) seenToolResultIds.add(id);
out.push(msg);
};
@@ -93,7 +118,12 @@ export function sanitizeToolUseResultPairing(
// Tool results must only appear directly after the matching assistant tool call turn.
// Any "free-floating" toolResult entries in session history can make strict providers
// (Anthropic-compatible APIs, MiniMax, Cloud Code Assist) reject the entire request.
if (role !== "toolResult") out.push(msg);
if (role !== "toolResult") {
out.push(msg);
} else {
droppedOrphanCount += 1;
changed = true;
}
continue;
}
@@ -131,6 +161,8 @@ export function sanitizeToolUseResultPairing(
const id = extractToolResultId(toolResult);
if (id && toolCallIds.has(id)) {
if (seenToolResultIds.has(id)) {
droppedDuplicateCount += 1;
changed = true;
continue;
}
if (!spanResultsById.has(id)) {
@@ -141,17 +173,34 @@ export function sanitizeToolUseResultPairing(
}
// Drop tool results that don't match the current assistant tool calls.
if (nextRole !== "toolResult") remainder.push(next);
if (nextRole !== "toolResult") {
remainder.push(next);
} else {
droppedOrphanCount += 1;
changed = true;
}
}
out.push(msg);
if (spanResultsById.size > 0 && remainder.length > 0) {
moved = true;
changed = true;
}
for (const call of toolCalls) {
const existing = spanResultsById.get(call.id);
pushToolResult(
existing ??
makeMissingToolResult({ toolCallId: call.id, toolName: call.name }),
);
if (existing) {
pushToolResult(existing);
} else {
const missing = makeMissingToolResult({
toolCallId: call.id,
toolName: call.name,
});
added.push(missing);
changed = true;
pushToolResult(missing);
}
}
for (const rem of remainder) {
@@ -164,5 +213,12 @@ export function sanitizeToolUseResultPairing(
i = j - 1;
}
return out;
const changedOrMoved = changed || moved;
return {
messages: changedOrMoved ? out : messages,
added,
droppedDuplicateCount,
droppedOrphanCount,
moved: changedOrMoved,
};
}