fix: honor user-pinned profiles and search ranking

This commit is contained in:
Peter Steinberger
2026-01-23 03:05:01 +00:00
parent 917bcb714e
commit 91ca52d3c5
4 changed files with 120 additions and 19 deletions

View File

@@ -210,6 +210,74 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
} }
}); });
it("honors user-pinned profiles even when in cooldown", async () => {
vi.useFakeTimers();
try {
const agentDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-agent-"));
const workspaceDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-workspace-"));
const now = Date.now();
vi.setSystemTime(now);
try {
const authPath = path.join(agentDir, "auth-profiles.json");
const payload = {
version: 1,
profiles: {
"openai:p1": { type: "api_key", provider: "openai", key: "sk-one" },
"openai:p2": { type: "api_key", provider: "openai", key: "sk-two" },
},
usageStats: {
"openai:p1": { lastUsed: 1, cooldownUntil: now + 60 * 60 * 1000 },
"openai:p2": { lastUsed: 2 },
},
};
await fs.writeFile(authPath, JSON.stringify(payload));
runEmbeddedAttemptMock.mockResolvedValueOnce(
makeAttempt({
assistantTexts: ["ok"],
lastAssistant: buildAssistant({
stopReason: "stop",
content: [{ type: "text", text: "ok" }],
}),
}),
);
await runEmbeddedPiAgent({
sessionId: "session:test",
sessionKey: "agent:test:user-cooldown",
sessionFile: path.join(workspaceDir, "session.jsonl"),
workspaceDir,
agentDir,
config: makeConfig(),
prompt: "hello",
provider: "openai",
model: "mock-1",
authProfileId: "openai:p1",
authProfileIdSource: "user",
timeoutMs: 5_000,
runId: "run:user-cooldown",
});
expect(runEmbeddedAttemptMock).toHaveBeenCalledTimes(1);
const stored = JSON.parse(
await fs.readFile(path.join(agentDir, "auth-profiles.json"), "utf-8"),
) as {
usageStats?: Record<string, { lastUsed?: number; cooldownUntil?: number }>;
};
expect(stored.usageStats?.["openai:p1"]?.cooldownUntil).toBeUndefined();
expect(stored.usageStats?.["openai:p1"]?.lastUsed).not.toBe(1);
expect(stored.usageStats?.["openai:p2"]?.lastUsed).toBe(2);
} finally {
await fs.rm(agentDir, { recursive: true, force: true });
await fs.rm(workspaceDir, { recursive: true, force: true });
}
} finally {
vi.useRealTimers();
}
});
it("ignores user-locked profile when provider mismatches", async () => { it("ignores user-locked profile when provider mismatches", async () => {
const agentDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-agent-")); const agentDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-agent-"));
const workspaceDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-workspace-")); const workspaceDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-workspace-"));
@@ -329,10 +397,12 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
profiles: { profiles: {
"openai:p1": { type: "api_key", provider: "openai", key: "sk-one" }, "openai:p1": { type: "api_key", provider: "openai", key: "sk-one" },
"openai:p2": { type: "api_key", provider: "openai", key: "sk-two" }, "openai:p2": { type: "api_key", provider: "openai", key: "sk-two" },
"openai:p3": { type: "api_key", provider: "openai", key: "sk-three" },
}, },
usageStats: { usageStats: {
"openai:p1": { lastUsed: 1 }, "openai:p1": { lastUsed: 1 },
"openai:p2": { cooldownUntil: now + 60 * 60 * 1000 }, // p2 in cooldown "openai:p2": { cooldownUntil: now + 60 * 60 * 1000 }, // p2 in cooldown
"openai:p3": { lastUsed: 3 },
}, },
}; };
await fs.writeFile(authPath, JSON.stringify(payload)); await fs.writeFile(authPath, JSON.stringify(payload));
@@ -377,8 +447,12 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
const stored = JSON.parse( const stored = JSON.parse(
await fs.readFile(path.join(agentDir, "auth-profiles.json"), "utf-8"), await fs.readFile(path.join(agentDir, "auth-profiles.json"), "utf-8"),
) as { usageStats?: Record<string, { lastUsed?: number }> }; ) as {
usageStats?: Record<string, { lastUsed?: number; cooldownUntil?: number }>;
};
expect(typeof stored.usageStats?.["openai:p1"]?.lastUsed).toBe("number"); expect(typeof stored.usageStats?.["openai:p1"]?.lastUsed).toBe("number");
expect(typeof stored.usageStats?.["openai:p3"]?.lastUsed).toBe("number");
expect(stored.usageStats?.["openai:p2"]?.cooldownUntil).toBe(now + 60 * 60 * 1000);
} finally { } finally {
await fs.rm(agentDir, { recursive: true, force: true }); await fs.rm(agentDir, { recursive: true, force: true });
await fs.rm(workspaceDir, { recursive: true, force: true }); await fs.rm(workspaceDir, { recursive: true, force: true });

View File

@@ -149,7 +149,11 @@ export async function runEmbeddedPiAgent(
if (lockedProfileId && !profileOrder.includes(lockedProfileId)) { if (lockedProfileId && !profileOrder.includes(lockedProfileId)) {
throw new Error(`Auth profile "${lockedProfileId}" is not configured for ${provider}.`); throw new Error(`Auth profile "${lockedProfileId}" is not configured for ${provider}.`);
} }
const profileCandidates = profileOrder.length > 0 ? profileOrder : [undefined]; const profileCandidates = lockedProfileId
? [lockedProfileId]
: profileOrder.length > 0
? profileOrder
: [undefined];
let profileIndex = 0; let profileIndex = 0;
const initialThinkLevel = params.thinkLevel ?? "off"; const initialThinkLevel = params.thinkLevel ?? "off";
@@ -170,13 +174,14 @@ export async function runEmbeddedPiAgent(
const applyApiKeyInfo = async (candidate?: string): Promise<void> => { const applyApiKeyInfo = async (candidate?: string): Promise<void> => {
apiKeyInfo = await resolveApiKeyForCandidate(candidate); apiKeyInfo = await resolveApiKeyForCandidate(candidate);
const resolvedProfileId = apiKeyInfo.profileId ?? candidate;
if (!apiKeyInfo.apiKey) { if (!apiKeyInfo.apiKey) {
if (apiKeyInfo.mode !== "aws-sdk") { if (apiKeyInfo.mode !== "aws-sdk") {
throw new Error( throw new Error(
`No API key resolved for provider "${model.provider}" (auth mode: ${apiKeyInfo.mode}).`, `No API key resolved for provider "${model.provider}" (auth mode: ${apiKeyInfo.mode}).`,
); );
} }
lastProfileId = apiKeyInfo.profileId; lastProfileId = resolvedProfileId;
return; return;
} }
if (model.provider === "github-copilot") { if (model.provider === "github-copilot") {
@@ -189,7 +194,7 @@ export async function runEmbeddedPiAgent(
} else { } else {
authStorage.setRuntimeApiKey(model.provider, apiKeyInfo.apiKey); authStorage.setRuntimeApiKey(model.provider, apiKeyInfo.apiKey);
} }
lastProfileId = apiKeyInfo.profileId; lastProfileId = resolvedProfileId;
}; };
const advanceAuthProfile = async (): Promise<boolean> => { const advanceAuthProfile = async (): Promise<boolean> => {
@@ -218,7 +223,11 @@ export async function runEmbeddedPiAgent(
try { try {
while (profileIndex < profileCandidates.length) { while (profileIndex < profileCandidates.length) {
const candidate = profileCandidates[profileIndex]; const candidate = profileCandidates[profileIndex];
if (candidate && isProfileInCooldown(authStore, candidate)) { if (
candidate &&
candidate !== lockedProfileId &&
isProfileInCooldown(authStore, candidate)
) {
profileIndex += 1; profileIndex += 1;
continue; continue;
} }
@@ -226,7 +235,9 @@ export async function runEmbeddedPiAgent(
break; break;
} }
if (profileIndex >= profileCandidates.length) { if (profileIndex >= profileCandidates.length) {
throw new Error(`No available auth profile for ${provider} (all in cooldown or unavailable).`); throw new Error(
`No available auth profile for ${provider} (all in cooldown or unavailable).`,
);
} }
} catch (err) { } catch (err) {
if (profileCandidates[profileIndex] === lockedProfileId) throw err; if (profileCandidates[profileIndex] === lockedProfileId) throw err;
@@ -518,10 +529,12 @@ export async function runEmbeddedPiAgent(
store: authStore, store: authStore,
provider, provider,
profileId: lastProfileId, profileId: lastProfileId,
agentDir: params.agentDir,
}); });
await markAuthProfileUsed({ await markAuthProfileUsed({
store: authStore, store: authStore,
profileId: lastProfileId, profileId: lastProfileId,
agentDir: params.agentDir,
}); });
} }
return { return {

View File

@@ -76,6 +76,22 @@ describe("SearchableSelectList", () => {
expect(selected?.value).toBe("opus-direct"); expect(selected?.value).toBe("opus-direct");
}); });
it("keeps exact label matches ahead of description matches", () => {
const longPrefix = "x".repeat(250);
const items = [
{ value: "late-label", label: `${longPrefix}opus`, description: "late exact match" },
{ value: "desc-first", label: "provider/other", description: "opus in description" },
];
const list = new SearchableSelectList(items, 5, mockTheme);
for (const ch of "opus") {
list.handleInput(ch);
}
const selected = list.getSelectedItem();
expect(selected?.value).toBe("late-label");
});
it("exact label match beats description match", () => { it("exact label match beats description match", () => {
const items = [ const items = [
{ {

View File

@@ -73,7 +73,7 @@ export class SearchableSelectList implements Component {
*/ */
private smartFilter(query: string): SelectItem[] { private smartFilter(query: string): SelectItem[] {
const q = query.toLowerCase(); const q = query.toLowerCase();
type ScoredItem = { item: SelectItem; score: number }; type ScoredItem = { item: SelectItem; tier: number; score: number };
const scoredItems: ScoredItem[] = []; const scoredItems: ScoredItem[] = [];
const fuzzyCandidates: SelectItem[] = []; const fuzzyCandidates: SelectItem[] = [];
@@ -81,22 +81,22 @@ export class SearchableSelectList implements Component {
const label = item.label.toLowerCase(); const label = item.label.toLowerCase();
const desc = (item.description ?? "").toLowerCase(); const desc = (item.description ?? "").toLowerCase();
// Tier 1: Exact substring in label (score 0-99) // Tier 1: Exact substring in label
const labelIndex = label.indexOf(q); const labelIndex = label.indexOf(q);
if (labelIndex !== -1) { if (labelIndex !== -1) {
scoredItems.push({ item, score: labelIndex }); scoredItems.push({ item, tier: 0, score: labelIndex });
continue; continue;
} }
// Tier 2: Word-boundary prefix in label (score 100-199) // Tier 2: Word-boundary prefix in label
const wordBoundaryIndex = findWordBoundaryIndex(label, q); const wordBoundaryIndex = findWordBoundaryIndex(label, q);
if (wordBoundaryIndex !== null) { if (wordBoundaryIndex !== null) {
scoredItems.push({ item, score: 100 + wordBoundaryIndex }); scoredItems.push({ item, tier: 1, score: wordBoundaryIndex });
continue; continue;
} }
// Tier 3: Exact substring in description (score 200-299) // Tier 3: Exact substring in description
const descIndex = desc.indexOf(q); const descIndex = desc.indexOf(q);
if (descIndex !== -1) { if (descIndex !== -1) {
scoredItems.push({ item, score: 200 + descIndex }); scoredItems.push({ item, tier: 2, score: descIndex });
continue; continue;
} }
// Tier 4: Fuzzy match (score 300+) // Tier 4: Fuzzy match (score 300+)
@@ -108,10 +108,7 @@ export class SearchableSelectList implements Component {
const preparedCandidates = prepareSearchItems(fuzzyCandidates); const preparedCandidates = prepareSearchItems(fuzzyCandidates);
const fuzzyMatches = fuzzyFilterLower(preparedCandidates, q); const fuzzyMatches = fuzzyFilterLower(preparedCandidates, q);
return [ return [...scoredItems.map((s) => s.item), ...fuzzyMatches];
...scoredItems.map((s) => s.item),
...fuzzyMatches,
];
} }
private escapeRegex(str: string): string { private escapeRegex(str: string): string {
@@ -119,9 +116,10 @@ export class SearchableSelectList implements Component {
} }
private compareByScore = ( private compareByScore = (
a: { item: SelectItem; score: number }, a: { item: SelectItem; tier: number; score: number },
b: { item: SelectItem; score: number }, b: { item: SelectItem; tier: number; score: number },
) => { ) => {
if (a.tier !== b.tier) return a.tier - b.tier;
if (a.score !== b.score) return a.score - b.score; if (a.score !== b.score) return a.score - b.score;
return this.getItemLabel(a.item).localeCompare(this.getItemLabel(b.item)); return this.getItemLabel(a.item).localeCompare(this.getItemLabel(b.item));
}; };