fix: handle embedded agent overflow

This commit is contained in:
Peter Steinberger
2025-12-26 10:16:50 +01:00
parent 8059e83c49
commit d28265cfbe
5 changed files with 255 additions and 63 deletions

View File

@@ -8,14 +8,16 @@ import {
type Api,
type AssistantMessage,
type Model,
type OAuthStorage,
setOAuthStorage,
type OAuthCredentials,
type OAuthProvider,
getEnvApiKey,
getOAuthApiKey,
} from "@mariozechner/pi-ai";
import {
buildSystemPrompt,
createAgentSession,
defaultGetApiKey,
findModelByProviderAndId,
discoverAuthStorage,
discoverModels,
SessionManager,
SettingsManager,
type Skill,
@@ -91,7 +93,8 @@ const ACTIVE_EMBEDDED_RUNS = new Map<string, EmbeddedPiQueueHandle>();
const OAUTH_FILENAME = "oauth.json";
const DEFAULT_OAUTH_DIR = path.join(CONFIG_DIR, "credentials");
let oauthStorageConfigured = false;
let cachedDefaultApiKey: ReturnType<typeof defaultGetApiKey> | null = null;
type OAuthStorage = Record<string, OAuthCredentials>;
function resolveSessionLane(key: string) {
const cleaned = key.trim() || "main";
@@ -178,18 +181,15 @@ function ensureOAuthStorage(): void {
oauthStorageConfigured = true;
const oauthPath = resolveClawdisOAuthPath();
importLegacyOAuthIfNeeded(oauthPath);
setOAuthStorage({
load: () => loadOAuthStorageAt(oauthPath) ?? {},
save: (storage) => saveOAuthStorageAt(oauthPath, storage),
});
}
function getDefaultApiKey() {
if (!cachedDefaultApiKey) {
ensureOAuthStorage();
cachedDefaultApiKey = defaultGetApiKey();
}
return cachedDefaultApiKey;
function isOAuthProvider(provider: string): provider is OAuthProvider {
return (
provider === "anthropic" ||
provider === "github-copilot" ||
provider === "google-gemini-cli" ||
provider === "google-antigravity"
);
}
export function queueEmbeddedPiMessage(
@@ -214,11 +214,10 @@ function resolveModel(
modelId: string,
agentDir?: string,
): { model?: Model<Api>; error?: string } {
const model = findModelByProviderAndId(
provider,
modelId,
agentDir,
) as Model<Api> | null;
const resolvedAgentDir = agentDir ?? resolveClawdisAgentDir();
const authStorage = discoverAuthStorage(resolvedAgentDir);
const modelRegistry = discoverModels(authStorage, resolvedAgentDir);
const model = modelRegistry.find(provider, modelId) as Model<Api> | null;
if (!model) return { error: `Unknown model: ${provider}/${modelId}` };
return { model };
}
@@ -229,8 +228,24 @@ async function getApiKeyForModel(model: Model<Api>): Promise<string> {
const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN;
if (oauthEnv?.trim()) return oauthEnv.trim();
}
const key = await getDefaultApiKey()(model);
if (key) return key;
const envKey = getEnvApiKey(model.provider);
if (envKey) return envKey;
if (isOAuthProvider(model.provider)) {
const oauthPath = resolveClawdisOAuthPath();
const storage = loadOAuthStorageAt(oauthPath);
if (storage) {
try {
const result = await getOAuthApiKey(model.provider, storage);
if (result?.apiKey) {
storage[model.provider] = result.newCredentials;
saveOAuthStorageAt(oauthPath, storage);
return result.apiKey;
}
} catch {
// fall through to error below
}
}
}
throw new Error(`No API key found for provider "${model.provider}"`);
}
@@ -423,6 +438,7 @@ export async function runEmbeddedPiAgent(params: {
toolMetas,
unsubscribe,
flush: flushToolDebouncer,
waitForCompactionRetry,
} = subscribeEmbeddedPiSession({
session,
runId: params.runId,
@@ -463,10 +479,10 @@ export async function runEmbeddedPiAgent(params: {
await session.prompt(params.prompt);
} catch (err) {
promptError = err;
} finally {
messagesSnapshot = session.messages.slice();
sessionIdUsed = session.sessionId;
}
await waitForCompactionRetry();
messagesSnapshot = session.messages.slice();
sessionIdUsed = session.sessionId;
} finally {
clearTimeout(abortTimer);
unsubscribe();