From d28265cfbe56ac8b57f5904f21e454aaec78305a Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Fri, 26 Dec 2025 10:16:50 +0100 Subject: [PATCH] fix: handle embedded agent overflow --- src/agents/pi-embedded-runner.ts | 66 +++++++++++------- src/agents/pi-embedded-subscribe.test.ts | 56 +++++++++++++++ src/agents/pi-embedded-subscribe.ts | 83 ++++++++++++++++++++++ src/auto-reply/reply.triggers.test.ts | 24 +++++++ src/auto-reply/reply.ts | 89 ++++++++++++++---------- 5 files changed, 255 insertions(+), 63 deletions(-) diff --git a/src/agents/pi-embedded-runner.ts b/src/agents/pi-embedded-runner.ts index 399bc7c46..4d0f0cfa4 100644 --- a/src/agents/pi-embedded-runner.ts +++ b/src/agents/pi-embedded-runner.ts @@ -8,14 +8,16 @@ import { type Api, type AssistantMessage, type Model, - type OAuthStorage, - setOAuthStorage, + type OAuthCredentials, + type OAuthProvider, + getEnvApiKey, + getOAuthApiKey, } from "@mariozechner/pi-ai"; import { buildSystemPrompt, createAgentSession, - defaultGetApiKey, - findModelByProviderAndId, + discoverAuthStorage, + discoverModels, SessionManager, SettingsManager, type Skill, @@ -91,7 +93,8 @@ const ACTIVE_EMBEDDED_RUNS = new Map(); const OAUTH_FILENAME = "oauth.json"; const DEFAULT_OAUTH_DIR = path.join(CONFIG_DIR, "credentials"); let oauthStorageConfigured = false; -let cachedDefaultApiKey: ReturnType | null = null; + +type OAuthStorage = Record; function resolveSessionLane(key: string) { const cleaned = key.trim() || "main"; @@ -178,18 +181,15 @@ function ensureOAuthStorage(): void { oauthStorageConfigured = true; const oauthPath = resolveClawdisOAuthPath(); importLegacyOAuthIfNeeded(oauthPath); - setOAuthStorage({ - load: () => loadOAuthStorageAt(oauthPath) ?? {}, - save: (storage) => saveOAuthStorageAt(oauthPath, storage), - }); } -function getDefaultApiKey() { - if (!cachedDefaultApiKey) { - ensureOAuthStorage(); - cachedDefaultApiKey = defaultGetApiKey(); - } - return cachedDefaultApiKey; +function isOAuthProvider(provider: string): provider is OAuthProvider { + return ( + provider === "anthropic" || + provider === "github-copilot" || + provider === "google-gemini-cli" || + provider === "google-antigravity" + ); } export function queueEmbeddedPiMessage( @@ -214,11 +214,10 @@ function resolveModel( modelId: string, agentDir?: string, ): { model?: Model; error?: string } { - const model = findModelByProviderAndId( - provider, - modelId, - agentDir, - ) as Model | null; + const resolvedAgentDir = agentDir ?? resolveClawdisAgentDir(); + const authStorage = discoverAuthStorage(resolvedAgentDir); + const modelRegistry = discoverModels(authStorage, resolvedAgentDir); + const model = modelRegistry.find(provider, modelId) as Model | null; if (!model) return { error: `Unknown model: ${provider}/${modelId}` }; return { model }; } @@ -229,8 +228,24 @@ async function getApiKeyForModel(model: Model): Promise { const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN; if (oauthEnv?.trim()) return oauthEnv.trim(); } - const key = await getDefaultApiKey()(model); - if (key) return key; + const envKey = getEnvApiKey(model.provider); + if (envKey) return envKey; + if (isOAuthProvider(model.provider)) { + const oauthPath = resolveClawdisOAuthPath(); + const storage = loadOAuthStorageAt(oauthPath); + if (storage) { + try { + const result = await getOAuthApiKey(model.provider, storage); + if (result?.apiKey) { + storage[model.provider] = result.newCredentials; + saveOAuthStorageAt(oauthPath, storage); + return result.apiKey; + } + } catch { + // fall through to error below + } + } + } throw new Error(`No API key found for provider "${model.provider}"`); } @@ -423,6 +438,7 @@ export async function runEmbeddedPiAgent(params: { toolMetas, unsubscribe, flush: flushToolDebouncer, + waitForCompactionRetry, } = subscribeEmbeddedPiSession({ session, runId: params.runId, @@ -463,10 +479,10 @@ export async function runEmbeddedPiAgent(params: { await session.prompt(params.prompt); } catch (err) { promptError = err; - } finally { - messagesSnapshot = session.messages.slice(); - sessionIdUsed = session.sessionId; } + await waitForCompactionRetry(); + messagesSnapshot = session.messages.slice(); + sessionIdUsed = session.sessionId; } finally { clearTimeout(abortTimer); unsubscribe(); diff --git a/src/agents/pi-embedded-subscribe.test.ts b/src/agents/pi-embedded-subscribe.test.ts index e5cd3e75f..9162cab46 100644 --- a/src/agents/pi-embedded-subscribe.test.ts +++ b/src/agents/pi-embedded-subscribe.test.ts @@ -1,4 +1,7 @@ import { describe, expect, it, vi } from "vitest"; + +import type { AssistantMessage } from "@mariozechner/pi-ai"; + import { subscribeEmbeddedPiSession } from "./pi-embedded-subscribe.js"; type StubSession = { @@ -92,4 +95,57 @@ describe("subscribeEmbeddedPiSession", () => { const payload = onPartialReply.mock.calls[0][0]; expect(payload.text).toBe("Hello world"); }); + + it("waits for auto-compaction retry and clears buffered text", async () => { + const listeners: Array<(evt: any) => void> = []; + const session = { + subscribe: (listener: (evt: any) => void) => { + listeners.push(listener); + return () => { + const index = listeners.indexOf(listener); + if (index !== -1) listeners.splice(index, 1); + }; + }, + } as any; + + const subscription = subscribeEmbeddedPiSession({ + session, + runId: "run-1", + }); + + const assistantMessage = { + role: "assistant", + content: [{ type: "text", text: "oops" }], + } as AssistantMessage; + + for (const listener of listeners) { + listener({ type: "message_end", message: assistantMessage }); + } + + expect(subscription.assistantTexts.length).toBe(1); + + for (const listener of listeners) { + listener({ + type: "auto_compaction_end", + willRetry: true, + }); + } + + expect(subscription.assistantTexts.length).toBe(0); + + let resolved = false; + const waitPromise = subscription.waitForCompactionRetry().then(() => { + resolved = true; + }); + + await Promise.resolve(); + expect(resolved).toBe(false); + + for (const listener of listeners) { + listener({ type: "agent_end" }); + } + + await waitPromise; + expect(resolved).toBe(true); + }); }); diff --git a/src/agents/pi-embedded-subscribe.ts b/src/agents/pi-embedded-subscribe.ts index 165c96e86..f38f95742 100644 --- a/src/agents/pi-embedded-subscribe.ts +++ b/src/agents/pi-embedded-subscribe.ts @@ -72,6 +72,41 @@ export function subscribeEmbeddedPiSession(params: { const toolMetaById = new Map(); let deltaBuffer = ""; let lastStreamedAssistant: string | undefined; + let compactionInFlight = false; + let pendingCompactionRetry = 0; + let compactionRetryResolve: (() => void) | undefined; + let compactionRetryPromise: Promise | null = null; + + const ensureCompactionPromise = () => { + if (!compactionRetryPromise) { + compactionRetryPromise = new Promise((resolve) => { + compactionRetryResolve = resolve; + }); + } + }; + + const noteCompactionRetry = () => { + pendingCompactionRetry += 1; + ensureCompactionPromise(); + }; + + const resolveCompactionRetry = () => { + if (pendingCompactionRetry <= 0) return; + pendingCompactionRetry -= 1; + if (pendingCompactionRetry === 0 && !compactionInFlight) { + compactionRetryResolve?.(); + compactionRetryResolve = undefined; + compactionRetryPromise = null; + } + }; + + const maybeResolveCompactionWait = () => { + if (pendingCompactionRetry === 0 && !compactionInFlight) { + compactionRetryResolve?.(); + compactionRetryResolve = undefined; + compactionRetryPromise = null; + } + }; const FINAL_START_RE = /<\s*final\s*>/i; const FINAL_END_RE = /<\s*\/\s*final\s*>/i; // Local providers sometimes emit malformed tags; normalize before filtering. @@ -104,6 +139,15 @@ export function subscribeEmbeddedPiSession(params: { }); }); + const resetForCompactionRetry = () => { + assistantTexts.length = 0; + toolMetas.length = 0; + toolMetaById.clear(); + deltaBuffer = ""; + lastStreamedAssistant = undefined; + toolDebouncer.flush(); + }; + const unsubscribe = params.session.subscribe( (evt: AgentEvent | { type: string; [k: string]: unknown }) => { if (evt.type === "tool_execution_start") { @@ -274,8 +318,31 @@ export function subscribeEmbeddedPiSession(params: { } } + if (evt.type === "auto_compaction_start") { + compactionInFlight = true; + ensureCompactionPromise(); + } + + if (evt.type === "auto_compaction_end") { + compactionInFlight = false; + const willRetry = Boolean( + (evt as { willRetry?: unknown }).willRetry, + ); + if (willRetry) { + noteCompactionRetry(); + resetForCompactionRetry(); + } else { + maybeResolveCompactionWait(); + } + } + if (evt.type === "agent_end") { toolDebouncer.flush(); + if (pendingCompactionRetry > 0) { + resolveCompactionRetry(); + } else { + maybeResolveCompactionWait(); + } } }, ); @@ -285,5 +352,21 @@ export function subscribeEmbeddedPiSession(params: { toolMetas, unsubscribe, flush: () => toolDebouncer.flush(), + waitForCompactionRetry: () => { + if (compactionInFlight || pendingCompactionRetry > 0) { + ensureCompactionPromise(); + return compactionRetryPromise ?? Promise.resolve(); + } + return new Promise((resolve) => { + queueMicrotask(() => { + if (compactionInFlight || pendingCompactionRetry > 0) { + ensureCompactionPromise(); + void (compactionRetryPromise ?? Promise.resolve()).then(resolve); + } else { + resolve(); + } + }); + }); + }, }; } diff --git a/src/auto-reply/reply.triggers.test.ts b/src/auto-reply/reply.triggers.test.ts index d5780fb40..53f7e0dbd 100644 --- a/src/auto-reply/reply.triggers.test.ts +++ b/src/auto-reply/reply.triggers.test.ts @@ -101,6 +101,30 @@ describe("trigger handling", () => { }); }); + it("returns a context overflow fallback when the embedded agent throws", async () => { + await withTempHome(async (home) => { + vi.mocked(runEmbeddedPiAgent).mockRejectedValue( + new Error("Context window exceeded"), + ); + + const res = await getReplyFromConfig( + { + Body: "hello", + From: "+1002", + To: "+2000", + }, + {}, + makeCfg(home), + ); + + const text = Array.isArray(res) ? res[0]?.text : res?.text; + expect(text).toBe( + "⚠️ Context overflow - conversation too long. Starting fresh might help!", + ); + expect(runEmbeddedPiAgent).toHaveBeenCalledOnce(); + }); + }); + it("uses heartbeat model override for heartbeat runs", async () => { await withTempHome(async (home) => { vi.mocked(runEmbeddedPiAgent).mockResolvedValue({ diff --git a/src/auto-reply/reply.ts b/src/auto-reply/reply.ts index 7c6d88f53..4655d40f2 100644 --- a/src/auto-reply/reply.ts +++ b/src/auto-reply/reply.ts @@ -996,44 +996,57 @@ export async function getReplyFromConfig( await startTypingLoop(); } const runId = crypto.randomUUID(); - const runResult = await runEmbeddedPiAgent({ - sessionId: sessionIdFinal, - sessionKey, - sessionFile, - workspaceDir, - config: cfg, - skillsSnapshot, - prompt: commandBody, - extraSystemPrompt: groupIntro || undefined, - ownerNumbers: ownerList.length > 0 ? ownerList : undefined, - enforceFinalTag: - provider === "lmstudio" || provider === "ollama" ? true : undefined, - provider, - model, - thinkLevel: resolvedThinkLevel, - verboseLevel: resolvedVerboseLevel, - timeoutMs, - runId, - onPartialReply: opts?.onPartialReply - ? async (payload) => { - await startTypingOnText(payload.text); - await opts.onPartialReply?.({ - text: payload.text, - mediaUrls: payload.mediaUrls, - }); - } - : undefined, - shouldEmitToolResult, - onToolResult: opts?.onToolResult - ? async (payload) => { - await startTypingOnText(payload.text); - await opts.onToolResult?.({ - text: payload.text, - mediaUrls: payload.mediaUrls, - }); - } - : undefined, - }); + let runResult: Awaited>; + try { + runResult = await runEmbeddedPiAgent({ + sessionId: sessionIdFinal, + sessionKey, + sessionFile, + workspaceDir, + config: cfg, + skillsSnapshot, + prompt: commandBody, + extraSystemPrompt: groupIntro || undefined, + ownerNumbers: ownerList.length > 0 ? ownerList : undefined, + enforceFinalTag: + provider === "lmstudio" || provider === "ollama" ? true : undefined, + provider, + model, + thinkLevel: resolvedThinkLevel, + verboseLevel: resolvedVerboseLevel, + timeoutMs, + runId, + onPartialReply: opts?.onPartialReply + ? async (payload) => { + await startTypingOnText(payload.text); + await opts.onPartialReply?.({ + text: payload.text, + mediaUrls: payload.mediaUrls, + }); + } + : undefined, + shouldEmitToolResult, + onToolResult: opts?.onToolResult + ? async (payload) => { + await startTypingOnText(payload.text); + await opts.onToolResult?.({ + text: payload.text, + mediaUrls: payload.mediaUrls, + }); + } + : undefined, + }); + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + const isContextOverflow = + /context.*overflow|too large|context window/i.test(message); + defaultRuntime.error(`Embedded agent failed before reply: ${message}`); + return { + text: isContextOverflow + ? "⚠️ Context overflow - conversation too long. Starting fresh might help!" + : "⚠️ Agent failed. Check gateway logs.", + }; + } if ( shouldInjectGroupIntro &&