diff --git a/src/agents/pi-embedded-helpers.ts b/src/agents/pi-embedded-helpers.ts index baafe7ef6..581400357 100644 --- a/src/agents/pi-embedded-helpers.ts +++ b/src/agents/pi-embedded-helpers.ts @@ -106,6 +106,10 @@ export async function sanitizeSessionMessagesImages( const GOOGLE_TURN_ORDER_BOOTSTRAP_TEXT = "(session bootstrap)"; +export function isGoogleModelApi(api?: string | null): boolean { + return api === "google-gemini-cli" || api === "google-generative-ai"; +} + export function sanitizeGoogleTurnOrdering( messages: AgentMessage[], ): AgentMessage[] { diff --git a/src/agents/pi-embedded-runner.test.ts b/src/agents/pi-embedded-runner.test.ts index ac5b75a76..e2fc92541 100644 --- a/src/agents/pi-embedded-runner.test.ts +++ b/src/agents/pi-embedded-runner.test.ts @@ -1,7 +1,9 @@ -import type { AgentTool } from "@mariozechner/pi-agent-core"; +import type { AgentMessage, AgentTool } from "@mariozechner/pi-agent-core"; +import { SessionManager } from "@mariozechner/pi-coding-agent"; import { Type } from "@sinclair/typebox"; -import { describe, expect, it } from "vitest"; +import { describe, expect, it, vi } from "vitest"; import { + applyGoogleTurnOrderingFix, buildEmbeddedSandboxInfo, splitSdkTools, } from "./pi-embedded-runner.js"; @@ -102,3 +104,64 @@ describe("splitSdkTools", () => { expect(customTools.map((tool) => tool.name)).toEqual(["browser"]); }); }); + +describe("applyGoogleTurnOrderingFix", () => { + const makeAssistantFirst = () => + [ + { + role: "assistant", + content: [ + { type: "toolCall", id: "call_1", name: "bash", arguments: {} }, + ], + }, + ] satisfies AgentMessage[]; + + it("prepends a bootstrap once and records a marker for Google models", () => { + const sessionManager = SessionManager.inMemory(); + const warn = vi.fn(); + const input = makeAssistantFirst(); + const first = applyGoogleTurnOrderingFix({ + messages: input, + modelApi: "google-generative-ai", + sessionManager, + sessionId: "session:1", + warn, + }); + expect(first.messages[0]?.role).toBe("user"); + expect(first.messages[1]?.role).toBe("assistant"); + expect(warn).toHaveBeenCalledTimes(1); + expect( + sessionManager + .getEntries() + .some( + (entry) => + entry.type === "custom" && + entry.customType === "google-turn-ordering-bootstrap", + ), + ).toBe(true); + + applyGoogleTurnOrderingFix({ + messages: input, + modelApi: "google-generative-ai", + sessionManager, + sessionId: "session:1", + warn, + }); + expect(warn).toHaveBeenCalledTimes(1); + }); + + it("skips non-Google models", () => { + const sessionManager = SessionManager.inMemory(); + const warn = vi.fn(); + const input = makeAssistantFirst(); + const result = applyGoogleTurnOrderingFix({ + messages: input, + modelApi: "openai", + sessionManager, + sessionId: "session:2", + warn, + }); + expect(result.messages).toBe(input); + expect(warn).not.toHaveBeenCalled(); + }); +}); diff --git a/src/agents/pi-embedded-runner.ts b/src/agents/pi-embedded-runner.ts index 5432f8b7b..ee7077fa7 100644 --- a/src/agents/pi-embedded-runner.ts +++ b/src/agents/pi-embedded-runner.ts @@ -59,6 +59,7 @@ import { isAuthAssistantError, isAuthErrorMessage, isContextOverflowError, + isGoogleModelApi, isRateLimitAssistantError, isRateLimitErrorMessage, pickFallbackThinkingLevel, @@ -243,6 +244,80 @@ type EmbeddedPiQueueHandle = { }; const log = createSubsystemLogger("agent/embedded"); +const GOOGLE_TURN_ORDERING_CUSTOM_TYPE = "google-turn-ordering-bootstrap"; + +type CustomEntryLike = { type?: unknown; customType?: unknown }; + +function hasGoogleTurnOrderingMarker(sessionManager: SessionManager): boolean { + try { + return sessionManager + .getEntries() + .some( + (entry) => + (entry as CustomEntryLike)?.type === "custom" && + (entry as CustomEntryLike)?.customType === + GOOGLE_TURN_ORDERING_CUSTOM_TYPE, + ); + } catch { + return false; + } +} + +function markGoogleTurnOrderingMarker(sessionManager: SessionManager): void { + try { + sessionManager.appendCustomEntry(GOOGLE_TURN_ORDERING_CUSTOM_TYPE, { + timestamp: Date.now(), + }); + } catch { + // ignore marker persistence failures + } +} + +export function applyGoogleTurnOrderingFix(params: { + messages: AgentMessage[]; + modelApi?: string | null; + sessionManager: SessionManager; + sessionId: string; + warn?: (message: string) => void; +}): { messages: AgentMessage[]; didPrepend: boolean } { + if (!isGoogleModelApi(params.modelApi)) { + return { messages: params.messages, didPrepend: false }; + } + const first = params.messages[0] as + | { role?: unknown; content?: unknown } + | undefined; + if (first?.role !== "assistant") { + return { messages: params.messages, didPrepend: false }; + } + const sanitized = sanitizeGoogleTurnOrdering(params.messages); + const didPrepend = sanitized !== params.messages; + if (didPrepend && !hasGoogleTurnOrderingMarker(params.sessionManager)) { + const warn = params.warn ?? ((message: string) => log.warn(message)); + warn( + `google turn ordering fixup: prepended user bootstrap (sessionId=${params.sessionId})`, + ); + markGoogleTurnOrderingMarker(params.sessionManager); + } + return { messages: sanitized, didPrepend }; +} + +async function sanitizeSessionHistory(params: { + messages: AgentMessage[]; + modelApi?: string | null; + sessionManager: SessionManager; + sessionId: string; +}): Promise { + const sanitizedImages = await sanitizeSessionMessagesImages( + params.messages, + "session:history", + ); + return applyGoogleTurnOrderingFix({ + messages: sanitizedImages, + modelApi: params.modelApi, + sessionManager: params.sessionManager, + sessionId: params.sessionId, + }).messages; +} const ACTIVE_EMBEDDED_RUNS = new Map(); type EmbeddedRunWaiter = { @@ -699,27 +774,12 @@ export async function compactEmbeddedPiSession(params: { })); try { - const sanitizedImages = await sanitizeSessionMessagesImages( - session.messages, - "session:history", - ); - const needsGoogleBootstrap = - (model.api === "google-gemini-cli" || - model.api === "google-generative-ai") && - sanitizedImages[0] && - typeof sanitizedImages[0] === "object" && - "role" in sanitizedImages[0] && - sanitizedImages[0].role === "assistant"; - const prior = - model.api === "google-gemini-cli" || - model.api === "google-generative-ai" - ? sanitizeGoogleTurnOrdering(sanitizedImages) - : sanitizedImages; - if (needsGoogleBootstrap) { - log.warn( - `google turn ordering fixup: prepended user bootstrap (sessionId=${params.sessionId})`, - ); - } + const prior = await sanitizeSessionHistory({ + messages: session.messages, + modelApi: model.api, + sessionManager, + sessionId: params.sessionId, + }); if (prior.length > 0) { session.agent.replaceMessages(prior); } @@ -1039,29 +1099,14 @@ export async function runEmbeddedPiAgent(params: { })); try { - const prior = await sanitizeSessionMessagesImages( - session.messages, - "session:history", - ); - const needsGoogleBootstrap = - (model.api === "google-gemini-cli" || - model.api === "google-generative-ai") && - prior[0] && - typeof prior[0] === "object" && - "role" in prior[0] && - prior[0].role === "assistant"; - const sanitizedPrior = - model.api === "google-gemini-cli" || - model.api === "google-generative-ai" - ? sanitizeGoogleTurnOrdering(prior) - : prior; - if (needsGoogleBootstrap) { - log.warn( - `google turn ordering fixup: prepended user bootstrap (sessionId=${params.sessionId})`, - ); - } - if (sanitizedPrior.length > 0) { - session.agent.replaceMessages(sanitizedPrior); + const prior = await sanitizeSessionHistory({ + messages: session.messages, + modelApi: model.api, + sessionManager, + sessionId: params.sessionId, + }); + if (prior.length > 0) { + session.agent.replaceMessages(prior); } } catch (err) { session.dispose();