fix: handle embedded agent overflow
This commit is contained in:
@@ -8,14 +8,16 @@ import {
|
|||||||
type Api,
|
type Api,
|
||||||
type AssistantMessage,
|
type AssistantMessage,
|
||||||
type Model,
|
type Model,
|
||||||
type OAuthStorage,
|
type OAuthCredentials,
|
||||||
setOAuthStorage,
|
type OAuthProvider,
|
||||||
|
getEnvApiKey,
|
||||||
|
getOAuthApiKey,
|
||||||
} from "@mariozechner/pi-ai";
|
} from "@mariozechner/pi-ai";
|
||||||
import {
|
import {
|
||||||
buildSystemPrompt,
|
buildSystemPrompt,
|
||||||
createAgentSession,
|
createAgentSession,
|
||||||
defaultGetApiKey,
|
discoverAuthStorage,
|
||||||
findModelByProviderAndId,
|
discoverModels,
|
||||||
SessionManager,
|
SessionManager,
|
||||||
SettingsManager,
|
SettingsManager,
|
||||||
type Skill,
|
type Skill,
|
||||||
@@ -91,7 +93,8 @@ const ACTIVE_EMBEDDED_RUNS = new Map<string, EmbeddedPiQueueHandle>();
|
|||||||
const OAUTH_FILENAME = "oauth.json";
|
const OAUTH_FILENAME = "oauth.json";
|
||||||
const DEFAULT_OAUTH_DIR = path.join(CONFIG_DIR, "credentials");
|
const DEFAULT_OAUTH_DIR = path.join(CONFIG_DIR, "credentials");
|
||||||
let oauthStorageConfigured = false;
|
let oauthStorageConfigured = false;
|
||||||
let cachedDefaultApiKey: ReturnType<typeof defaultGetApiKey> | null = null;
|
|
||||||
|
type OAuthStorage = Record<string, OAuthCredentials>;
|
||||||
|
|
||||||
function resolveSessionLane(key: string) {
|
function resolveSessionLane(key: string) {
|
||||||
const cleaned = key.trim() || "main";
|
const cleaned = key.trim() || "main";
|
||||||
@@ -178,18 +181,15 @@ function ensureOAuthStorage(): void {
|
|||||||
oauthStorageConfigured = true;
|
oauthStorageConfigured = true;
|
||||||
const oauthPath = resolveClawdisOAuthPath();
|
const oauthPath = resolveClawdisOAuthPath();
|
||||||
importLegacyOAuthIfNeeded(oauthPath);
|
importLegacyOAuthIfNeeded(oauthPath);
|
||||||
setOAuthStorage({
|
|
||||||
load: () => loadOAuthStorageAt(oauthPath) ?? {},
|
|
||||||
save: (storage) => saveOAuthStorageAt(oauthPath, storage),
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function getDefaultApiKey() {
|
function isOAuthProvider(provider: string): provider is OAuthProvider {
|
||||||
if (!cachedDefaultApiKey) {
|
return (
|
||||||
ensureOAuthStorage();
|
provider === "anthropic" ||
|
||||||
cachedDefaultApiKey = defaultGetApiKey();
|
provider === "github-copilot" ||
|
||||||
}
|
provider === "google-gemini-cli" ||
|
||||||
return cachedDefaultApiKey;
|
provider === "google-antigravity"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
export function queueEmbeddedPiMessage(
|
export function queueEmbeddedPiMessage(
|
||||||
@@ -214,11 +214,10 @@ function resolveModel(
|
|||||||
modelId: string,
|
modelId: string,
|
||||||
agentDir?: string,
|
agentDir?: string,
|
||||||
): { model?: Model<Api>; error?: string } {
|
): { model?: Model<Api>; error?: string } {
|
||||||
const model = findModelByProviderAndId(
|
const resolvedAgentDir = agentDir ?? resolveClawdisAgentDir();
|
||||||
provider,
|
const authStorage = discoverAuthStorage(resolvedAgentDir);
|
||||||
modelId,
|
const modelRegistry = discoverModels(authStorage, resolvedAgentDir);
|
||||||
agentDir,
|
const model = modelRegistry.find(provider, modelId) as Model<Api> | null;
|
||||||
) as Model<Api> | null;
|
|
||||||
if (!model) return { error: `Unknown model: ${provider}/${modelId}` };
|
if (!model) return { error: `Unknown model: ${provider}/${modelId}` };
|
||||||
return { model };
|
return { model };
|
||||||
}
|
}
|
||||||
@@ -229,8 +228,24 @@ async function getApiKeyForModel(model: Model<Api>): Promise<string> {
|
|||||||
const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN;
|
const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN;
|
||||||
if (oauthEnv?.trim()) return oauthEnv.trim();
|
if (oauthEnv?.trim()) return oauthEnv.trim();
|
||||||
}
|
}
|
||||||
const key = await getDefaultApiKey()(model);
|
const envKey = getEnvApiKey(model.provider);
|
||||||
if (key) return key;
|
if (envKey) return envKey;
|
||||||
|
if (isOAuthProvider(model.provider)) {
|
||||||
|
const oauthPath = resolveClawdisOAuthPath();
|
||||||
|
const storage = loadOAuthStorageAt(oauthPath);
|
||||||
|
if (storage) {
|
||||||
|
try {
|
||||||
|
const result = await getOAuthApiKey(model.provider, storage);
|
||||||
|
if (result?.apiKey) {
|
||||||
|
storage[model.provider] = result.newCredentials;
|
||||||
|
saveOAuthStorageAt(oauthPath, storage);
|
||||||
|
return result.apiKey;
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// fall through to error below
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
throw new Error(`No API key found for provider "${model.provider}"`);
|
throw new Error(`No API key found for provider "${model.provider}"`);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -423,6 +438,7 @@ export async function runEmbeddedPiAgent(params: {
|
|||||||
toolMetas,
|
toolMetas,
|
||||||
unsubscribe,
|
unsubscribe,
|
||||||
flush: flushToolDebouncer,
|
flush: flushToolDebouncer,
|
||||||
|
waitForCompactionRetry,
|
||||||
} = subscribeEmbeddedPiSession({
|
} = subscribeEmbeddedPiSession({
|
||||||
session,
|
session,
|
||||||
runId: params.runId,
|
runId: params.runId,
|
||||||
@@ -463,10 +479,10 @@ export async function runEmbeddedPiAgent(params: {
|
|||||||
await session.prompt(params.prompt);
|
await session.prompt(params.prompt);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
promptError = err;
|
promptError = err;
|
||||||
} finally {
|
|
||||||
messagesSnapshot = session.messages.slice();
|
|
||||||
sessionIdUsed = session.sessionId;
|
|
||||||
}
|
}
|
||||||
|
await waitForCompactionRetry();
|
||||||
|
messagesSnapshot = session.messages.slice();
|
||||||
|
sessionIdUsed = session.sessionId;
|
||||||
} finally {
|
} finally {
|
||||||
clearTimeout(abortTimer);
|
clearTimeout(abortTimer);
|
||||||
unsubscribe();
|
unsubscribe();
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
import { describe, expect, it, vi } from "vitest";
|
import { describe, expect, it, vi } from "vitest";
|
||||||
|
|
||||||
|
import type { AssistantMessage } from "@mariozechner/pi-ai";
|
||||||
|
|
||||||
import { subscribeEmbeddedPiSession } from "./pi-embedded-subscribe.js";
|
import { subscribeEmbeddedPiSession } from "./pi-embedded-subscribe.js";
|
||||||
|
|
||||||
type StubSession = {
|
type StubSession = {
|
||||||
@@ -92,4 +95,57 @@ describe("subscribeEmbeddedPiSession", () => {
|
|||||||
const payload = onPartialReply.mock.calls[0][0];
|
const payload = onPartialReply.mock.calls[0][0];
|
||||||
expect(payload.text).toBe("Hello world");
|
expect(payload.text).toBe("Hello world");
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("waits for auto-compaction retry and clears buffered text", async () => {
|
||||||
|
const listeners: Array<(evt: any) => void> = [];
|
||||||
|
const session = {
|
||||||
|
subscribe: (listener: (evt: any) => void) => {
|
||||||
|
listeners.push(listener);
|
||||||
|
return () => {
|
||||||
|
const index = listeners.indexOf(listener);
|
||||||
|
if (index !== -1) listeners.splice(index, 1);
|
||||||
|
};
|
||||||
|
},
|
||||||
|
} as any;
|
||||||
|
|
||||||
|
const subscription = subscribeEmbeddedPiSession({
|
||||||
|
session,
|
||||||
|
runId: "run-1",
|
||||||
|
});
|
||||||
|
|
||||||
|
const assistantMessage = {
|
||||||
|
role: "assistant",
|
||||||
|
content: [{ type: "text", text: "oops" }],
|
||||||
|
} as AssistantMessage;
|
||||||
|
|
||||||
|
for (const listener of listeners) {
|
||||||
|
listener({ type: "message_end", message: assistantMessage });
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(subscription.assistantTexts.length).toBe(1);
|
||||||
|
|
||||||
|
for (const listener of listeners) {
|
||||||
|
listener({
|
||||||
|
type: "auto_compaction_end",
|
||||||
|
willRetry: true,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(subscription.assistantTexts.length).toBe(0);
|
||||||
|
|
||||||
|
let resolved = false;
|
||||||
|
const waitPromise = subscription.waitForCompactionRetry().then(() => {
|
||||||
|
resolved = true;
|
||||||
|
});
|
||||||
|
|
||||||
|
await Promise.resolve();
|
||||||
|
expect(resolved).toBe(false);
|
||||||
|
|
||||||
|
for (const listener of listeners) {
|
||||||
|
listener({ type: "agent_end" });
|
||||||
|
}
|
||||||
|
|
||||||
|
await waitPromise;
|
||||||
|
expect(resolved).toBe(true);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -72,6 +72,41 @@ export function subscribeEmbeddedPiSession(params: {
|
|||||||
const toolMetaById = new Map<string, string | undefined>();
|
const toolMetaById = new Map<string, string | undefined>();
|
||||||
let deltaBuffer = "";
|
let deltaBuffer = "";
|
||||||
let lastStreamedAssistant: string | undefined;
|
let lastStreamedAssistant: string | undefined;
|
||||||
|
let compactionInFlight = false;
|
||||||
|
let pendingCompactionRetry = 0;
|
||||||
|
let compactionRetryResolve: (() => void) | undefined;
|
||||||
|
let compactionRetryPromise: Promise<void> | null = null;
|
||||||
|
|
||||||
|
const ensureCompactionPromise = () => {
|
||||||
|
if (!compactionRetryPromise) {
|
||||||
|
compactionRetryPromise = new Promise((resolve) => {
|
||||||
|
compactionRetryResolve = resolve;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const noteCompactionRetry = () => {
|
||||||
|
pendingCompactionRetry += 1;
|
||||||
|
ensureCompactionPromise();
|
||||||
|
};
|
||||||
|
|
||||||
|
const resolveCompactionRetry = () => {
|
||||||
|
if (pendingCompactionRetry <= 0) return;
|
||||||
|
pendingCompactionRetry -= 1;
|
||||||
|
if (pendingCompactionRetry === 0 && !compactionInFlight) {
|
||||||
|
compactionRetryResolve?.();
|
||||||
|
compactionRetryResolve = undefined;
|
||||||
|
compactionRetryPromise = null;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const maybeResolveCompactionWait = () => {
|
||||||
|
if (pendingCompactionRetry === 0 && !compactionInFlight) {
|
||||||
|
compactionRetryResolve?.();
|
||||||
|
compactionRetryResolve = undefined;
|
||||||
|
compactionRetryPromise = null;
|
||||||
|
}
|
||||||
|
};
|
||||||
const FINAL_START_RE = /<\s*final\s*>/i;
|
const FINAL_START_RE = /<\s*final\s*>/i;
|
||||||
const FINAL_END_RE = /<\s*\/\s*final\s*>/i;
|
const FINAL_END_RE = /<\s*\/\s*final\s*>/i;
|
||||||
// Local providers sometimes emit malformed tags; normalize before filtering.
|
// Local providers sometimes emit malformed tags; normalize before filtering.
|
||||||
@@ -104,6 +139,15 @@ export function subscribeEmbeddedPiSession(params: {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const resetForCompactionRetry = () => {
|
||||||
|
assistantTexts.length = 0;
|
||||||
|
toolMetas.length = 0;
|
||||||
|
toolMetaById.clear();
|
||||||
|
deltaBuffer = "";
|
||||||
|
lastStreamedAssistant = undefined;
|
||||||
|
toolDebouncer.flush();
|
||||||
|
};
|
||||||
|
|
||||||
const unsubscribe = params.session.subscribe(
|
const unsubscribe = params.session.subscribe(
|
||||||
(evt: AgentEvent | { type: string; [k: string]: unknown }) => {
|
(evt: AgentEvent | { type: string; [k: string]: unknown }) => {
|
||||||
if (evt.type === "tool_execution_start") {
|
if (evt.type === "tool_execution_start") {
|
||||||
@@ -274,8 +318,31 @@ export function subscribeEmbeddedPiSession(params: {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (evt.type === "auto_compaction_start") {
|
||||||
|
compactionInFlight = true;
|
||||||
|
ensureCompactionPromise();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (evt.type === "auto_compaction_end") {
|
||||||
|
compactionInFlight = false;
|
||||||
|
const willRetry = Boolean(
|
||||||
|
(evt as { willRetry?: unknown }).willRetry,
|
||||||
|
);
|
||||||
|
if (willRetry) {
|
||||||
|
noteCompactionRetry();
|
||||||
|
resetForCompactionRetry();
|
||||||
|
} else {
|
||||||
|
maybeResolveCompactionWait();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (evt.type === "agent_end") {
|
if (evt.type === "agent_end") {
|
||||||
toolDebouncer.flush();
|
toolDebouncer.flush();
|
||||||
|
if (pendingCompactionRetry > 0) {
|
||||||
|
resolveCompactionRetry();
|
||||||
|
} else {
|
||||||
|
maybeResolveCompactionWait();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
@@ -285,5 +352,21 @@ export function subscribeEmbeddedPiSession(params: {
|
|||||||
toolMetas,
|
toolMetas,
|
||||||
unsubscribe,
|
unsubscribe,
|
||||||
flush: () => toolDebouncer.flush(),
|
flush: () => toolDebouncer.flush(),
|
||||||
|
waitForCompactionRetry: () => {
|
||||||
|
if (compactionInFlight || pendingCompactionRetry > 0) {
|
||||||
|
ensureCompactionPromise();
|
||||||
|
return compactionRetryPromise ?? Promise.resolve();
|
||||||
|
}
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
queueMicrotask(() => {
|
||||||
|
if (compactionInFlight || pendingCompactionRetry > 0) {
|
||||||
|
ensureCompactionPromise();
|
||||||
|
void (compactionRetryPromise ?? Promise.resolve()).then(resolve);
|
||||||
|
} else {
|
||||||
|
resolve();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -101,6 +101,30 @@ describe("trigger handling", () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("returns a context overflow fallback when the embedded agent throws", async () => {
|
||||||
|
await withTempHome(async (home) => {
|
||||||
|
vi.mocked(runEmbeddedPiAgent).mockRejectedValue(
|
||||||
|
new Error("Context window exceeded"),
|
||||||
|
);
|
||||||
|
|
||||||
|
const res = await getReplyFromConfig(
|
||||||
|
{
|
||||||
|
Body: "hello",
|
||||||
|
From: "+1002",
|
||||||
|
To: "+2000",
|
||||||
|
},
|
||||||
|
{},
|
||||||
|
makeCfg(home),
|
||||||
|
);
|
||||||
|
|
||||||
|
const text = Array.isArray(res) ? res[0]?.text : res?.text;
|
||||||
|
expect(text).toBe(
|
||||||
|
"⚠️ Context overflow - conversation too long. Starting fresh might help!",
|
||||||
|
);
|
||||||
|
expect(runEmbeddedPiAgent).toHaveBeenCalledOnce();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
it("uses heartbeat model override for heartbeat runs", async () => {
|
it("uses heartbeat model override for heartbeat runs", async () => {
|
||||||
await withTempHome(async (home) => {
|
await withTempHome(async (home) => {
|
||||||
vi.mocked(runEmbeddedPiAgent).mockResolvedValue({
|
vi.mocked(runEmbeddedPiAgent).mockResolvedValue({
|
||||||
|
|||||||
@@ -996,44 +996,57 @@ export async function getReplyFromConfig(
|
|||||||
await startTypingLoop();
|
await startTypingLoop();
|
||||||
}
|
}
|
||||||
const runId = crypto.randomUUID();
|
const runId = crypto.randomUUID();
|
||||||
const runResult = await runEmbeddedPiAgent({
|
let runResult: Awaited<ReturnType<typeof runEmbeddedPiAgent>>;
|
||||||
sessionId: sessionIdFinal,
|
try {
|
||||||
sessionKey,
|
runResult = await runEmbeddedPiAgent({
|
||||||
sessionFile,
|
sessionId: sessionIdFinal,
|
||||||
workspaceDir,
|
sessionKey,
|
||||||
config: cfg,
|
sessionFile,
|
||||||
skillsSnapshot,
|
workspaceDir,
|
||||||
prompt: commandBody,
|
config: cfg,
|
||||||
extraSystemPrompt: groupIntro || undefined,
|
skillsSnapshot,
|
||||||
ownerNumbers: ownerList.length > 0 ? ownerList : undefined,
|
prompt: commandBody,
|
||||||
enforceFinalTag:
|
extraSystemPrompt: groupIntro || undefined,
|
||||||
provider === "lmstudio" || provider === "ollama" ? true : undefined,
|
ownerNumbers: ownerList.length > 0 ? ownerList : undefined,
|
||||||
provider,
|
enforceFinalTag:
|
||||||
model,
|
provider === "lmstudio" || provider === "ollama" ? true : undefined,
|
||||||
thinkLevel: resolvedThinkLevel,
|
provider,
|
||||||
verboseLevel: resolvedVerboseLevel,
|
model,
|
||||||
timeoutMs,
|
thinkLevel: resolvedThinkLevel,
|
||||||
runId,
|
verboseLevel: resolvedVerboseLevel,
|
||||||
onPartialReply: opts?.onPartialReply
|
timeoutMs,
|
||||||
? async (payload) => {
|
runId,
|
||||||
await startTypingOnText(payload.text);
|
onPartialReply: opts?.onPartialReply
|
||||||
await opts.onPartialReply?.({
|
? async (payload) => {
|
||||||
text: payload.text,
|
await startTypingOnText(payload.text);
|
||||||
mediaUrls: payload.mediaUrls,
|
await opts.onPartialReply?.({
|
||||||
});
|
text: payload.text,
|
||||||
}
|
mediaUrls: payload.mediaUrls,
|
||||||
: undefined,
|
});
|
||||||
shouldEmitToolResult,
|
}
|
||||||
onToolResult: opts?.onToolResult
|
: undefined,
|
||||||
? async (payload) => {
|
shouldEmitToolResult,
|
||||||
await startTypingOnText(payload.text);
|
onToolResult: opts?.onToolResult
|
||||||
await opts.onToolResult?.({
|
? async (payload) => {
|
||||||
text: payload.text,
|
await startTypingOnText(payload.text);
|
||||||
mediaUrls: payload.mediaUrls,
|
await opts.onToolResult?.({
|
||||||
});
|
text: payload.text,
|
||||||
}
|
mediaUrls: payload.mediaUrls,
|
||||||
: undefined,
|
});
|
||||||
});
|
}
|
||||||
|
: undefined,
|
||||||
|
});
|
||||||
|
} catch (err) {
|
||||||
|
const message = err instanceof Error ? err.message : String(err);
|
||||||
|
const isContextOverflow =
|
||||||
|
/context.*overflow|too large|context window/i.test(message);
|
||||||
|
defaultRuntime.error(`Embedded agent failed before reply: ${message}`);
|
||||||
|
return {
|
||||||
|
text: isContextOverflow
|
||||||
|
? "⚠️ Context overflow - conversation too long. Starting fresh might help!"
|
||||||
|
: "⚠️ Agent failed. Check gateway logs.",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
shouldInjectGroupIntro &&
|
shouldInjectGroupIntro &&
|
||||||
|
|||||||
Reference in New Issue
Block a user