fix(auth): preserve auto-pin preference

Co-authored-by: Mykyta Bozhenko <21245729+cheeeee@users.noreply.github.com>
This commit is contained in:
Peter Steinberger
2026-01-18 08:22:50 +00:00
parent e49a2952d9
commit d3862ae30a
11 changed files with 271 additions and 17 deletions

View File

@@ -0,0 +1,211 @@
import fs from "node:fs/promises";
import os from "node:os";
import path from "node:path";
import type { AssistantMessage } from "@mariozechner/pi-ai";
import { beforeEach, describe, expect, it, vi } from "vitest";
import type { ClawdbotConfig } from "../config/config.js";
import type { EmbeddedRunAttemptResult } from "./pi-embedded-runner/run/types.js";
const runEmbeddedAttemptMock = vi.fn<Promise<EmbeddedRunAttemptResult>, [unknown]>();
vi.mock("./pi-embedded-runner/run/attempt.js", () => ({
runEmbeddedAttempt: (params: unknown) => runEmbeddedAttemptMock(params),
}));
let runEmbeddedPiAgent: typeof import("./pi-embedded-runner.js").runEmbeddedPiAgent;
beforeEach(async () => {
vi.useRealTimers();
vi.resetModules();
runEmbeddedAttemptMock.mockReset();
({ runEmbeddedPiAgent } = await import("./pi-embedded-runner.js"));
});
const baseUsage = {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
};
const buildAssistant = (overrides: Partial<AssistantMessage>): AssistantMessage => ({
role: "assistant",
content: [],
api: "openai-responses",
provider: "openai",
model: "mock-1",
usage: baseUsage,
stopReason: "stop",
timestamp: Date.now(),
...overrides,
});
const makeAttempt = (
overrides: Partial<EmbeddedRunAttemptResult>,
): EmbeddedRunAttemptResult => ({
aborted: false,
timedOut: false,
promptError: null,
sessionIdUsed: "session:test",
systemPromptReport: undefined,
messagesSnapshot: [],
assistantTexts: [],
toolMetas: [],
lastAssistant: undefined,
didSendViaMessagingTool: false,
messagingToolSentTexts: [],
messagingToolSentTargets: [],
cloudCodeAssistFormatError: false,
...overrides,
});
const makeConfig = (): ClawdbotConfig =>
({
agents: {
defaults: {
model: {
fallbacks: [],
},
},
},
models: {
providers: {
openai: {
api: "openai-responses",
apiKey: "sk-test",
baseUrl: "https://example.com",
models: [
{
id: "mock-1",
name: "Mock 1",
reasoning: false,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 16_000,
maxTokens: 2048,
},
],
},
},
},
}) satisfies ClawdbotConfig;
const writeAuthStore = async (agentDir: string) => {
const authPath = path.join(agentDir, "auth-profiles.json");
const payload = {
version: 1,
profiles: {
"openai:p1": { type: "api_key", provider: "openai", key: "sk-one" },
"openai:p2": { type: "api_key", provider: "openai", key: "sk-two" },
},
usageStats: {
"openai:p1": { lastUsed: 1 },
"openai:p2": { lastUsed: 2 },
},
};
await fs.writeFile(authPath, JSON.stringify(payload));
};
describe("runEmbeddedPiAgent auth profile rotation", () => {
it("rotates for auto-pinned profiles", async () => {
const agentDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-agent-"));
const workspaceDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-workspace-"));
try {
await writeAuthStore(agentDir);
runEmbeddedAttemptMock
.mockResolvedValueOnce(
makeAttempt({
assistantTexts: [],
lastAssistant: buildAssistant({
stopReason: "error",
errorMessage: "rate limit",
}),
}),
)
.mockResolvedValueOnce(
makeAttempt({
assistantTexts: ["ok"],
lastAssistant: buildAssistant({
stopReason: "stop",
content: [{ type: "text", text: "ok" }],
}),
}),
);
await runEmbeddedPiAgent({
sessionId: "session:test",
sessionKey: "agent:test:auto",
sessionFile: path.join(workspaceDir, "session.jsonl"),
workspaceDir,
agentDir,
config: makeConfig(),
prompt: "hello",
provider: "openai",
model: "mock-1",
authProfileId: "openai:p1",
authProfileIdSource: "auto",
timeoutMs: 5_000,
runId: "run:auto",
});
expect(runEmbeddedAttemptMock).toHaveBeenCalledTimes(2);
const stored = JSON.parse(
await fs.readFile(path.join(agentDir, "auth-profiles.json"), "utf-8"),
) as { usageStats?: Record<string, { lastUsed?: number }> };
expect(typeof stored.usageStats?.["openai:p2"]?.lastUsed).toBe("number");
} finally {
await fs.rm(agentDir, { recursive: true, force: true });
await fs.rm(workspaceDir, { recursive: true, force: true });
}
});
it("does not rotate for user-pinned profiles", async () => {
const agentDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-agent-"));
const workspaceDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-workspace-"));
try {
await writeAuthStore(agentDir);
runEmbeddedAttemptMock.mockResolvedValueOnce(
makeAttempt({
assistantTexts: [],
lastAssistant: buildAssistant({
stopReason: "error",
errorMessage: "rate limit",
}),
}),
);
await runEmbeddedPiAgent({
sessionId: "session:test",
sessionKey: "agent:test:user",
sessionFile: path.join(workspaceDir, "session.jsonl"),
workspaceDir,
agentDir,
config: makeConfig(),
prompt: "hello",
provider: "openai",
model: "mock-1",
authProfileId: "openai:p1",
authProfileIdSource: "user",
timeoutMs: 5_000,
runId: "run:user",
});
expect(runEmbeddedAttemptMock).toHaveBeenCalledTimes(1);
const stored = JSON.parse(
await fs.readFile(path.join(agentDir, "auth-profiles.json"), "utf-8"),
) as { usageStats?: Record<string, { lastUsed?: number }> };
expect(stored.usageStats?.["openai:p2"]?.lastUsed).toBeUndefined();
} finally {
await fs.rm(agentDir, { recursive: true, force: true });
await fs.rm(workspaceDir, { recursive: true, force: true });
}
});
});

View File

@@ -117,15 +117,17 @@ export async function runEmbeddedPiAgent(
}
const authStore = ensureAuthProfileStore(agentDir, { allowKeychainPrompt: false });
const explicitProfileId = params.authProfileId?.trim();
const preferredProfileId = params.authProfileId?.trim();
const lockedProfileId =
params.authProfileIdSource === "user" ? preferredProfileId : undefined;
const profileOrder = resolveAuthProfileOrder({
cfg: params.config,
store: authStore,
provider,
preferredProfile: explicitProfileId,
preferredProfile: preferredProfileId,
});
if (explicitProfileId && !profileOrder.includes(explicitProfileId)) {
throw new Error(`Auth profile "${explicitProfileId}" is not configured for ${provider}.`);
if (lockedProfileId && !profileOrder.includes(lockedProfileId)) {
throw new Error(`Auth profile "${lockedProfileId}" is not configured for ${provider}.`);
}
const profileCandidates = profileOrder.length > 0 ? profileOrder : [undefined];
let profileIndex = 0;
@@ -162,6 +164,7 @@ export async function runEmbeddedPiAgent(
};
const advanceAuthProfile = async (): Promise<boolean> => {
if (lockedProfileId) return false;
let nextIndex = profileIndex + 1;
while (nextIndex < profileCandidates.length) {
const candidate = profileCandidates[nextIndex];
@@ -172,7 +175,7 @@ export async function runEmbeddedPiAgent(
attemptedThinking.clear();
return true;
} catch (err) {
if (candidate && candidate === explicitProfileId) throw err;
if (candidate && candidate === lockedProfileId) throw err;
nextIndex += 1;
}
}
@@ -182,7 +185,7 @@ export async function runEmbeddedPiAgent(
try {
await applyApiKeyInfo(profileCandidates[profileIndex]);
} catch (err) {
if (profileCandidates[profileIndex] === explicitProfileId) throw err;
if (profileCandidates[profileIndex] === lockedProfileId) throw err;
const advanced = await advanceAuthProfile();
if (!advanced) throw err;
}

View File

@@ -30,6 +30,7 @@ export type RunEmbeddedPiAgentParams = {
provider?: string;
model?: string;
authProfileId?: string;
authProfileIdSource?: "auto" | "user";
thinkLevel?: ThinkLevel;
verboseLevel?: VerboseLevel;
reasoningLevel?: ReasoningLevel;

View File

@@ -200,6 +200,10 @@ export async function runAgentTurnWithFallback(params: {
throw err;
});
}
const authProfileId =
provider === params.followupRun.run.provider
? params.followupRun.run.authProfileId
: undefined;
return runEmbeddedPiAgent({
sessionId: params.followupRun.run.sessionId,
sessionKey: params.sessionKey,
@@ -222,7 +226,10 @@ export async function runAgentTurnWithFallback(params: {
enforceFinalTag: resolveEnforceFinalTag(params.followupRun.run, provider),
provider,
model,
authProfileId: params.followupRun.run.authProfileId,
authProfileId,
authProfileIdSource: authProfileId
? params.followupRun.run.authProfileIdSource
: undefined,
thinkLevel: params.followupRun.run.thinkLevel,
verboseLevel: params.followupRun.run.verboseLevel,
reasoningLevel: params.followupRun.run.reasoningLevel,

View File

@@ -96,8 +96,12 @@ export async function runMemoryFlushIfNeeded(params: {
params.followupRun.run.config,
resolveAgentIdFromSessionKey(params.followupRun.run.sessionKey),
),
run: (provider, model) =>
runEmbeddedPiAgent({
run: (provider, model) => {
const authProfileId =
provider === params.followupRun.run.provider
? params.followupRun.run.authProfileId
: undefined;
return runEmbeddedPiAgent({
sessionId: params.followupRun.run.sessionId,
sessionKey: params.sessionKey,
messageProvider: params.sessionCtx.Provider?.trim().toLowerCase() || undefined,
@@ -119,7 +123,10 @@ export async function runMemoryFlushIfNeeded(params: {
enforceFinalTag: resolveEnforceFinalTag(params.followupRun.run, provider),
provider,
model,
authProfileId: params.followupRun.run.authProfileId,
authProfileId,
authProfileIdSource: authProfileId
? params.followupRun.run.authProfileIdSource
: undefined,
thinkLevel: params.followupRun.run.thinkLevel,
verboseLevel: params.followupRun.run.verboseLevel,
reasoningLevel: params.followupRun.run.reasoningLevel,
@@ -136,7 +143,8 @@ export async function runMemoryFlushIfNeeded(params: {
}
}
},
}),
});
},
});
let memoryFlushCompactionCount =
activeSessionEntry?.compactionCount ??

View File

@@ -138,8 +138,10 @@ export function createFollowupRunner(params: {
queued.run.config,
resolveAgentIdFromSessionKey(queued.run.sessionKey),
),
run: (provider, model) =>
runEmbeddedPiAgent({
run: (provider, model) => {
const authProfileId =
provider === queued.run.provider ? queued.run.authProfileId : undefined;
return runEmbeddedPiAgent({
sessionId: queued.run.sessionId,
sessionKey: queued.run.sessionKey,
messageProvider: queued.run.messageProvider,
@@ -154,7 +156,8 @@ export function createFollowupRunner(params: {
enforceFinalTag: queued.run.enforceFinalTag,
provider,
model,
authProfileId: queued.run.authProfileId,
authProfileId,
authProfileIdSource: authProfileId ? queued.run.authProfileIdSource : undefined,
thinkLevel: queued.run.thinkLevel,
verboseLevel: queued.run.verboseLevel,
reasoningLevel: queued.run.reasoningLevel,
@@ -171,7 +174,8 @@ export function createFollowupRunner(params: {
autoCompactionCompleted = true;
}
},
}),
});
},
});
runResult = fallbackResult.result;
fallbackProvider = fallbackResult.provider;

View File

@@ -152,7 +152,13 @@ async function resolveSessionAuthProfileOverride(params: {
let current = sessionEntry.authProfileOverride?.trim();
if (current && !order.includes(current)) current = undefined;
const source = sessionEntry.authProfileOverrideSource ?? (current ? "user" : undefined);
const source =
sessionEntry.authProfileOverrideSource ??
(typeof sessionEntry.authProfileOverrideCompactionCount === "number"
? "auto"
: current
? "user"
: undefined);
if (source === "user" && current && !isNewSession) {
return current;
}
@@ -406,6 +412,7 @@ export async function runPreparedReply(
storePath,
isNewSession,
});
const authProfileIdSource = sessionEntry?.authProfileOverrideSource;
const followupRun = {
prompt: queuedBody,
messageId: sessionCtx.MessageSid,
@@ -430,6 +437,7 @@ export async function runPreparedReply(
provider,
model,
authProfileId,
authProfileIdSource,
thinkLevel: resolvedThinkLevel,
verboseLevel: resolvedVerboseLevel,
reasoningLevel: resolvedReasoningLevel,

View File

@@ -53,6 +53,7 @@ export type FollowupRun = {
provider: string;
model: string;
authProfileId?: string;
authProfileIdSource?: "auto" | "user";
thinkLevel?: ThinkLevel;
verboseLevel?: VerboseLevel;
reasoningLevel?: ReasoningLevel;

View File

@@ -372,6 +372,8 @@ export async function agentCommand(
images: opts.images,
});
}
const authProfileId =
providerOverride === provider ? sessionEntry?.authProfileOverride : undefined;
return runEmbeddedPiAgent({
sessionId,
sessionKey,
@@ -384,7 +386,8 @@ export async function agentCommand(
images: opts.images,
provider: providerOverride,
model: modelOverride,
authProfileId: sessionEntry?.authProfileOverride,
authProfileId,
authProfileIdSource: authProfileId ? sessionEntry?.authProfileOverrideSource : undefined,
thinkLevel: resolvedThinkLevel,
verboseLevel: resolvedVerboseLevel,
timeoutMs,