fix: handle embedded agent overflow
This commit is contained in:
@@ -8,14 +8,16 @@ import {
|
||||
type Api,
|
||||
type AssistantMessage,
|
||||
type Model,
|
||||
type OAuthStorage,
|
||||
setOAuthStorage,
|
||||
type OAuthCredentials,
|
||||
type OAuthProvider,
|
||||
getEnvApiKey,
|
||||
getOAuthApiKey,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import {
|
||||
buildSystemPrompt,
|
||||
createAgentSession,
|
||||
defaultGetApiKey,
|
||||
findModelByProviderAndId,
|
||||
discoverAuthStorage,
|
||||
discoverModels,
|
||||
SessionManager,
|
||||
SettingsManager,
|
||||
type Skill,
|
||||
@@ -91,7 +93,8 @@ const ACTIVE_EMBEDDED_RUNS = new Map<string, EmbeddedPiQueueHandle>();
|
||||
const OAUTH_FILENAME = "oauth.json";
|
||||
const DEFAULT_OAUTH_DIR = path.join(CONFIG_DIR, "credentials");
|
||||
let oauthStorageConfigured = false;
|
||||
let cachedDefaultApiKey: ReturnType<typeof defaultGetApiKey> | null = null;
|
||||
|
||||
type OAuthStorage = Record<string, OAuthCredentials>;
|
||||
|
||||
function resolveSessionLane(key: string) {
|
||||
const cleaned = key.trim() || "main";
|
||||
@@ -178,18 +181,15 @@ function ensureOAuthStorage(): void {
|
||||
oauthStorageConfigured = true;
|
||||
const oauthPath = resolveClawdisOAuthPath();
|
||||
importLegacyOAuthIfNeeded(oauthPath);
|
||||
setOAuthStorage({
|
||||
load: () => loadOAuthStorageAt(oauthPath) ?? {},
|
||||
save: (storage) => saveOAuthStorageAt(oauthPath, storage),
|
||||
});
|
||||
}
|
||||
|
||||
function getDefaultApiKey() {
|
||||
if (!cachedDefaultApiKey) {
|
||||
ensureOAuthStorage();
|
||||
cachedDefaultApiKey = defaultGetApiKey();
|
||||
}
|
||||
return cachedDefaultApiKey;
|
||||
function isOAuthProvider(provider: string): provider is OAuthProvider {
|
||||
return (
|
||||
provider === "anthropic" ||
|
||||
provider === "github-copilot" ||
|
||||
provider === "google-gemini-cli" ||
|
||||
provider === "google-antigravity"
|
||||
);
|
||||
}
|
||||
|
||||
export function queueEmbeddedPiMessage(
|
||||
@@ -214,11 +214,10 @@ function resolveModel(
|
||||
modelId: string,
|
||||
agentDir?: string,
|
||||
): { model?: Model<Api>; error?: string } {
|
||||
const model = findModelByProviderAndId(
|
||||
provider,
|
||||
modelId,
|
||||
agentDir,
|
||||
) as Model<Api> | null;
|
||||
const resolvedAgentDir = agentDir ?? resolveClawdisAgentDir();
|
||||
const authStorage = discoverAuthStorage(resolvedAgentDir);
|
||||
const modelRegistry = discoverModels(authStorage, resolvedAgentDir);
|
||||
const model = modelRegistry.find(provider, modelId) as Model<Api> | null;
|
||||
if (!model) return { error: `Unknown model: ${provider}/${modelId}` };
|
||||
return { model };
|
||||
}
|
||||
@@ -229,8 +228,24 @@ async function getApiKeyForModel(model: Model<Api>): Promise<string> {
|
||||
const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN;
|
||||
if (oauthEnv?.trim()) return oauthEnv.trim();
|
||||
}
|
||||
const key = await getDefaultApiKey()(model);
|
||||
if (key) return key;
|
||||
const envKey = getEnvApiKey(model.provider);
|
||||
if (envKey) return envKey;
|
||||
if (isOAuthProvider(model.provider)) {
|
||||
const oauthPath = resolveClawdisOAuthPath();
|
||||
const storage = loadOAuthStorageAt(oauthPath);
|
||||
if (storage) {
|
||||
try {
|
||||
const result = await getOAuthApiKey(model.provider, storage);
|
||||
if (result?.apiKey) {
|
||||
storage[model.provider] = result.newCredentials;
|
||||
saveOAuthStorageAt(oauthPath, storage);
|
||||
return result.apiKey;
|
||||
}
|
||||
} catch {
|
||||
// fall through to error below
|
||||
}
|
||||
}
|
||||
}
|
||||
throw new Error(`No API key found for provider "${model.provider}"`);
|
||||
}
|
||||
|
||||
@@ -423,6 +438,7 @@ export async function runEmbeddedPiAgent(params: {
|
||||
toolMetas,
|
||||
unsubscribe,
|
||||
flush: flushToolDebouncer,
|
||||
waitForCompactionRetry,
|
||||
} = subscribeEmbeddedPiSession({
|
||||
session,
|
||||
runId: params.runId,
|
||||
@@ -463,10 +479,10 @@ export async function runEmbeddedPiAgent(params: {
|
||||
await session.prompt(params.prompt);
|
||||
} catch (err) {
|
||||
promptError = err;
|
||||
} finally {
|
||||
messagesSnapshot = session.messages.slice();
|
||||
sessionIdUsed = session.sessionId;
|
||||
}
|
||||
await waitForCompactionRetry();
|
||||
messagesSnapshot = session.messages.slice();
|
||||
sessionIdUsed = session.sessionId;
|
||||
} finally {
|
||||
clearTimeout(abortTimer);
|
||||
unsubscribe();
|
||||
|
||||
Reference in New Issue
Block a user