fix: handle embedded agent overflow

This commit is contained in:
Peter Steinberger
2025-12-26 10:16:50 +01:00
parent 8059e83c49
commit d28265cfbe
5 changed files with 255 additions and 63 deletions

View File

@@ -8,14 +8,16 @@ import {
type Api, type Api,
type AssistantMessage, type AssistantMessage,
type Model, type Model,
type OAuthStorage, type OAuthCredentials,
setOAuthStorage, type OAuthProvider,
getEnvApiKey,
getOAuthApiKey,
} from "@mariozechner/pi-ai"; } from "@mariozechner/pi-ai";
import { import {
buildSystemPrompt, buildSystemPrompt,
createAgentSession, createAgentSession,
defaultGetApiKey, discoverAuthStorage,
findModelByProviderAndId, discoverModels,
SessionManager, SessionManager,
SettingsManager, SettingsManager,
type Skill, type Skill,
@@ -91,7 +93,8 @@ const ACTIVE_EMBEDDED_RUNS = new Map<string, EmbeddedPiQueueHandle>();
const OAUTH_FILENAME = "oauth.json"; const OAUTH_FILENAME = "oauth.json";
const DEFAULT_OAUTH_DIR = path.join(CONFIG_DIR, "credentials"); const DEFAULT_OAUTH_DIR = path.join(CONFIG_DIR, "credentials");
let oauthStorageConfigured = false; let oauthStorageConfigured = false;
let cachedDefaultApiKey: ReturnType<typeof defaultGetApiKey> | null = null;
type OAuthStorage = Record<string, OAuthCredentials>;
function resolveSessionLane(key: string) { function resolveSessionLane(key: string) {
const cleaned = key.trim() || "main"; const cleaned = key.trim() || "main";
@@ -178,18 +181,15 @@ function ensureOAuthStorage(): void {
oauthStorageConfigured = true; oauthStorageConfigured = true;
const oauthPath = resolveClawdisOAuthPath(); const oauthPath = resolveClawdisOAuthPath();
importLegacyOAuthIfNeeded(oauthPath); importLegacyOAuthIfNeeded(oauthPath);
setOAuthStorage({
load: () => loadOAuthStorageAt(oauthPath) ?? {},
save: (storage) => saveOAuthStorageAt(oauthPath, storage),
});
} }
function getDefaultApiKey() { function isOAuthProvider(provider: string): provider is OAuthProvider {
if (!cachedDefaultApiKey) { return (
ensureOAuthStorage(); provider === "anthropic" ||
cachedDefaultApiKey = defaultGetApiKey(); provider === "github-copilot" ||
} provider === "google-gemini-cli" ||
return cachedDefaultApiKey; provider === "google-antigravity"
);
} }
export function queueEmbeddedPiMessage( export function queueEmbeddedPiMessage(
@@ -214,11 +214,10 @@ function resolveModel(
modelId: string, modelId: string,
agentDir?: string, agentDir?: string,
): { model?: Model<Api>; error?: string } { ): { model?: Model<Api>; error?: string } {
const model = findModelByProviderAndId( const resolvedAgentDir = agentDir ?? resolveClawdisAgentDir();
provider, const authStorage = discoverAuthStorage(resolvedAgentDir);
modelId, const modelRegistry = discoverModels(authStorage, resolvedAgentDir);
agentDir, const model = modelRegistry.find(provider, modelId) as Model<Api> | null;
) as Model<Api> | null;
if (!model) return { error: `Unknown model: ${provider}/${modelId}` }; if (!model) return { error: `Unknown model: ${provider}/${modelId}` };
return { model }; return { model };
} }
@@ -229,8 +228,24 @@ async function getApiKeyForModel(model: Model<Api>): Promise<string> {
const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN; const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN;
if (oauthEnv?.trim()) return oauthEnv.trim(); if (oauthEnv?.trim()) return oauthEnv.trim();
} }
const key = await getDefaultApiKey()(model); const envKey = getEnvApiKey(model.provider);
if (key) return key; if (envKey) return envKey;
if (isOAuthProvider(model.provider)) {
const oauthPath = resolveClawdisOAuthPath();
const storage = loadOAuthStorageAt(oauthPath);
if (storage) {
try {
const result = await getOAuthApiKey(model.provider, storage);
if (result?.apiKey) {
storage[model.provider] = result.newCredentials;
saveOAuthStorageAt(oauthPath, storage);
return result.apiKey;
}
} catch {
// fall through to error below
}
}
}
throw new Error(`No API key found for provider "${model.provider}"`); throw new Error(`No API key found for provider "${model.provider}"`);
} }
@@ -423,6 +438,7 @@ export async function runEmbeddedPiAgent(params: {
toolMetas, toolMetas,
unsubscribe, unsubscribe,
flush: flushToolDebouncer, flush: flushToolDebouncer,
waitForCompactionRetry,
} = subscribeEmbeddedPiSession({ } = subscribeEmbeddedPiSession({
session, session,
runId: params.runId, runId: params.runId,
@@ -463,10 +479,10 @@ export async function runEmbeddedPiAgent(params: {
await session.prompt(params.prompt); await session.prompt(params.prompt);
} catch (err) { } catch (err) {
promptError = err; promptError = err;
} finally {
messagesSnapshot = session.messages.slice();
sessionIdUsed = session.sessionId;
} }
await waitForCompactionRetry();
messagesSnapshot = session.messages.slice();
sessionIdUsed = session.sessionId;
} finally { } finally {
clearTimeout(abortTimer); clearTimeout(abortTimer);
unsubscribe(); unsubscribe();

View File

@@ -1,4 +1,7 @@
import { describe, expect, it, vi } from "vitest"; import { describe, expect, it, vi } from "vitest";
import type { AssistantMessage } from "@mariozechner/pi-ai";
import { subscribeEmbeddedPiSession } from "./pi-embedded-subscribe.js"; import { subscribeEmbeddedPiSession } from "./pi-embedded-subscribe.js";
type StubSession = { type StubSession = {
@@ -92,4 +95,57 @@ describe("subscribeEmbeddedPiSession", () => {
const payload = onPartialReply.mock.calls[0][0]; const payload = onPartialReply.mock.calls[0][0];
expect(payload.text).toBe("Hello world"); expect(payload.text).toBe("Hello world");
}); });
it("waits for auto-compaction retry and clears buffered text", async () => {
const listeners: Array<(evt: any) => void> = [];
const session = {
subscribe: (listener: (evt: any) => void) => {
listeners.push(listener);
return () => {
const index = listeners.indexOf(listener);
if (index !== -1) listeners.splice(index, 1);
};
},
} as any;
const subscription = subscribeEmbeddedPiSession({
session,
runId: "run-1",
});
const assistantMessage = {
role: "assistant",
content: [{ type: "text", text: "oops" }],
} as AssistantMessage;
for (const listener of listeners) {
listener({ type: "message_end", message: assistantMessage });
}
expect(subscription.assistantTexts.length).toBe(1);
for (const listener of listeners) {
listener({
type: "auto_compaction_end",
willRetry: true,
});
}
expect(subscription.assistantTexts.length).toBe(0);
let resolved = false;
const waitPromise = subscription.waitForCompactionRetry().then(() => {
resolved = true;
});
await Promise.resolve();
expect(resolved).toBe(false);
for (const listener of listeners) {
listener({ type: "agent_end" });
}
await waitPromise;
expect(resolved).toBe(true);
});
}); });

View File

@@ -72,6 +72,41 @@ export function subscribeEmbeddedPiSession(params: {
const toolMetaById = new Map<string, string | undefined>(); const toolMetaById = new Map<string, string | undefined>();
let deltaBuffer = ""; let deltaBuffer = "";
let lastStreamedAssistant: string | undefined; let lastStreamedAssistant: string | undefined;
let compactionInFlight = false;
let pendingCompactionRetry = 0;
let compactionRetryResolve: (() => void) | undefined;
let compactionRetryPromise: Promise<void> | null = null;
const ensureCompactionPromise = () => {
if (!compactionRetryPromise) {
compactionRetryPromise = new Promise((resolve) => {
compactionRetryResolve = resolve;
});
}
};
const noteCompactionRetry = () => {
pendingCompactionRetry += 1;
ensureCompactionPromise();
};
const resolveCompactionRetry = () => {
if (pendingCompactionRetry <= 0) return;
pendingCompactionRetry -= 1;
if (pendingCompactionRetry === 0 && !compactionInFlight) {
compactionRetryResolve?.();
compactionRetryResolve = undefined;
compactionRetryPromise = null;
}
};
const maybeResolveCompactionWait = () => {
if (pendingCompactionRetry === 0 && !compactionInFlight) {
compactionRetryResolve?.();
compactionRetryResolve = undefined;
compactionRetryPromise = null;
}
};
const FINAL_START_RE = /<\s*final\s*>/i; const FINAL_START_RE = /<\s*final\s*>/i;
const FINAL_END_RE = /<\s*\/\s*final\s*>/i; const FINAL_END_RE = /<\s*\/\s*final\s*>/i;
// Local providers sometimes emit malformed tags; normalize before filtering. // Local providers sometimes emit malformed tags; normalize before filtering.
@@ -104,6 +139,15 @@ export function subscribeEmbeddedPiSession(params: {
}); });
}); });
const resetForCompactionRetry = () => {
assistantTexts.length = 0;
toolMetas.length = 0;
toolMetaById.clear();
deltaBuffer = "";
lastStreamedAssistant = undefined;
toolDebouncer.flush();
};
const unsubscribe = params.session.subscribe( const unsubscribe = params.session.subscribe(
(evt: AgentEvent | { type: string; [k: string]: unknown }) => { (evt: AgentEvent | { type: string; [k: string]: unknown }) => {
if (evt.type === "tool_execution_start") { if (evt.type === "tool_execution_start") {
@@ -274,8 +318,31 @@ export function subscribeEmbeddedPiSession(params: {
} }
} }
if (evt.type === "auto_compaction_start") {
compactionInFlight = true;
ensureCompactionPromise();
}
if (evt.type === "auto_compaction_end") {
compactionInFlight = false;
const willRetry = Boolean(
(evt as { willRetry?: unknown }).willRetry,
);
if (willRetry) {
noteCompactionRetry();
resetForCompactionRetry();
} else {
maybeResolveCompactionWait();
}
}
if (evt.type === "agent_end") { if (evt.type === "agent_end") {
toolDebouncer.flush(); toolDebouncer.flush();
if (pendingCompactionRetry > 0) {
resolveCompactionRetry();
} else {
maybeResolveCompactionWait();
}
} }
}, },
); );
@@ -285,5 +352,21 @@ export function subscribeEmbeddedPiSession(params: {
toolMetas, toolMetas,
unsubscribe, unsubscribe,
flush: () => toolDebouncer.flush(), flush: () => toolDebouncer.flush(),
waitForCompactionRetry: () => {
if (compactionInFlight || pendingCompactionRetry > 0) {
ensureCompactionPromise();
return compactionRetryPromise ?? Promise.resolve();
}
return new Promise((resolve) => {
queueMicrotask(() => {
if (compactionInFlight || pendingCompactionRetry > 0) {
ensureCompactionPromise();
void (compactionRetryPromise ?? Promise.resolve()).then(resolve);
} else {
resolve();
}
});
});
},
}; };
} }

View File

@@ -101,6 +101,30 @@ describe("trigger handling", () => {
}); });
}); });
it("returns a context overflow fallback when the embedded agent throws", async () => {
await withTempHome(async (home) => {
vi.mocked(runEmbeddedPiAgent).mockRejectedValue(
new Error("Context window exceeded"),
);
const res = await getReplyFromConfig(
{
Body: "hello",
From: "+1002",
To: "+2000",
},
{},
makeCfg(home),
);
const text = Array.isArray(res) ? res[0]?.text : res?.text;
expect(text).toBe(
"⚠️ Context overflow - conversation too long. Starting fresh might help!",
);
expect(runEmbeddedPiAgent).toHaveBeenCalledOnce();
});
});
it("uses heartbeat model override for heartbeat runs", async () => { it("uses heartbeat model override for heartbeat runs", async () => {
await withTempHome(async (home) => { await withTempHome(async (home) => {
vi.mocked(runEmbeddedPiAgent).mockResolvedValue({ vi.mocked(runEmbeddedPiAgent).mockResolvedValue({

View File

@@ -996,44 +996,57 @@ export async function getReplyFromConfig(
await startTypingLoop(); await startTypingLoop();
} }
const runId = crypto.randomUUID(); const runId = crypto.randomUUID();
const runResult = await runEmbeddedPiAgent({ let runResult: Awaited<ReturnType<typeof runEmbeddedPiAgent>>;
sessionId: sessionIdFinal, try {
sessionKey, runResult = await runEmbeddedPiAgent({
sessionFile, sessionId: sessionIdFinal,
workspaceDir, sessionKey,
config: cfg, sessionFile,
skillsSnapshot, workspaceDir,
prompt: commandBody, config: cfg,
extraSystemPrompt: groupIntro || undefined, skillsSnapshot,
ownerNumbers: ownerList.length > 0 ? ownerList : undefined, prompt: commandBody,
enforceFinalTag: extraSystemPrompt: groupIntro || undefined,
provider === "lmstudio" || provider === "ollama" ? true : undefined, ownerNumbers: ownerList.length > 0 ? ownerList : undefined,
provider, enforceFinalTag:
model, provider === "lmstudio" || provider === "ollama" ? true : undefined,
thinkLevel: resolvedThinkLevel, provider,
verboseLevel: resolvedVerboseLevel, model,
timeoutMs, thinkLevel: resolvedThinkLevel,
runId, verboseLevel: resolvedVerboseLevel,
onPartialReply: opts?.onPartialReply timeoutMs,
? async (payload) => { runId,
await startTypingOnText(payload.text); onPartialReply: opts?.onPartialReply
await opts.onPartialReply?.({ ? async (payload) => {
text: payload.text, await startTypingOnText(payload.text);
mediaUrls: payload.mediaUrls, await opts.onPartialReply?.({
}); text: payload.text,
} mediaUrls: payload.mediaUrls,
: undefined, });
shouldEmitToolResult, }
onToolResult: opts?.onToolResult : undefined,
? async (payload) => { shouldEmitToolResult,
await startTypingOnText(payload.text); onToolResult: opts?.onToolResult
await opts.onToolResult?.({ ? async (payload) => {
text: payload.text, await startTypingOnText(payload.text);
mediaUrls: payload.mediaUrls, await opts.onToolResult?.({
}); text: payload.text,
} mediaUrls: payload.mediaUrls,
: undefined, });
}); }
: undefined,
});
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const isContextOverflow =
/context.*overflow|too large|context window/i.test(message);
defaultRuntime.error(`Embedded agent failed before reply: ${message}`);
return {
text: isContextOverflow
? "⚠️ Context overflow - conversation too long. Starting fresh might help!"
: "⚠️ Agent failed. Check gateway logs.",
};
}
if ( if (
shouldInjectGroupIntro && shouldInjectGroupIntro &&