From 91ca52d3c5331f97c377c86a3ba554cf8b3c268b Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Fri, 23 Jan 2026 03:05:01 +0000 Subject: [PATCH] fix: honor user-pinned profiles and search ranking --- ...ded-pi-agent.auth-profile-rotation.test.ts | 76 ++++++++++++++++++- src/agents/pi-embedded-runner/run.ts | 23 ++++-- .../components/searchable-select-list.test.ts | 16 ++++ src/tui/components/searchable-select-list.ts | 24 +++--- 4 files changed, 120 insertions(+), 19 deletions(-) diff --git a/src/agents/pi-embedded-runner.run-embedded-pi-agent.auth-profile-rotation.test.ts b/src/agents/pi-embedded-runner.run-embedded-pi-agent.auth-profile-rotation.test.ts index 27bd96419..f6f395746 100644 --- a/src/agents/pi-embedded-runner.run-embedded-pi-agent.auth-profile-rotation.test.ts +++ b/src/agents/pi-embedded-runner.run-embedded-pi-agent.auth-profile-rotation.test.ts @@ -210,6 +210,74 @@ describe("runEmbeddedPiAgent auth profile rotation", () => { } }); + it("honors user-pinned profiles even when in cooldown", async () => { + vi.useFakeTimers(); + try { + const agentDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-agent-")); + const workspaceDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-workspace-")); + const now = Date.now(); + vi.setSystemTime(now); + + try { + const authPath = path.join(agentDir, "auth-profiles.json"); + const payload = { + version: 1, + profiles: { + "openai:p1": { type: "api_key", provider: "openai", key: "sk-one" }, + "openai:p2": { type: "api_key", provider: "openai", key: "sk-two" }, + }, + usageStats: { + "openai:p1": { lastUsed: 1, cooldownUntil: now + 60 * 60 * 1000 }, + "openai:p2": { lastUsed: 2 }, + }, + }; + await fs.writeFile(authPath, JSON.stringify(payload)); + + runEmbeddedAttemptMock.mockResolvedValueOnce( + makeAttempt({ + assistantTexts: ["ok"], + lastAssistant: buildAssistant({ + stopReason: "stop", + content: [{ type: "text", text: "ok" }], + }), + }), + ); + + await runEmbeddedPiAgent({ + sessionId: "session:test", + sessionKey: "agent:test:user-cooldown", + sessionFile: path.join(workspaceDir, "session.jsonl"), + workspaceDir, + agentDir, + config: makeConfig(), + prompt: "hello", + provider: "openai", + model: "mock-1", + authProfileId: "openai:p1", + authProfileIdSource: "user", + timeoutMs: 5_000, + runId: "run:user-cooldown", + }); + + expect(runEmbeddedAttemptMock).toHaveBeenCalledTimes(1); + + const stored = JSON.parse( + await fs.readFile(path.join(agentDir, "auth-profiles.json"), "utf-8"), + ) as { + usageStats?: Record; + }; + expect(stored.usageStats?.["openai:p1"]?.cooldownUntil).toBeUndefined(); + expect(stored.usageStats?.["openai:p1"]?.lastUsed).not.toBe(1); + expect(stored.usageStats?.["openai:p2"]?.lastUsed).toBe(2); + } finally { + await fs.rm(agentDir, { recursive: true, force: true }); + await fs.rm(workspaceDir, { recursive: true, force: true }); + } + } finally { + vi.useRealTimers(); + } + }); + it("ignores user-locked profile when provider mismatches", async () => { const agentDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-agent-")); const workspaceDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-workspace-")); @@ -329,10 +397,12 @@ describe("runEmbeddedPiAgent auth profile rotation", () => { profiles: { "openai:p1": { type: "api_key", provider: "openai", key: "sk-one" }, "openai:p2": { type: "api_key", provider: "openai", key: "sk-two" }, + "openai:p3": { type: "api_key", provider: "openai", key: "sk-three" }, }, usageStats: { "openai:p1": { lastUsed: 1 }, "openai:p2": { cooldownUntil: now + 60 * 60 * 1000 }, // p2 in cooldown + "openai:p3": { lastUsed: 3 }, }, }; await fs.writeFile(authPath, JSON.stringify(payload)); @@ -377,8 +447,12 @@ describe("runEmbeddedPiAgent auth profile rotation", () => { const stored = JSON.parse( await fs.readFile(path.join(agentDir, "auth-profiles.json"), "utf-8"), - ) as { usageStats?: Record }; + ) as { + usageStats?: Record; + }; expect(typeof stored.usageStats?.["openai:p1"]?.lastUsed).toBe("number"); + expect(typeof stored.usageStats?.["openai:p3"]?.lastUsed).toBe("number"); + expect(stored.usageStats?.["openai:p2"]?.cooldownUntil).toBe(now + 60 * 60 * 1000); } finally { await fs.rm(agentDir, { recursive: true, force: true }); await fs.rm(workspaceDir, { recursive: true, force: true }); diff --git a/src/agents/pi-embedded-runner/run.ts b/src/agents/pi-embedded-runner/run.ts index a2af256cc..a39b6fd96 100644 --- a/src/agents/pi-embedded-runner/run.ts +++ b/src/agents/pi-embedded-runner/run.ts @@ -149,7 +149,11 @@ export async function runEmbeddedPiAgent( if (lockedProfileId && !profileOrder.includes(lockedProfileId)) { throw new Error(`Auth profile "${lockedProfileId}" is not configured for ${provider}.`); } - const profileCandidates = profileOrder.length > 0 ? profileOrder : [undefined]; + const profileCandidates = lockedProfileId + ? [lockedProfileId] + : profileOrder.length > 0 + ? profileOrder + : [undefined]; let profileIndex = 0; const initialThinkLevel = params.thinkLevel ?? "off"; @@ -170,13 +174,14 @@ export async function runEmbeddedPiAgent( const applyApiKeyInfo = async (candidate?: string): Promise => { apiKeyInfo = await resolveApiKeyForCandidate(candidate); + const resolvedProfileId = apiKeyInfo.profileId ?? candidate; if (!apiKeyInfo.apiKey) { if (apiKeyInfo.mode !== "aws-sdk") { throw new Error( `No API key resolved for provider "${model.provider}" (auth mode: ${apiKeyInfo.mode}).`, ); } - lastProfileId = apiKeyInfo.profileId; + lastProfileId = resolvedProfileId; return; } if (model.provider === "github-copilot") { @@ -189,7 +194,7 @@ export async function runEmbeddedPiAgent( } else { authStorage.setRuntimeApiKey(model.provider, apiKeyInfo.apiKey); } - lastProfileId = apiKeyInfo.profileId; + lastProfileId = resolvedProfileId; }; const advanceAuthProfile = async (): Promise => { @@ -218,7 +223,11 @@ export async function runEmbeddedPiAgent( try { while (profileIndex < profileCandidates.length) { const candidate = profileCandidates[profileIndex]; - if (candidate && isProfileInCooldown(authStore, candidate)) { + if ( + candidate && + candidate !== lockedProfileId && + isProfileInCooldown(authStore, candidate) + ) { profileIndex += 1; continue; } @@ -226,7 +235,9 @@ export async function runEmbeddedPiAgent( break; } if (profileIndex >= profileCandidates.length) { - throw new Error(`No available auth profile for ${provider} (all in cooldown or unavailable).`); + throw new Error( + `No available auth profile for ${provider} (all in cooldown or unavailable).`, + ); } } catch (err) { if (profileCandidates[profileIndex] === lockedProfileId) throw err; @@ -518,10 +529,12 @@ export async function runEmbeddedPiAgent( store: authStore, provider, profileId: lastProfileId, + agentDir: params.agentDir, }); await markAuthProfileUsed({ store: authStore, profileId: lastProfileId, + agentDir: params.agentDir, }); } return { diff --git a/src/tui/components/searchable-select-list.test.ts b/src/tui/components/searchable-select-list.test.ts index a16d77320..cf1da265e 100644 --- a/src/tui/components/searchable-select-list.test.ts +++ b/src/tui/components/searchable-select-list.test.ts @@ -76,6 +76,22 @@ describe("SearchableSelectList", () => { expect(selected?.value).toBe("opus-direct"); }); + it("keeps exact label matches ahead of description matches", () => { + const longPrefix = "x".repeat(250); + const items = [ + { value: "late-label", label: `${longPrefix}opus`, description: "late exact match" }, + { value: "desc-first", label: "provider/other", description: "opus in description" }, + ]; + const list = new SearchableSelectList(items, 5, mockTheme); + + for (const ch of "opus") { + list.handleInput(ch); + } + + const selected = list.getSelectedItem(); + expect(selected?.value).toBe("late-label"); + }); + it("exact label match beats description match", () => { const items = [ { diff --git a/src/tui/components/searchable-select-list.ts b/src/tui/components/searchable-select-list.ts index 886cc0170..f8e07e790 100644 --- a/src/tui/components/searchable-select-list.ts +++ b/src/tui/components/searchable-select-list.ts @@ -73,7 +73,7 @@ export class SearchableSelectList implements Component { */ private smartFilter(query: string): SelectItem[] { const q = query.toLowerCase(); - type ScoredItem = { item: SelectItem; score: number }; + type ScoredItem = { item: SelectItem; tier: number; score: number }; const scoredItems: ScoredItem[] = []; const fuzzyCandidates: SelectItem[] = []; @@ -81,22 +81,22 @@ export class SearchableSelectList implements Component { const label = item.label.toLowerCase(); const desc = (item.description ?? "").toLowerCase(); - // Tier 1: Exact substring in label (score 0-99) + // Tier 1: Exact substring in label const labelIndex = label.indexOf(q); if (labelIndex !== -1) { - scoredItems.push({ item, score: labelIndex }); + scoredItems.push({ item, tier: 0, score: labelIndex }); continue; } - // Tier 2: Word-boundary prefix in label (score 100-199) + // Tier 2: Word-boundary prefix in label const wordBoundaryIndex = findWordBoundaryIndex(label, q); if (wordBoundaryIndex !== null) { - scoredItems.push({ item, score: 100 + wordBoundaryIndex }); + scoredItems.push({ item, tier: 1, score: wordBoundaryIndex }); continue; } - // Tier 3: Exact substring in description (score 200-299) + // Tier 3: Exact substring in description const descIndex = desc.indexOf(q); if (descIndex !== -1) { - scoredItems.push({ item, score: 200 + descIndex }); + scoredItems.push({ item, tier: 2, score: descIndex }); continue; } // Tier 4: Fuzzy match (score 300+) @@ -108,10 +108,7 @@ export class SearchableSelectList implements Component { const preparedCandidates = prepareSearchItems(fuzzyCandidates); const fuzzyMatches = fuzzyFilterLower(preparedCandidates, q); - return [ - ...scoredItems.map((s) => s.item), - ...fuzzyMatches, - ]; + return [...scoredItems.map((s) => s.item), ...fuzzyMatches]; } private escapeRegex(str: string): string { @@ -119,9 +116,10 @@ export class SearchableSelectList implements Component { } private compareByScore = ( - a: { item: SelectItem; score: number }, - b: { item: SelectItem; score: number }, + a: { item: SelectItem; tier: number; score: number }, + b: { item: SelectItem; tier: number; score: number }, ) => { + if (a.tier !== b.tier) return a.tier - b.tier; if (a.score !== b.score) return a.score - b.score; return this.getItemLabel(a.item).localeCompare(this.getItemLabel(b.item)); };