fix: handle embedded agent overflow
This commit is contained in:
@@ -8,14 +8,16 @@ import {
|
||||
type Api,
|
||||
type AssistantMessage,
|
||||
type Model,
|
||||
type OAuthStorage,
|
||||
setOAuthStorage,
|
||||
type OAuthCredentials,
|
||||
type OAuthProvider,
|
||||
getEnvApiKey,
|
||||
getOAuthApiKey,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import {
|
||||
buildSystemPrompt,
|
||||
createAgentSession,
|
||||
defaultGetApiKey,
|
||||
findModelByProviderAndId,
|
||||
discoverAuthStorage,
|
||||
discoverModels,
|
||||
SessionManager,
|
||||
SettingsManager,
|
||||
type Skill,
|
||||
@@ -91,7 +93,8 @@ const ACTIVE_EMBEDDED_RUNS = new Map<string, EmbeddedPiQueueHandle>();
|
||||
const OAUTH_FILENAME = "oauth.json";
|
||||
const DEFAULT_OAUTH_DIR = path.join(CONFIG_DIR, "credentials");
|
||||
let oauthStorageConfigured = false;
|
||||
let cachedDefaultApiKey: ReturnType<typeof defaultGetApiKey> | null = null;
|
||||
|
||||
type OAuthStorage = Record<string, OAuthCredentials>;
|
||||
|
||||
function resolveSessionLane(key: string) {
|
||||
const cleaned = key.trim() || "main";
|
||||
@@ -178,18 +181,15 @@ function ensureOAuthStorage(): void {
|
||||
oauthStorageConfigured = true;
|
||||
const oauthPath = resolveClawdisOAuthPath();
|
||||
importLegacyOAuthIfNeeded(oauthPath);
|
||||
setOAuthStorage({
|
||||
load: () => loadOAuthStorageAt(oauthPath) ?? {},
|
||||
save: (storage) => saveOAuthStorageAt(oauthPath, storage),
|
||||
});
|
||||
}
|
||||
|
||||
function getDefaultApiKey() {
|
||||
if (!cachedDefaultApiKey) {
|
||||
ensureOAuthStorage();
|
||||
cachedDefaultApiKey = defaultGetApiKey();
|
||||
}
|
||||
return cachedDefaultApiKey;
|
||||
function isOAuthProvider(provider: string): provider is OAuthProvider {
|
||||
return (
|
||||
provider === "anthropic" ||
|
||||
provider === "github-copilot" ||
|
||||
provider === "google-gemini-cli" ||
|
||||
provider === "google-antigravity"
|
||||
);
|
||||
}
|
||||
|
||||
export function queueEmbeddedPiMessage(
|
||||
@@ -214,11 +214,10 @@ function resolveModel(
|
||||
modelId: string,
|
||||
agentDir?: string,
|
||||
): { model?: Model<Api>; error?: string } {
|
||||
const model = findModelByProviderAndId(
|
||||
provider,
|
||||
modelId,
|
||||
agentDir,
|
||||
) as Model<Api> | null;
|
||||
const resolvedAgentDir = agentDir ?? resolveClawdisAgentDir();
|
||||
const authStorage = discoverAuthStorage(resolvedAgentDir);
|
||||
const modelRegistry = discoverModels(authStorage, resolvedAgentDir);
|
||||
const model = modelRegistry.find(provider, modelId) as Model<Api> | null;
|
||||
if (!model) return { error: `Unknown model: ${provider}/${modelId}` };
|
||||
return { model };
|
||||
}
|
||||
@@ -229,8 +228,24 @@ async function getApiKeyForModel(model: Model<Api>): Promise<string> {
|
||||
const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN;
|
||||
if (oauthEnv?.trim()) return oauthEnv.trim();
|
||||
}
|
||||
const key = await getDefaultApiKey()(model);
|
||||
if (key) return key;
|
||||
const envKey = getEnvApiKey(model.provider);
|
||||
if (envKey) return envKey;
|
||||
if (isOAuthProvider(model.provider)) {
|
||||
const oauthPath = resolveClawdisOAuthPath();
|
||||
const storage = loadOAuthStorageAt(oauthPath);
|
||||
if (storage) {
|
||||
try {
|
||||
const result = await getOAuthApiKey(model.provider, storage);
|
||||
if (result?.apiKey) {
|
||||
storage[model.provider] = result.newCredentials;
|
||||
saveOAuthStorageAt(oauthPath, storage);
|
||||
return result.apiKey;
|
||||
}
|
||||
} catch {
|
||||
// fall through to error below
|
||||
}
|
||||
}
|
||||
}
|
||||
throw new Error(`No API key found for provider "${model.provider}"`);
|
||||
}
|
||||
|
||||
@@ -423,6 +438,7 @@ export async function runEmbeddedPiAgent(params: {
|
||||
toolMetas,
|
||||
unsubscribe,
|
||||
flush: flushToolDebouncer,
|
||||
waitForCompactionRetry,
|
||||
} = subscribeEmbeddedPiSession({
|
||||
session,
|
||||
runId: params.runId,
|
||||
@@ -463,10 +479,10 @@ export async function runEmbeddedPiAgent(params: {
|
||||
await session.prompt(params.prompt);
|
||||
} catch (err) {
|
||||
promptError = err;
|
||||
} finally {
|
||||
messagesSnapshot = session.messages.slice();
|
||||
sessionIdUsed = session.sessionId;
|
||||
}
|
||||
await waitForCompactionRetry();
|
||||
messagesSnapshot = session.messages.slice();
|
||||
sessionIdUsed = session.sessionId;
|
||||
} finally {
|
||||
clearTimeout(abortTimer);
|
||||
unsubscribe();
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
|
||||
import type { AssistantMessage } from "@mariozechner/pi-ai";
|
||||
|
||||
import { subscribeEmbeddedPiSession } from "./pi-embedded-subscribe.js";
|
||||
|
||||
type StubSession = {
|
||||
@@ -92,4 +95,57 @@ describe("subscribeEmbeddedPiSession", () => {
|
||||
const payload = onPartialReply.mock.calls[0][0];
|
||||
expect(payload.text).toBe("Hello world");
|
||||
});
|
||||
|
||||
it("waits for auto-compaction retry and clears buffered text", async () => {
|
||||
const listeners: Array<(evt: any) => void> = [];
|
||||
const session = {
|
||||
subscribe: (listener: (evt: any) => void) => {
|
||||
listeners.push(listener);
|
||||
return () => {
|
||||
const index = listeners.indexOf(listener);
|
||||
if (index !== -1) listeners.splice(index, 1);
|
||||
};
|
||||
},
|
||||
} as any;
|
||||
|
||||
const subscription = subscribeEmbeddedPiSession({
|
||||
session,
|
||||
runId: "run-1",
|
||||
});
|
||||
|
||||
const assistantMessage = {
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: "oops" }],
|
||||
} as AssistantMessage;
|
||||
|
||||
for (const listener of listeners) {
|
||||
listener({ type: "message_end", message: assistantMessage });
|
||||
}
|
||||
|
||||
expect(subscription.assistantTexts.length).toBe(1);
|
||||
|
||||
for (const listener of listeners) {
|
||||
listener({
|
||||
type: "auto_compaction_end",
|
||||
willRetry: true,
|
||||
});
|
||||
}
|
||||
|
||||
expect(subscription.assistantTexts.length).toBe(0);
|
||||
|
||||
let resolved = false;
|
||||
const waitPromise = subscription.waitForCompactionRetry().then(() => {
|
||||
resolved = true;
|
||||
});
|
||||
|
||||
await Promise.resolve();
|
||||
expect(resolved).toBe(false);
|
||||
|
||||
for (const listener of listeners) {
|
||||
listener({ type: "agent_end" });
|
||||
}
|
||||
|
||||
await waitPromise;
|
||||
expect(resolved).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -72,6 +72,41 @@ export function subscribeEmbeddedPiSession(params: {
|
||||
const toolMetaById = new Map<string, string | undefined>();
|
||||
let deltaBuffer = "";
|
||||
let lastStreamedAssistant: string | undefined;
|
||||
let compactionInFlight = false;
|
||||
let pendingCompactionRetry = 0;
|
||||
let compactionRetryResolve: (() => void) | undefined;
|
||||
let compactionRetryPromise: Promise<void> | null = null;
|
||||
|
||||
const ensureCompactionPromise = () => {
|
||||
if (!compactionRetryPromise) {
|
||||
compactionRetryPromise = new Promise((resolve) => {
|
||||
compactionRetryResolve = resolve;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const noteCompactionRetry = () => {
|
||||
pendingCompactionRetry += 1;
|
||||
ensureCompactionPromise();
|
||||
};
|
||||
|
||||
const resolveCompactionRetry = () => {
|
||||
if (pendingCompactionRetry <= 0) return;
|
||||
pendingCompactionRetry -= 1;
|
||||
if (pendingCompactionRetry === 0 && !compactionInFlight) {
|
||||
compactionRetryResolve?.();
|
||||
compactionRetryResolve = undefined;
|
||||
compactionRetryPromise = null;
|
||||
}
|
||||
};
|
||||
|
||||
const maybeResolveCompactionWait = () => {
|
||||
if (pendingCompactionRetry === 0 && !compactionInFlight) {
|
||||
compactionRetryResolve?.();
|
||||
compactionRetryResolve = undefined;
|
||||
compactionRetryPromise = null;
|
||||
}
|
||||
};
|
||||
const FINAL_START_RE = /<\s*final\s*>/i;
|
||||
const FINAL_END_RE = /<\s*\/\s*final\s*>/i;
|
||||
// Local providers sometimes emit malformed tags; normalize before filtering.
|
||||
@@ -104,6 +139,15 @@ export function subscribeEmbeddedPiSession(params: {
|
||||
});
|
||||
});
|
||||
|
||||
const resetForCompactionRetry = () => {
|
||||
assistantTexts.length = 0;
|
||||
toolMetas.length = 0;
|
||||
toolMetaById.clear();
|
||||
deltaBuffer = "";
|
||||
lastStreamedAssistant = undefined;
|
||||
toolDebouncer.flush();
|
||||
};
|
||||
|
||||
const unsubscribe = params.session.subscribe(
|
||||
(evt: AgentEvent | { type: string; [k: string]: unknown }) => {
|
||||
if (evt.type === "tool_execution_start") {
|
||||
@@ -274,8 +318,31 @@ export function subscribeEmbeddedPiSession(params: {
|
||||
}
|
||||
}
|
||||
|
||||
if (evt.type === "auto_compaction_start") {
|
||||
compactionInFlight = true;
|
||||
ensureCompactionPromise();
|
||||
}
|
||||
|
||||
if (evt.type === "auto_compaction_end") {
|
||||
compactionInFlight = false;
|
||||
const willRetry = Boolean(
|
||||
(evt as { willRetry?: unknown }).willRetry,
|
||||
);
|
||||
if (willRetry) {
|
||||
noteCompactionRetry();
|
||||
resetForCompactionRetry();
|
||||
} else {
|
||||
maybeResolveCompactionWait();
|
||||
}
|
||||
}
|
||||
|
||||
if (evt.type === "agent_end") {
|
||||
toolDebouncer.flush();
|
||||
if (pendingCompactionRetry > 0) {
|
||||
resolveCompactionRetry();
|
||||
} else {
|
||||
maybeResolveCompactionWait();
|
||||
}
|
||||
}
|
||||
},
|
||||
);
|
||||
@@ -285,5 +352,21 @@ export function subscribeEmbeddedPiSession(params: {
|
||||
toolMetas,
|
||||
unsubscribe,
|
||||
flush: () => toolDebouncer.flush(),
|
||||
waitForCompactionRetry: () => {
|
||||
if (compactionInFlight || pendingCompactionRetry > 0) {
|
||||
ensureCompactionPromise();
|
||||
return compactionRetryPromise ?? Promise.resolve();
|
||||
}
|
||||
return new Promise((resolve) => {
|
||||
queueMicrotask(() => {
|
||||
if (compactionInFlight || pendingCompactionRetry > 0) {
|
||||
ensureCompactionPromise();
|
||||
void (compactionRetryPromise ?? Promise.resolve()).then(resolve);
|
||||
} else {
|
||||
resolve();
|
||||
}
|
||||
});
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -101,6 +101,30 @@ describe("trigger handling", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("returns a context overflow fallback when the embedded agent throws", async () => {
|
||||
await withTempHome(async (home) => {
|
||||
vi.mocked(runEmbeddedPiAgent).mockRejectedValue(
|
||||
new Error("Context window exceeded"),
|
||||
);
|
||||
|
||||
const res = await getReplyFromConfig(
|
||||
{
|
||||
Body: "hello",
|
||||
From: "+1002",
|
||||
To: "+2000",
|
||||
},
|
||||
{},
|
||||
makeCfg(home),
|
||||
);
|
||||
|
||||
const text = Array.isArray(res) ? res[0]?.text : res?.text;
|
||||
expect(text).toBe(
|
||||
"⚠️ Context overflow - conversation too long. Starting fresh might help!",
|
||||
);
|
||||
expect(runEmbeddedPiAgent).toHaveBeenCalledOnce();
|
||||
});
|
||||
});
|
||||
|
||||
it("uses heartbeat model override for heartbeat runs", async () => {
|
||||
await withTempHome(async (home) => {
|
||||
vi.mocked(runEmbeddedPiAgent).mockResolvedValue({
|
||||
|
||||
@@ -996,44 +996,57 @@ export async function getReplyFromConfig(
|
||||
await startTypingLoop();
|
||||
}
|
||||
const runId = crypto.randomUUID();
|
||||
const runResult = await runEmbeddedPiAgent({
|
||||
sessionId: sessionIdFinal,
|
||||
sessionKey,
|
||||
sessionFile,
|
||||
workspaceDir,
|
||||
config: cfg,
|
||||
skillsSnapshot,
|
||||
prompt: commandBody,
|
||||
extraSystemPrompt: groupIntro || undefined,
|
||||
ownerNumbers: ownerList.length > 0 ? ownerList : undefined,
|
||||
enforceFinalTag:
|
||||
provider === "lmstudio" || provider === "ollama" ? true : undefined,
|
||||
provider,
|
||||
model,
|
||||
thinkLevel: resolvedThinkLevel,
|
||||
verboseLevel: resolvedVerboseLevel,
|
||||
timeoutMs,
|
||||
runId,
|
||||
onPartialReply: opts?.onPartialReply
|
||||
? async (payload) => {
|
||||
await startTypingOnText(payload.text);
|
||||
await opts.onPartialReply?.({
|
||||
text: payload.text,
|
||||
mediaUrls: payload.mediaUrls,
|
||||
});
|
||||
}
|
||||
: undefined,
|
||||
shouldEmitToolResult,
|
||||
onToolResult: opts?.onToolResult
|
||||
? async (payload) => {
|
||||
await startTypingOnText(payload.text);
|
||||
await opts.onToolResult?.({
|
||||
text: payload.text,
|
||||
mediaUrls: payload.mediaUrls,
|
||||
});
|
||||
}
|
||||
: undefined,
|
||||
});
|
||||
let runResult: Awaited<ReturnType<typeof runEmbeddedPiAgent>>;
|
||||
try {
|
||||
runResult = await runEmbeddedPiAgent({
|
||||
sessionId: sessionIdFinal,
|
||||
sessionKey,
|
||||
sessionFile,
|
||||
workspaceDir,
|
||||
config: cfg,
|
||||
skillsSnapshot,
|
||||
prompt: commandBody,
|
||||
extraSystemPrompt: groupIntro || undefined,
|
||||
ownerNumbers: ownerList.length > 0 ? ownerList : undefined,
|
||||
enforceFinalTag:
|
||||
provider === "lmstudio" || provider === "ollama" ? true : undefined,
|
||||
provider,
|
||||
model,
|
||||
thinkLevel: resolvedThinkLevel,
|
||||
verboseLevel: resolvedVerboseLevel,
|
||||
timeoutMs,
|
||||
runId,
|
||||
onPartialReply: opts?.onPartialReply
|
||||
? async (payload) => {
|
||||
await startTypingOnText(payload.text);
|
||||
await opts.onPartialReply?.({
|
||||
text: payload.text,
|
||||
mediaUrls: payload.mediaUrls,
|
||||
});
|
||||
}
|
||||
: undefined,
|
||||
shouldEmitToolResult,
|
||||
onToolResult: opts?.onToolResult
|
||||
? async (payload) => {
|
||||
await startTypingOnText(payload.text);
|
||||
await opts.onToolResult?.({
|
||||
text: payload.text,
|
||||
mediaUrls: payload.mediaUrls,
|
||||
});
|
||||
}
|
||||
: undefined,
|
||||
});
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const isContextOverflow =
|
||||
/context.*overflow|too large|context window/i.test(message);
|
||||
defaultRuntime.error(`Embedded agent failed before reply: ${message}`);
|
||||
return {
|
||||
text: isContextOverflow
|
||||
? "⚠️ Context overflow - conversation too long. Starting fresh might help!"
|
||||
: "⚠️ Agent failed. Check gateway logs.",
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
shouldInjectGroupIntro &&
|
||||
|
||||
Reference in New Issue
Block a user