fix: handle embedded agent overflow

This commit is contained in:
Peter Steinberger
2025-12-26 10:16:50 +01:00
parent 8059e83c49
commit d28265cfbe
5 changed files with 255 additions and 63 deletions

View File

@@ -8,14 +8,16 @@ import {
type Api,
type AssistantMessage,
type Model,
type OAuthStorage,
setOAuthStorage,
type OAuthCredentials,
type OAuthProvider,
getEnvApiKey,
getOAuthApiKey,
} from "@mariozechner/pi-ai";
import {
buildSystemPrompt,
createAgentSession,
defaultGetApiKey,
findModelByProviderAndId,
discoverAuthStorage,
discoverModels,
SessionManager,
SettingsManager,
type Skill,
@@ -91,7 +93,8 @@ const ACTIVE_EMBEDDED_RUNS = new Map<string, EmbeddedPiQueueHandle>();
const OAUTH_FILENAME = "oauth.json";
const DEFAULT_OAUTH_DIR = path.join(CONFIG_DIR, "credentials");
let oauthStorageConfigured = false;
let cachedDefaultApiKey: ReturnType<typeof defaultGetApiKey> | null = null;
type OAuthStorage = Record<string, OAuthCredentials>;
function resolveSessionLane(key: string) {
const cleaned = key.trim() || "main";
@@ -178,18 +181,15 @@ function ensureOAuthStorage(): void {
oauthStorageConfigured = true;
const oauthPath = resolveClawdisOAuthPath();
importLegacyOAuthIfNeeded(oauthPath);
setOAuthStorage({
load: () => loadOAuthStorageAt(oauthPath) ?? {},
save: (storage) => saveOAuthStorageAt(oauthPath, storage),
});
}
function getDefaultApiKey() {
if (!cachedDefaultApiKey) {
ensureOAuthStorage();
cachedDefaultApiKey = defaultGetApiKey();
}
return cachedDefaultApiKey;
function isOAuthProvider(provider: string): provider is OAuthProvider {
return (
provider === "anthropic" ||
provider === "github-copilot" ||
provider === "google-gemini-cli" ||
provider === "google-antigravity"
);
}
export function queueEmbeddedPiMessage(
@@ -214,11 +214,10 @@ function resolveModel(
modelId: string,
agentDir?: string,
): { model?: Model<Api>; error?: string } {
const model = findModelByProviderAndId(
provider,
modelId,
agentDir,
) as Model<Api> | null;
const resolvedAgentDir = agentDir ?? resolveClawdisAgentDir();
const authStorage = discoverAuthStorage(resolvedAgentDir);
const modelRegistry = discoverModels(authStorage, resolvedAgentDir);
const model = modelRegistry.find(provider, modelId) as Model<Api> | null;
if (!model) return { error: `Unknown model: ${provider}/${modelId}` };
return { model };
}
@@ -229,8 +228,24 @@ async function getApiKeyForModel(model: Model<Api>): Promise<string> {
const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN;
if (oauthEnv?.trim()) return oauthEnv.trim();
}
const key = await getDefaultApiKey()(model);
if (key) return key;
const envKey = getEnvApiKey(model.provider);
if (envKey) return envKey;
if (isOAuthProvider(model.provider)) {
const oauthPath = resolveClawdisOAuthPath();
const storage = loadOAuthStorageAt(oauthPath);
if (storage) {
try {
const result = await getOAuthApiKey(model.provider, storage);
if (result?.apiKey) {
storage[model.provider] = result.newCredentials;
saveOAuthStorageAt(oauthPath, storage);
return result.apiKey;
}
} catch {
// fall through to error below
}
}
}
throw new Error(`No API key found for provider "${model.provider}"`);
}
@@ -423,6 +438,7 @@ export async function runEmbeddedPiAgent(params: {
toolMetas,
unsubscribe,
flush: flushToolDebouncer,
waitForCompactionRetry,
} = subscribeEmbeddedPiSession({
session,
runId: params.runId,
@@ -463,10 +479,10 @@ export async function runEmbeddedPiAgent(params: {
await session.prompt(params.prompt);
} catch (err) {
promptError = err;
} finally {
messagesSnapshot = session.messages.slice();
sessionIdUsed = session.sessionId;
}
await waitForCompactionRetry();
messagesSnapshot = session.messages.slice();
sessionIdUsed = session.sessionId;
} finally {
clearTimeout(abortTimer);
unsubscribe();

View File

@@ -1,4 +1,7 @@
import { describe, expect, it, vi } from "vitest";
import type { AssistantMessage } from "@mariozechner/pi-ai";
import { subscribeEmbeddedPiSession } from "./pi-embedded-subscribe.js";
type StubSession = {
@@ -92,4 +95,57 @@ describe("subscribeEmbeddedPiSession", () => {
const payload = onPartialReply.mock.calls[0][0];
expect(payload.text).toBe("Hello world");
});
it("waits for auto-compaction retry and clears buffered text", async () => {
const listeners: Array<(evt: any) => void> = [];
const session = {
subscribe: (listener: (evt: any) => void) => {
listeners.push(listener);
return () => {
const index = listeners.indexOf(listener);
if (index !== -1) listeners.splice(index, 1);
};
},
} as any;
const subscription = subscribeEmbeddedPiSession({
session,
runId: "run-1",
});
const assistantMessage = {
role: "assistant",
content: [{ type: "text", text: "oops" }],
} as AssistantMessage;
for (const listener of listeners) {
listener({ type: "message_end", message: assistantMessage });
}
expect(subscription.assistantTexts.length).toBe(1);
for (const listener of listeners) {
listener({
type: "auto_compaction_end",
willRetry: true,
});
}
expect(subscription.assistantTexts.length).toBe(0);
let resolved = false;
const waitPromise = subscription.waitForCompactionRetry().then(() => {
resolved = true;
});
await Promise.resolve();
expect(resolved).toBe(false);
for (const listener of listeners) {
listener({ type: "agent_end" });
}
await waitPromise;
expect(resolved).toBe(true);
});
});

View File

@@ -72,6 +72,41 @@ export function subscribeEmbeddedPiSession(params: {
const toolMetaById = new Map<string, string | undefined>();
let deltaBuffer = "";
let lastStreamedAssistant: string | undefined;
let compactionInFlight = false;
let pendingCompactionRetry = 0;
let compactionRetryResolve: (() => void) | undefined;
let compactionRetryPromise: Promise<void> | null = null;
const ensureCompactionPromise = () => {
if (!compactionRetryPromise) {
compactionRetryPromise = new Promise((resolve) => {
compactionRetryResolve = resolve;
});
}
};
const noteCompactionRetry = () => {
pendingCompactionRetry += 1;
ensureCompactionPromise();
};
const resolveCompactionRetry = () => {
if (pendingCompactionRetry <= 0) return;
pendingCompactionRetry -= 1;
if (pendingCompactionRetry === 0 && !compactionInFlight) {
compactionRetryResolve?.();
compactionRetryResolve = undefined;
compactionRetryPromise = null;
}
};
const maybeResolveCompactionWait = () => {
if (pendingCompactionRetry === 0 && !compactionInFlight) {
compactionRetryResolve?.();
compactionRetryResolve = undefined;
compactionRetryPromise = null;
}
};
const FINAL_START_RE = /<\s*final\s*>/i;
const FINAL_END_RE = /<\s*\/\s*final\s*>/i;
// Local providers sometimes emit malformed tags; normalize before filtering.
@@ -104,6 +139,15 @@ export function subscribeEmbeddedPiSession(params: {
});
});
const resetForCompactionRetry = () => {
assistantTexts.length = 0;
toolMetas.length = 0;
toolMetaById.clear();
deltaBuffer = "";
lastStreamedAssistant = undefined;
toolDebouncer.flush();
};
const unsubscribe = params.session.subscribe(
(evt: AgentEvent | { type: string; [k: string]: unknown }) => {
if (evt.type === "tool_execution_start") {
@@ -274,8 +318,31 @@ export function subscribeEmbeddedPiSession(params: {
}
}
if (evt.type === "auto_compaction_start") {
compactionInFlight = true;
ensureCompactionPromise();
}
if (evt.type === "auto_compaction_end") {
compactionInFlight = false;
const willRetry = Boolean(
(evt as { willRetry?: unknown }).willRetry,
);
if (willRetry) {
noteCompactionRetry();
resetForCompactionRetry();
} else {
maybeResolveCompactionWait();
}
}
if (evt.type === "agent_end") {
toolDebouncer.flush();
if (pendingCompactionRetry > 0) {
resolveCompactionRetry();
} else {
maybeResolveCompactionWait();
}
}
},
);
@@ -285,5 +352,21 @@ export function subscribeEmbeddedPiSession(params: {
toolMetas,
unsubscribe,
flush: () => toolDebouncer.flush(),
waitForCompactionRetry: () => {
if (compactionInFlight || pendingCompactionRetry > 0) {
ensureCompactionPromise();
return compactionRetryPromise ?? Promise.resolve();
}
return new Promise((resolve) => {
queueMicrotask(() => {
if (compactionInFlight || pendingCompactionRetry > 0) {
ensureCompactionPromise();
void (compactionRetryPromise ?? Promise.resolve()).then(resolve);
} else {
resolve();
}
});
});
},
};
}

View File

@@ -101,6 +101,30 @@ describe("trigger handling", () => {
});
});
it("returns a context overflow fallback when the embedded agent throws", async () => {
await withTempHome(async (home) => {
vi.mocked(runEmbeddedPiAgent).mockRejectedValue(
new Error("Context window exceeded"),
);
const res = await getReplyFromConfig(
{
Body: "hello",
From: "+1002",
To: "+2000",
},
{},
makeCfg(home),
);
const text = Array.isArray(res) ? res[0]?.text : res?.text;
expect(text).toBe(
"⚠️ Context overflow - conversation too long. Starting fresh might help!",
);
expect(runEmbeddedPiAgent).toHaveBeenCalledOnce();
});
});
it("uses heartbeat model override for heartbeat runs", async () => {
await withTempHome(async (home) => {
vi.mocked(runEmbeddedPiAgent).mockResolvedValue({

View File

@@ -996,44 +996,57 @@ export async function getReplyFromConfig(
await startTypingLoop();
}
const runId = crypto.randomUUID();
const runResult = await runEmbeddedPiAgent({
sessionId: sessionIdFinal,
sessionKey,
sessionFile,
workspaceDir,
config: cfg,
skillsSnapshot,
prompt: commandBody,
extraSystemPrompt: groupIntro || undefined,
ownerNumbers: ownerList.length > 0 ? ownerList : undefined,
enforceFinalTag:
provider === "lmstudio" || provider === "ollama" ? true : undefined,
provider,
model,
thinkLevel: resolvedThinkLevel,
verboseLevel: resolvedVerboseLevel,
timeoutMs,
runId,
onPartialReply: opts?.onPartialReply
? async (payload) => {
await startTypingOnText(payload.text);
await opts.onPartialReply?.({
text: payload.text,
mediaUrls: payload.mediaUrls,
});
}
: undefined,
shouldEmitToolResult,
onToolResult: opts?.onToolResult
? async (payload) => {
await startTypingOnText(payload.text);
await opts.onToolResult?.({
text: payload.text,
mediaUrls: payload.mediaUrls,
});
}
: undefined,
});
let runResult: Awaited<ReturnType<typeof runEmbeddedPiAgent>>;
try {
runResult = await runEmbeddedPiAgent({
sessionId: sessionIdFinal,
sessionKey,
sessionFile,
workspaceDir,
config: cfg,
skillsSnapshot,
prompt: commandBody,
extraSystemPrompt: groupIntro || undefined,
ownerNumbers: ownerList.length > 0 ? ownerList : undefined,
enforceFinalTag:
provider === "lmstudio" || provider === "ollama" ? true : undefined,
provider,
model,
thinkLevel: resolvedThinkLevel,
verboseLevel: resolvedVerboseLevel,
timeoutMs,
runId,
onPartialReply: opts?.onPartialReply
? async (payload) => {
await startTypingOnText(payload.text);
await opts.onPartialReply?.({
text: payload.text,
mediaUrls: payload.mediaUrls,
});
}
: undefined,
shouldEmitToolResult,
onToolResult: opts?.onToolResult
? async (payload) => {
await startTypingOnText(payload.text);
await opts.onToolResult?.({
text: payload.text,
mediaUrls: payload.mediaUrls,
});
}
: undefined,
});
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const isContextOverflow =
/context.*overflow|too large|context window/i.test(message);
defaultRuntime.error(`Embedded agent failed before reply: ${message}`);
return {
text: isContextOverflow
? "⚠️ Context overflow - conversation too long. Starting fresh might help!"
: "⚠️ Agent failed. Check gateway logs.",
};
}
if (
shouldInjectGroupIntro &&