fix: align pi model discovery with auth storage

This commit is contained in:
Peter Steinberger
2025-12-26 11:49:13 +01:00
parent d31c5d7a2c
commit 82ced33747
6 changed files with 90 additions and 63 deletions

View File

@@ -10,10 +10,15 @@ type ModelEntry = { id: string; contextWindow?: number };
const MODEL_CACHE = new Map<string, number>(); const MODEL_CACHE = new Map<string, number>();
const loadPromise = (async () => { const loadPromise = (async () => {
try { try {
const { discoverModels } = await import("@mariozechner/pi-coding-agent"); const { discoverAuthStorage, discoverModels } = await import(
"@mariozechner/pi-coding-agent"
);
const cfg = loadConfig(); const cfg = loadConfig();
await ensureClawdisModelsJson(cfg); await ensureClawdisModelsJson(cfg);
const models = discoverModels(resolveClawdisAgentDir()) as ModelEntry[]; const agentDir = resolveClawdisAgentDir();
const authStorage = discoverAuthStorage(agentDir);
const modelRegistry = discoverModels(authStorage, agentDir);
const models = modelRegistry.getAll() as ModelEntry[];
for (const m of models) { for (const m of models) {
if (!m?.id) continue; if (!m?.id) continue;
if (typeof m.contextWindow === "number" && m.contextWindow > 0) { if (typeof m.contextWindow === "number" && m.contextWindow > 0) {

View File

@@ -25,41 +25,30 @@ export async function loadModelCatalog(params?: {
if (modelCatalogPromise) return modelCatalogPromise; if (modelCatalogPromise) return modelCatalogPromise;
modelCatalogPromise = (async () => { modelCatalogPromise = (async () => {
const piSdk = (await import("@mariozechner/pi-coding-agent")) as { const piSdk = await import("@mariozechner/pi-coding-agent");
discoverModels: (agentDir?: string) => Array<{
id: string;
name?: string;
provider: string;
contextWindow?: number;
}>;
};
let entries: Array<{ const models: ModelCatalogEntry[] = [];
id: string;
name?: string;
provider: string;
contextWindow?: number;
}> = [];
try { try {
const cfg = params?.config ?? loadConfig(); const cfg = params?.config ?? loadConfig();
await ensureClawdisModelsJson(cfg); await ensureClawdisModelsJson(cfg);
entries = piSdk.discoverModels(resolveClawdisAgentDir()); const agentDir = resolveClawdisAgentDir();
const authStorage = piSdk.discoverAuthStorage(agentDir);
const registry = piSdk.discoverModels(authStorage, agentDir);
const entries = registry.getAll();
for (const entry of entries) {
const id = String(entry?.id ?? "").trim();
if (!id) continue;
const provider = String(entry?.provider ?? "").trim();
if (!provider) continue;
const name = String(entry?.name ?? id).trim() || id;
const contextWindow =
typeof entry?.contextWindow === "number" && entry.contextWindow > 0
? entry.contextWindow
: undefined;
models.push({ id, name, provider, contextWindow });
}
} catch { } catch {
entries = []; // Leave models empty on discovery errors.
}
const models: ModelCatalogEntry[] = [];
for (const entry of entries) {
const id = String(entry?.id ?? "").trim();
if (!id) continue;
const provider = String(entry?.provider ?? "").trim();
if (!provider) continue;
const name = String(entry?.name ?? id).trim() || id;
const contextWindow =
typeof entry?.contextWindow === "number" && entry.contextWindow > 0
? entry.contextWindow
: undefined;
models.push({ id, name, provider, contextWindow });
} }
return models.sort((a, b) => { return models.sort((a, b) => {

View File

@@ -7,11 +7,11 @@ import type { AppMessage, ThinkingLevel } from "@mariozechner/pi-agent-core";
import { import {
type Api, type Api,
type AssistantMessage, type AssistantMessage,
getEnvApiKey,
getOAuthApiKey,
type Model, type Model,
type OAuthCredentials, type OAuthCredentials,
type OAuthProvider, type OAuthProvider,
getEnvApiKey,
getOAuthApiKey,
} from "@mariozechner/pi-ai"; } from "@mariozechner/pi-ai";
import { import {
buildSystemPrompt, buildSystemPrompt,
@@ -213,16 +213,32 @@ function resolveModel(
provider: string, provider: string,
modelId: string, modelId: string,
agentDir?: string, agentDir?: string,
): { model?: Model<Api>; error?: string } { ): {
model?: Model<Api>;
error?: string;
authStorage: ReturnType<typeof discoverAuthStorage>;
modelRegistry: ReturnType<typeof discoverModels>;
} {
const resolvedAgentDir = agentDir ?? resolveClawdisAgentDir(); const resolvedAgentDir = agentDir ?? resolveClawdisAgentDir();
const authStorage = discoverAuthStorage(resolvedAgentDir); const authStorage = discoverAuthStorage(resolvedAgentDir);
const modelRegistry = discoverModels(authStorage, resolvedAgentDir); const modelRegistry = discoverModels(authStorage, resolvedAgentDir);
const model = modelRegistry.find(provider, modelId) as Model<Api> | null; const model = modelRegistry.find(provider, modelId) as Model<Api> | null;
if (!model) return { error: `Unknown model: ${provider}/${modelId}` }; if (!model) {
return { model }; return {
error: `Unknown model: ${provider}/${modelId}`,
authStorage,
modelRegistry,
};
}
return { model, authStorage, modelRegistry };
} }
async function getApiKeyForModel(model: Model<Api>): Promise<string> { async function getApiKeyForModel(
model: Model<Api>,
authStorage: ReturnType<typeof discoverAuthStorage>,
): Promise<string> {
const storedKey = await authStorage.getApiKey(model.provider);
if (storedKey) return storedKey;
ensureOAuthStorage(); ensureOAuthStorage();
if (model.provider === "anthropic") { if (model.provider === "anthropic") {
const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN; const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN;
@@ -320,10 +336,16 @@ export async function runEmbeddedPiAgent(params: {
const modelId = (params.model ?? DEFAULT_MODEL).trim() || DEFAULT_MODEL; const modelId = (params.model ?? DEFAULT_MODEL).trim() || DEFAULT_MODEL;
await ensureClawdisModelsJson(params.config); await ensureClawdisModelsJson(params.config);
const agentDir = resolveClawdisAgentDir(); const agentDir = resolveClawdisAgentDir();
const { model, error } = resolveModel(provider, modelId, agentDir); const { model, error, authStorage, modelRegistry } = resolveModel(
provider,
modelId,
agentDir,
);
if (!model) { if (!model) {
throw new Error(error ?? `Unknown model: ${provider}/${modelId}`); throw new Error(error ?? `Unknown model: ${provider}/${modelId}`);
} }
const apiKey = await getApiKeyForModel(model, authStorage);
authStorage.setRuntimeApiKey(model.provider, apiKey);
const thinkingLevel = mapThinkingLevel(params.thinkLevel); const thinkingLevel = mapThinkingLevel(params.thinkLevel);
@@ -402,6 +424,8 @@ export async function runEmbeddedPiAgent(params: {
const { session } = await createAgentSession({ const { session } = await createAgentSession({
cwd: resolvedWorkspace, cwd: resolvedWorkspace,
agentDir, agentDir,
authStorage,
modelRegistry,
model, model,
thinkingLevel, thinkingLevel,
systemPrompt, systemPrompt,
@@ -410,9 +434,6 @@ export async function runEmbeddedPiAgent(params: {
tools, tools,
sessionManager, sessionManager,
settingsManager, settingsManager,
getApiKey: async (m) => {
return await getApiKeyForModel(m as Model<Api>);
},
skills: promptSkills, skills: promptSkills,
contextFiles, contextFiles,
}); });

View File

@@ -1,6 +1,5 @@
import { describe, expect, it, vi } from "vitest";
import type { AssistantMessage } from "@mariozechner/pi-ai"; import type { AssistantMessage } from "@mariozechner/pi-ai";
import { describe, expect, it, vi } from "vitest";
import { subscribeEmbeddedPiSession } from "./pi-embedded-subscribe.js"; import { subscribeEmbeddedPiSession } from "./pi-embedded-subscribe.js";
@@ -8,6 +7,8 @@ type StubSession = {
subscribe: (fn: (evt: unknown) => void) => () => void; subscribe: (fn: (evt: unknown) => void) => () => void;
}; };
type SessionEventHandler = (evt: unknown) => void;
describe("subscribeEmbeddedPiSession", () => { describe("subscribeEmbeddedPiSession", () => {
it("filters to <final> and falls back when tags are malformed", () => { it("filters to <final> and falls back when tags are malformed", () => {
let handler: ((evt: unknown) => void) | undefined; let handler: ((evt: unknown) => void) | undefined;
@@ -97,16 +98,16 @@ describe("subscribeEmbeddedPiSession", () => {
}); });
it("waits for auto-compaction retry and clears buffered text", async () => { it("waits for auto-compaction retry and clears buffered text", async () => {
const listeners: Array<(evt: any) => void> = []; const listeners: SessionEventHandler[] = [];
const session = { const session = {
subscribe: (listener: (evt: any) => void) => { subscribe: (listener: SessionEventHandler) => {
listeners.push(listener); listeners.push(listener);
return () => { return () => {
const index = listeners.indexOf(listener); const index = listeners.indexOf(listener);
if (index !== -1) listeners.splice(index, 1); if (index !== -1) listeners.splice(index, 1);
}; };
}, },
} as any; } as unknown as Parameters<typeof subscribeEmbeddedPiSession>[0]["session"];
const subscription = subscribeEmbeddedPiSession({ const subscription = subscribeEmbeddedPiSession({
session, session,
@@ -150,13 +151,13 @@ describe("subscribeEmbeddedPiSession", () => {
}); });
it("resolves after compaction ends without retry", async () => { it("resolves after compaction ends without retry", async () => {
const listeners: Array<(evt: any) => void> = []; const listeners: SessionEventHandler[] = [];
const session = { const session = {
subscribe: (listener: (evt: any) => void) => { subscribe: (listener: SessionEventHandler) => {
listeners.push(listener); listeners.push(listener);
return () => {}; return () => {};
}, },
} as any; } as unknown as Parameters<typeof subscribeEmbeddedPiSession>[0]["session"];
const subscription = subscribeEmbeddedPiSession({ const subscription = subscribeEmbeddedPiSession({
session, session,
@@ -184,13 +185,13 @@ describe("subscribeEmbeddedPiSession", () => {
}); });
it("waits for multiple compaction retries before resolving", async () => { it("waits for multiple compaction retries before resolving", async () => {
const listeners: Array<(evt: any) => void> = []; const listeners: SessionEventHandler[] = [];
const session = { const session = {
subscribe: (listener: (evt: any) => void) => { subscribe: (listener: SessionEventHandler) => {
listeners.push(listener); listeners.push(listener);
return () => {}; return () => {};
}, },
} as any; } as unknown as Parameters<typeof subscribeEmbeddedPiSession>[0]["session"];
const subscription = subscribeEmbeddedPiSession({ const subscription = subscribeEmbeddedPiSession({
session, session,

View File

@@ -325,9 +325,7 @@ export function subscribeEmbeddedPiSession(params: {
if (evt.type === "auto_compaction_end") { if (evt.type === "auto_compaction_end") {
compactionInFlight = false; compactionInFlight = false;
const willRetry = Boolean( const willRetry = Boolean((evt as { willRetry?: unknown }).willRetry);
(evt as { willRetry?: unknown }).willRetry,
);
if (willRetry) { if (willRetry) {
noteCompactionRetry(); noteCompactionRetry();
resetForCompactionRetry(); resetForCompactionRetry();
@@ -357,7 +355,7 @@ export function subscribeEmbeddedPiSession(params: {
ensureCompactionPromise(); ensureCompactionPromise();
return compactionRetryPromise ?? Promise.resolve(); return compactionRetryPromise ?? Promise.resolve();
} }
return new Promise((resolve) => { return new Promise<void>((resolve) => {
queueMicrotask(() => { queueMicrotask(() => {
if (compactionInFlight || pendingCompactionRetry > 0) { if (compactionInFlight || pendingCompactionRetry > 0) {
ensureCompactionPromise(); ensureCompactionPromise();

View File

@@ -333,12 +333,25 @@ export async function monitorWebInbox(options: {
return { return {
close: async () => { close: async () => {
try { try {
if (typeof sock.ev.off === "function") { const ev = sock.ev as unknown as {
sock.ev.off("messages.upsert", handleMessagesUpsert); off?: (event: string, listener: (...args: unknown[]) => void) => void;
sock.ev.off("connection.update", handleConnectionUpdate); removeListener?: (
} else { event: string,
sock.ev.removeListener?.("messages.upsert", handleMessagesUpsert); listener: (...args: unknown[]) => void,
sock.ev.removeListener?.("connection.update", handleConnectionUpdate); ) => void;
};
const messagesUpsertHandler = handleMessagesUpsert as unknown as (
...args: unknown[]
) => void;
const connectionUpdateHandler = handleConnectionUpdate as unknown as (
...args: unknown[]
) => void;
if (typeof ev.off === "function") {
ev.off("messages.upsert", messagesUpsertHandler);
ev.off("connection.update", connectionUpdateHandler);
} else if (typeof ev.removeListener === "function") {
ev.removeListener("messages.upsert", messagesUpsertHandler);
ev.removeListener("connection.update", connectionUpdateHandler);
} }
sock.ws?.close(); sock.ws?.close();
} catch (err) { } catch (err) {