refactor(agents): centralize failover handling

This commit is contained in:
Peter Steinberger
2026-01-09 21:31:13 +01:00
parent cfeaa34c16
commit 374aa856f2
7 changed files with 292 additions and 61 deletions

View File

@@ -151,6 +151,49 @@ describe("resolveAuthProfileOrder", () => {
expect(order).toEqual(["anthropic:work", "anthropic:default"]); expect(order).toEqual(["anthropic:work", "anthropic:default"]);
}); });
it("pushes disabled profiles to the end even with store order", () => {
const now = Date.now();
const order = resolveAuthProfileOrder({
store: {
...store,
order: { anthropic: ["anthropic:default", "anthropic:work"] },
usageStats: {
"anthropic:default": {
disabledUntil: now + 60_000,
disabledReason: "billing",
},
"anthropic:work": { lastUsed: 1 },
},
},
provider: "anthropic",
});
expect(order).toEqual(["anthropic:work", "anthropic:default"]);
});
it("pushes disabled profiles to the end even with configured order", () => {
const now = Date.now();
const order = resolveAuthProfileOrder({
cfg: {
auth: {
order: { anthropic: ["anthropic:default", "anthropic:work"] },
profiles: cfg.auth.profiles,
},
},
store: {
...store,
usageStats: {
"anthropic:default": {
disabledUntil: now + 60_000,
disabledReason: "billing",
},
"anthropic:work": { lastUsed: 1 },
},
},
provider: "anthropic",
});
expect(order).toEqual(["anthropic:work", "anthropic:default"]);
});
it("normalizes z.ai aliases in auth.order", () => { it("normalizes z.ai aliases in auth.order", () => {
const order = resolveAuthProfileOrder({ const order = resolveAuthProfileOrder({
cfg: { cfg: {

View File

@@ -72,11 +72,21 @@ export type AuthProfileCredential =
| TokenCredential | TokenCredential
| OAuthCredential; | OAuthCredential;
export type AuthProfileFailureReason =
| "auth"
| "rate_limit"
| "billing"
| "timeout"
| "unknown";
/** Per-profile usage statistics for round-robin and cooldown tracking */ /** Per-profile usage statistics for round-robin and cooldown tracking */
export type ProfileUsageStats = { export type ProfileUsageStats = {
lastUsed?: number; lastUsed?: number;
cooldownUntil?: number; cooldownUntil?: number;
disabledUntil?: number;
disabledReason?: AuthProfileFailureReason;
errorCount?: number; errorCount?: number;
failureCounts?: Partial<Record<AuthProfileFailureReason, number>>;
}; };
export type AuthProfileStore = { export type AuthProfileStore = {
@@ -772,8 +782,9 @@ export function isProfileInCooldown(
profileId: string, profileId: string,
): boolean { ): boolean {
const stats = store.usageStats?.[profileId]; const stats = store.usageStats?.[profileId];
if (!stats?.cooldownUntil) return false; if (!stats) return false;
return Date.now() < stats.cooldownUntil; const unusableUntil = resolveProfileUnusableUntil(stats);
return unusableUntil ? Date.now() < unusableUntil : false;
} }
/** /**
@@ -796,6 +807,9 @@ export async function markAuthProfileUsed(params: {
lastUsed: Date.now(), lastUsed: Date.now(),
errorCount: 0, errorCount: 0,
cooldownUntil: undefined, cooldownUntil: undefined,
disabledUntil: undefined,
disabledReason: undefined,
failureCounts: undefined,
}; };
return true; return true;
}, },
@@ -812,6 +826,9 @@ export async function markAuthProfileUsed(params: {
lastUsed: Date.now(), lastUsed: Date.now(),
errorCount: 0, errorCount: 0,
cooldownUntil: undefined, cooldownUntil: undefined,
disabledUntil: undefined,
disabledReason: undefined,
failureCounts: undefined,
}; };
saveAuthProfileStore(store, agentDir); saveAuthProfileStore(store, agentDir);
} }
@@ -824,34 +841,74 @@ export function calculateAuthProfileCooldownMs(errorCount: number): number {
); );
} }
function calculateAuthProfileBillingDisableMs(errorCount: number): number {
const normalized = Math.max(1, errorCount);
const steps = [
30 * 60 * 1000, // 30 min
2 * 60 * 60 * 1000, // 2 hours
8 * 60 * 60 * 1000, // 8 hours
24 * 60 * 60 * 1000, // 24 hours
];
return steps[Math.min(normalized - 1, steps.length - 1)] as number;
}
function resolveProfileUnusableUntil(stats: ProfileUsageStats): number | null {
const values = [stats.cooldownUntil, stats.disabledUntil]
.filter((value): value is number => typeof value === "number")
.filter((value) => Number.isFinite(value) && value > 0);
if (values.length === 0) return null;
return Math.max(...values);
}
export function resolveProfileUnusableUntilForDisplay(
store: AuthProfileStore,
profileId: string,
): number | null {
const stats = store.usageStats?.[profileId];
if (!stats) return null;
return resolveProfileUnusableUntil(stats);
}
/** /**
* Mark a profile as failed/rate-limited. Applies exponential backoff cooldown. * Mark a profile as failed for a specific reason. Billing failures are treated
* Cooldown times: 1min, 5min, 25min, max 1 hour. * as "disabled" (longer backoff) vs the regular cooldown window.
* Uses store lock to avoid overwriting concurrent usage updates.
*/ */
export async function markAuthProfileCooldown(params: { export async function markAuthProfileFailure(params: {
store: AuthProfileStore; store: AuthProfileStore;
profileId: string; profileId: string;
reason: AuthProfileFailureReason;
agentDir?: string; agentDir?: string;
}): Promise<void> { }): Promise<void> {
const { store, profileId, agentDir } = params; const { store, profileId, reason, agentDir } = params;
const updated = await updateAuthProfileStoreWithLock({ const updated = await updateAuthProfileStoreWithLock({
agentDir, agentDir,
updater: (freshStore) => { updater: (freshStore) => {
if (!freshStore.profiles[profileId]) return false; if (!freshStore.profiles[profileId]) return false;
freshStore.usageStats = freshStore.usageStats ?? {}; freshStore.usageStats = freshStore.usageStats ?? {};
const existing = freshStore.usageStats[profileId] ?? {}; const existing = freshStore.usageStats[profileId] ?? {};
const errorCount = (existing.errorCount ?? 0) + 1;
// Exponential backoff: 1min, 5min, 25min, capped at 1h const nextErrorCount = (existing.errorCount ?? 0) + 1;
const backoffMs = calculateAuthProfileCooldownMs(errorCount); const failureCounts = { ...existing.failureCounts };
failureCounts[reason] = (failureCounts[reason] ?? 0) + 1;
freshStore.usageStats[profileId] = { const now = Date.now();
const updatedStats: ProfileUsageStats = {
...existing, ...existing,
errorCount, errorCount: nextErrorCount,
cooldownUntil: Date.now() + backoffMs, failureCounts,
}; };
if (reason === "billing") {
const billingCount = failureCounts.billing ?? 1;
const backoffMs = calculateAuthProfileBillingDisableMs(billingCount);
updatedStats.disabledUntil = now + backoffMs;
updatedStats.disabledReason = "billing";
} else {
const backoffMs = calculateAuthProfileCooldownMs(nextErrorCount);
updatedStats.cooldownUntil = now + backoffMs;
}
freshStore.usageStats[profileId] = updatedStats;
return true; return true;
}, },
}); });
@@ -863,19 +920,48 @@ export async function markAuthProfileCooldown(params: {
store.usageStats = store.usageStats ?? {}; store.usageStats = store.usageStats ?? {};
const existing = store.usageStats[profileId] ?? {}; const existing = store.usageStats[profileId] ?? {};
const errorCount = (existing.errorCount ?? 0) + 1; const nextErrorCount = (existing.errorCount ?? 0) + 1;
const failureCounts = { ...existing.failureCounts };
failureCounts[reason] = (failureCounts[reason] ?? 0) + 1;
// Exponential backoff: 1min, 5min, 25min, capped at 1h const now = Date.now();
const backoffMs = calculateAuthProfileCooldownMs(errorCount); const updatedStats: ProfileUsageStats = {
store.usageStats[profileId] = {
...existing, ...existing,
errorCount, errorCount: nextErrorCount,
cooldownUntil: Date.now() + backoffMs, failureCounts,
}; };
if (reason === "billing") {
const billingCount = failureCounts.billing ?? 1;
const backoffMs = calculateAuthProfileBillingDisableMs(billingCount);
updatedStats.disabledUntil = now + backoffMs;
updatedStats.disabledReason = "billing";
} else {
const backoffMs = calculateAuthProfileCooldownMs(nextErrorCount);
updatedStats.cooldownUntil = now + backoffMs;
}
store.usageStats[profileId] = updatedStats;
saveAuthProfileStore(store, agentDir); saveAuthProfileStore(store, agentDir);
} }
/**
* Mark a profile as failed/rate-limited. Applies exponential backoff cooldown.
* Cooldown times: 1min, 5min, 25min, max 1 hour.
* Uses store lock to avoid overwriting concurrent usage updates.
*/
export async function markAuthProfileCooldown(params: {
store: AuthProfileStore;
profileId: string;
agentDir?: string;
}): Promise<void> {
await markAuthProfileFailure({
store: params.store,
profileId: params.profileId,
reason: "unknown",
agentDir: params.agentDir,
});
}
/** /**
* Clear cooldown for a profile (e.g., manual reset). * Clear cooldown for a profile (e.g., manual reset).
* Uses store lock to avoid overwriting concurrent usage updates. * Uses store lock to avoid overwriting concurrent usage updates.
@@ -973,7 +1059,8 @@ export function resolveAuthProfileOrder(params: {
const inCooldown: Array<{ profileId: string; cooldownUntil: number }> = []; const inCooldown: Array<{ profileId: string; cooldownUntil: number }> = [];
for (const profileId of deduped) { for (const profileId of deduped) {
const cooldownUntil = store.usageStats?.[profileId]?.cooldownUntil; const cooldownUntil =
resolveProfileUnusableUntil(store.usageStats?.[profileId] ?? {}) ?? 0;
if ( if (
typeof cooldownUntil === "number" && typeof cooldownUntil === "number" &&
Number.isFinite(cooldownUntil) && Number.isFinite(cooldownUntil) &&
@@ -1057,7 +1144,8 @@ function orderProfilesByMode(
const cooldownSorted = inCooldown const cooldownSorted = inCooldown
.map((profileId) => ({ .map((profileId) => ({
profileId, profileId,
cooldownUntil: store.usageStats?.[profileId]?.cooldownUntil ?? now, cooldownUntil:
resolveProfileUnusableUntil(store.usageStats?.[profileId] ?? {}) ?? now,
})) }))
.sort((a, b) => a.cooldownUntil - b.cooldownUntil) .sort((a, b) => a.cooldownUntil - b.cooldownUntil)
.map((entry) => entry.profileId); .map((entry) => entry.profileId);

View File

@@ -56,6 +56,28 @@ describe("runWithModelFallback", () => {
expect(run.mock.calls[1]?.[1]).toBe("claude-haiku-3-5"); expect(run.mock.calls[1]?.[1]).toBe("claude-haiku-3-5");
}); });
it("falls back on 402 payment required", async () => {
const cfg = makeCfg();
const run = vi
.fn()
.mockRejectedValueOnce(
Object.assign(new Error("payment required"), { status: 402 }),
)
.mockResolvedValueOnce("ok");
const result = await runWithModelFallback({
cfg,
provider: "openai",
model: "gpt-4.1-mini",
run,
});
expect(result.result).toBe("ok");
expect(run).toHaveBeenCalledTimes(2);
expect(run.mock.calls[1]?.[0]).toBe("anthropic");
expect(run.mock.calls[1]?.[1]).toBe("claude-haiku-3-5");
});
it("falls back on billing errors", async () => { it("falls back on billing errors", async () => {
const cfg = makeCfg(); const cfg = makeCfg();
const run = vi const run = vi

View File

@@ -7,11 +7,7 @@ import {
resolveConfiguredModelRef, resolveConfiguredModelRef,
resolveModelRefFromString, resolveModelRefFromString,
} from "./model-selection.js"; } from "./model-selection.js";
import { import { isFailoverErrorMessage } from "./pi-embedded-helpers.js";
isAuthErrorMessage,
isBillingErrorMessage,
isRateLimitErrorMessage,
} from "./pi-embedded-helpers.js";
type ModelCandidate = { type ModelCandidate = {
provider: string; provider: string;
@@ -71,16 +67,6 @@ function getErrorMessage(err: unknown): string {
return ""; return "";
} }
function isTimeoutErrorMessage(raw: string): boolean {
const value = raw.toLowerCase();
return (
value.includes("timeout") ||
value.includes("timed out") ||
value.includes("deadline exceeded") ||
value.includes("context deadline exceeded")
);
}
function shouldFallbackForError(err: unknown): boolean { function shouldFallbackForError(err: unknown): boolean {
const statusCode = getStatusCode(err); const statusCode = getStatusCode(err);
if (statusCode && [401, 402, 403, 429].includes(statusCode)) return true; if (statusCode && [401, 402, 403, 429].includes(statusCode)) return true;
@@ -94,12 +80,7 @@ function shouldFallbackForError(err: unknown): boolean {
} }
const message = getErrorMessage(err); const message = getErrorMessage(err);
if (!message) return false; if (!message) return false;
return ( return isFailoverErrorMessage(message);
isAuthErrorMessage(message) ||
isRateLimitErrorMessage(message) ||
isBillingErrorMessage(message) ||
isTimeoutErrorMessage(message)
);
} }
function buildAllowedModelKeys( function buildAllowedModelKeys(

View File

@@ -3,9 +3,11 @@ import type { AssistantMessage } from "@mariozechner/pi-ai";
import { describe, expect, it } from "vitest"; import { describe, expect, it } from "vitest";
import { import {
buildBootstrapContextFiles, buildBootstrapContextFiles,
classifyFailoverReason,
formatAssistantErrorText, formatAssistantErrorText,
isBillingErrorMessage, isBillingErrorMessage,
isContextOverflowError, isContextOverflowError,
isFailoverErrorMessage,
isMessagingToolDuplicate, isMessagingToolDuplicate,
normalizeTextForComparison, normalizeTextForComparison,
sanitizeGoogleTurnOrdering, sanitizeGoogleTurnOrdering,
@@ -238,6 +240,30 @@ describe("isBillingErrorMessage", () => {
}); });
}); });
describe("isFailoverErrorMessage", () => {
it("matches auth/rate/billing/timeout", () => {
const samples = [
"invalid api key",
"429 rate limit exceeded",
"Your credit balance is too low",
"request timed out",
];
for (const sample of samples) {
expect(isFailoverErrorMessage(sample)).toBe(true);
}
});
});
describe("classifyFailoverReason", () => {
it("returns a stable reason", () => {
expect(classifyFailoverReason("invalid api key")).toBe("auth");
expect(classifyFailoverReason("429 too many requests")).toBe("rate_limit");
expect(classifyFailoverReason("credit balance too low")).toBe("billing");
expect(classifyFailoverReason("deadline exceeded")).toBe("timeout");
expect(classifyFailoverReason("bad request")).toBeNull();
});
});
describe("formatAssistantErrorText", () => { describe("formatAssistantErrorText", () => {
const makeAssistantError = (errorMessage: string): AssistantMessage => const makeAssistantError = (errorMessage: string): AssistantMessage =>
({ ({

View File

@@ -261,6 +261,17 @@ export function isRateLimitErrorMessage(raw: string): boolean {
); );
} }
export function isTimeoutErrorMessage(raw: string): boolean {
const value = raw.toLowerCase();
if (!value) return false;
return (
value.includes("timeout") ||
value.includes("timed out") ||
value.includes("deadline exceeded") ||
value.includes("context deadline exceeded")
);
}
export function isBillingErrorMessage(raw: string): boolean { export function isBillingErrorMessage(raw: string): boolean {
const value = raw.toLowerCase(); const value = raw.toLowerCase();
if (!value) return false; if (!value) return false;
@@ -308,6 +319,32 @@ export function isAuthAssistantError(
return isAuthErrorMessage(msg.errorMessage ?? ""); return isAuthErrorMessage(msg.errorMessage ?? "");
} }
export type FailoverReason =
| "auth"
| "rate_limit"
| "billing"
| "timeout"
| "unknown";
export function classifyFailoverReason(raw: string): FailoverReason | null {
if (isAuthErrorMessage(raw)) return "auth";
if (isRateLimitErrorMessage(raw)) return "rate_limit";
if (isBillingErrorMessage(raw)) return "billing";
if (isTimeoutErrorMessage(raw)) return "timeout";
return null;
}
export function isFailoverErrorMessage(raw: string): boolean {
return classifyFailoverReason(raw) !== null;
}
export function isFailoverAssistantError(
msg: AssistantMessage | undefined,
): boolean {
if (!msg || msg.stopReason !== "error") return false;
return isFailoverErrorMessage(msg.errorMessage ?? "");
}
function extractSupportedValues(raw: string): string[] { function extractSupportedValues(raw: string): string[] {
const match = const match =
raw.match(/supported values are:\s*([^\n.]+)/i) ?? raw.match(/supported values are:\s*([^\n.]+)/i) ??

View File

@@ -37,7 +37,7 @@ import { normalizeMessageProvider } from "../utils/message-provider.js";
import { resolveUserPath } from "../utils.js"; import { resolveUserPath } from "../utils.js";
import { resolveClawdbotAgentDir } from "./agent-paths.js"; import { resolveClawdbotAgentDir } from "./agent-paths.js";
import { import {
markAuthProfileCooldown, markAuthProfileFailure,
markAuthProfileGood, markAuthProfileGood,
markAuthProfileUsed, markAuthProfileUsed,
} from "./auth-profiles.js"; } from "./auth-profiles.js";
@@ -55,17 +55,17 @@ import {
import { ensureClawdbotModelsJson } from "./models-config.js"; import { ensureClawdbotModelsJson } from "./models-config.js";
import { import {
buildBootstrapContextFiles, buildBootstrapContextFiles,
classifyFailoverReason,
type EmbeddedContextFile, type EmbeddedContextFile,
ensureSessionHeader, ensureSessionHeader,
formatAssistantErrorText, formatAssistantErrorText,
isAuthAssistantError, isAuthAssistantError,
isAuthErrorMessage,
isBillingAssistantError,
isBillingErrorMessage,
isContextOverflowError, isContextOverflowError,
isFailoverAssistantError,
isFailoverErrorMessage,
isGoogleModelApi, isGoogleModelApi,
isRateLimitAssistantError, isRateLimitAssistantError,
isRateLimitErrorMessage, isTimeoutErrorMessage,
pickFallbackThinkingLevel, pickFallbackThinkingLevel,
sanitizeGoogleTurnOrdering, sanitizeGoogleTurnOrdering,
sanitizeSessionMessagesImages, sanitizeSessionMessagesImages,
@@ -1438,10 +1438,22 @@ export async function runEmbeddedPiAgent(params: {
}, },
}; };
} }
const promptFailoverReason = classifyFailoverReason(errorText);
if ( if (
(isAuthErrorMessage(errorText) || promptFailoverReason &&
isRateLimitErrorMessage(errorText) || promptFailoverReason !== "timeout" &&
isBillingErrorMessage(errorText)) && lastProfileId
) {
await markAuthProfileFailure({
store: authStore,
profileId: lastProfileId,
reason: promptFailoverReason,
agentDir: params.agentDir,
});
}
if (
isFailoverErrorMessage(errorText) &&
promptFailoverReason !== "timeout" &&
(await advanceAuthProfile()) (await advanceAuthProfile())
) { ) {
continue; continue;
@@ -1484,19 +1496,26 @@ export async function runEmbeddedPiAgent(params: {
0; 0;
const authFailure = isAuthAssistantError(lastAssistant); const authFailure = isAuthAssistantError(lastAssistant);
const rateLimitFailure = isRateLimitAssistantError(lastAssistant); const rateLimitFailure = isRateLimitAssistantError(lastAssistant);
const billingFailure = isBillingAssistantError(lastAssistant); const failoverFailure = isFailoverAssistantError(lastAssistant);
const assistantFailoverReason = classifyFailoverReason(
lastAssistant?.errorMessage ?? "",
);
// Treat timeout as potential rate limit (Antigravity hangs on rate limit) // Treat timeout as potential rate limit (Antigravity hangs on rate limit)
const shouldRotate = const shouldRotate = (!aborted && failoverFailure) || timedOut;
(!aborted && (authFailure || rateLimitFailure || billingFailure)) ||
timedOut;
if (shouldRotate) { if (shouldRotate) {
// Mark current profile for cooldown before rotating // Mark current profile for cooldown before rotating
if (lastProfileId) { if (lastProfileId) {
await markAuthProfileCooldown({ const reason =
timedOut || assistantFailoverReason === "timeout"
? "timeout"
: (assistantFailoverReason ?? "unknown");
await markAuthProfileFailure({
store: authStore, store: authStore,
profileId: lastProfileId, profileId: lastProfileId,
reason,
agentDir: params.agentDir,
}); });
if (timedOut) { if (timedOut) {
log.warn( log.warn(
@@ -1518,10 +1537,25 @@ export async function runEmbeddedPiAgent(params: {
? "LLM request timed out." ? "LLM request timed out."
: rateLimitFailure : rateLimitFailure
? "LLM request rate limited." ? "LLM request rate limited."
: billingFailure : authFailure
? "LLM request payment required." ? "LLM request unauthorized."
: "LLM request unauthorized."); : "LLM request failed.");
throw new Error(message); const err = new Error(message);
(err as { failoverReason?: string }).failoverReason =
assistantFailoverReason ?? undefined;
if (assistantFailoverReason === "billing") {
(err as { status?: number }).status = 402;
} else if (assistantFailoverReason === "rate_limit") {
(err as { status?: number }).status = 429;
} else if (assistantFailoverReason === "auth") {
(err as { status?: number }).status = 401;
} else if (
assistantFailoverReason === "timeout" ||
isTimeoutErrorMessage(message)
) {
(err as { status?: number }).status = 408;
}
throw err;
} }
} }