fix: align pi model discovery with auth storage

This commit is contained in:
Peter Steinberger
2025-12-26 11:49:13 +01:00
parent d31c5d7a2c
commit 82ced33747
6 changed files with 90 additions and 63 deletions

View File

@@ -10,10 +10,15 @@ type ModelEntry = { id: string; contextWindow?: number };
const MODEL_CACHE = new Map<string, number>();
const loadPromise = (async () => {
try {
const { discoverModels } = await import("@mariozechner/pi-coding-agent");
const { discoverAuthStorage, discoverModels } = await import(
"@mariozechner/pi-coding-agent"
);
const cfg = loadConfig();
await ensureClawdisModelsJson(cfg);
const models = discoverModels(resolveClawdisAgentDir()) as ModelEntry[];
const agentDir = resolveClawdisAgentDir();
const authStorage = discoverAuthStorage(agentDir);
const modelRegistry = discoverModels(authStorage, agentDir);
const models = modelRegistry.getAll() as ModelEntry[];
for (const m of models) {
if (!m?.id) continue;
if (typeof m.contextWindow === "number" && m.contextWindow > 0) {

View File

@@ -25,41 +25,30 @@ export async function loadModelCatalog(params?: {
if (modelCatalogPromise) return modelCatalogPromise;
modelCatalogPromise = (async () => {
const piSdk = (await import("@mariozechner/pi-coding-agent")) as {
discoverModels: (agentDir?: string) => Array<{
id: string;
name?: string;
provider: string;
contextWindow?: number;
}>;
};
const piSdk = await import("@mariozechner/pi-coding-agent");
let entries: Array<{
id: string;
name?: string;
provider: string;
contextWindow?: number;
}> = [];
const models: ModelCatalogEntry[] = [];
try {
const cfg = params?.config ?? loadConfig();
await ensureClawdisModelsJson(cfg);
entries = piSdk.discoverModels(resolveClawdisAgentDir());
const agentDir = resolveClawdisAgentDir();
const authStorage = piSdk.discoverAuthStorage(agentDir);
const registry = piSdk.discoverModels(authStorage, agentDir);
const entries = registry.getAll();
for (const entry of entries) {
const id = String(entry?.id ?? "").trim();
if (!id) continue;
const provider = String(entry?.provider ?? "").trim();
if (!provider) continue;
const name = String(entry?.name ?? id).trim() || id;
const contextWindow =
typeof entry?.contextWindow === "number" && entry.contextWindow > 0
? entry.contextWindow
: undefined;
models.push({ id, name, provider, contextWindow });
}
} catch {
entries = [];
}
const models: ModelCatalogEntry[] = [];
for (const entry of entries) {
const id = String(entry?.id ?? "").trim();
if (!id) continue;
const provider = String(entry?.provider ?? "").trim();
if (!provider) continue;
const name = String(entry?.name ?? id).trim() || id;
const contextWindow =
typeof entry?.contextWindow === "number" && entry.contextWindow > 0
? entry.contextWindow
: undefined;
models.push({ id, name, provider, contextWindow });
// Leave models empty on discovery errors.
}
return models.sort((a, b) => {

View File

@@ -7,11 +7,11 @@ import type { AppMessage, ThinkingLevel } from "@mariozechner/pi-agent-core";
import {
type Api,
type AssistantMessage,
getEnvApiKey,
getOAuthApiKey,
type Model,
type OAuthCredentials,
type OAuthProvider,
getEnvApiKey,
getOAuthApiKey,
} from "@mariozechner/pi-ai";
import {
buildSystemPrompt,
@@ -213,16 +213,32 @@ function resolveModel(
provider: string,
modelId: string,
agentDir?: string,
): { model?: Model<Api>; error?: string } {
): {
model?: Model<Api>;
error?: string;
authStorage: ReturnType<typeof discoverAuthStorage>;
modelRegistry: ReturnType<typeof discoverModels>;
} {
const resolvedAgentDir = agentDir ?? resolveClawdisAgentDir();
const authStorage = discoverAuthStorage(resolvedAgentDir);
const modelRegistry = discoverModels(authStorage, resolvedAgentDir);
const model = modelRegistry.find(provider, modelId) as Model<Api> | null;
if (!model) return { error: `Unknown model: ${provider}/${modelId}` };
return { model };
if (!model) {
return {
error: `Unknown model: ${provider}/${modelId}`,
authStorage,
modelRegistry,
};
}
return { model, authStorage, modelRegistry };
}
async function getApiKeyForModel(model: Model<Api>): Promise<string> {
async function getApiKeyForModel(
model: Model<Api>,
authStorage: ReturnType<typeof discoverAuthStorage>,
): Promise<string> {
const storedKey = await authStorage.getApiKey(model.provider);
if (storedKey) return storedKey;
ensureOAuthStorage();
if (model.provider === "anthropic") {
const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN;
@@ -320,10 +336,16 @@ export async function runEmbeddedPiAgent(params: {
const modelId = (params.model ?? DEFAULT_MODEL).trim() || DEFAULT_MODEL;
await ensureClawdisModelsJson(params.config);
const agentDir = resolveClawdisAgentDir();
const { model, error } = resolveModel(provider, modelId, agentDir);
const { model, error, authStorage, modelRegistry } = resolveModel(
provider,
modelId,
agentDir,
);
if (!model) {
throw new Error(error ?? `Unknown model: ${provider}/${modelId}`);
}
const apiKey = await getApiKeyForModel(model, authStorage);
authStorage.setRuntimeApiKey(model.provider, apiKey);
const thinkingLevel = mapThinkingLevel(params.thinkLevel);
@@ -402,6 +424,8 @@ export async function runEmbeddedPiAgent(params: {
const { session } = await createAgentSession({
cwd: resolvedWorkspace,
agentDir,
authStorage,
modelRegistry,
model,
thinkingLevel,
systemPrompt,
@@ -410,9 +434,6 @@ export async function runEmbeddedPiAgent(params: {
tools,
sessionManager,
settingsManager,
getApiKey: async (m) => {
return await getApiKeyForModel(m as Model<Api>);
},
skills: promptSkills,
contextFiles,
});

View File

@@ -1,6 +1,5 @@
import { describe, expect, it, vi } from "vitest";
import type { AssistantMessage } from "@mariozechner/pi-ai";
import { describe, expect, it, vi } from "vitest";
import { subscribeEmbeddedPiSession } from "./pi-embedded-subscribe.js";
@@ -8,6 +7,8 @@ type StubSession = {
subscribe: (fn: (evt: unknown) => void) => () => void;
};
type SessionEventHandler = (evt: unknown) => void;
describe("subscribeEmbeddedPiSession", () => {
it("filters to <final> and falls back when tags are malformed", () => {
let handler: ((evt: unknown) => void) | undefined;
@@ -97,16 +98,16 @@ describe("subscribeEmbeddedPiSession", () => {
});
it("waits for auto-compaction retry and clears buffered text", async () => {
const listeners: Array<(evt: any) => void> = [];
const listeners: SessionEventHandler[] = [];
const session = {
subscribe: (listener: (evt: any) => void) => {
subscribe: (listener: SessionEventHandler) => {
listeners.push(listener);
return () => {
const index = listeners.indexOf(listener);
if (index !== -1) listeners.splice(index, 1);
};
},
} as any;
} as unknown as Parameters<typeof subscribeEmbeddedPiSession>[0]["session"];
const subscription = subscribeEmbeddedPiSession({
session,
@@ -150,13 +151,13 @@ describe("subscribeEmbeddedPiSession", () => {
});
it("resolves after compaction ends without retry", async () => {
const listeners: Array<(evt: any) => void> = [];
const listeners: SessionEventHandler[] = [];
const session = {
subscribe: (listener: (evt: any) => void) => {
subscribe: (listener: SessionEventHandler) => {
listeners.push(listener);
return () => {};
},
} as any;
} as unknown as Parameters<typeof subscribeEmbeddedPiSession>[0]["session"];
const subscription = subscribeEmbeddedPiSession({
session,
@@ -184,13 +185,13 @@ describe("subscribeEmbeddedPiSession", () => {
});
it("waits for multiple compaction retries before resolving", async () => {
const listeners: Array<(evt: any) => void> = [];
const listeners: SessionEventHandler[] = [];
const session = {
subscribe: (listener: (evt: any) => void) => {
subscribe: (listener: SessionEventHandler) => {
listeners.push(listener);
return () => {};
},
} as any;
} as unknown as Parameters<typeof subscribeEmbeddedPiSession>[0]["session"];
const subscription = subscribeEmbeddedPiSession({
session,

View File

@@ -325,9 +325,7 @@ export function subscribeEmbeddedPiSession(params: {
if (evt.type === "auto_compaction_end") {
compactionInFlight = false;
const willRetry = Boolean(
(evt as { willRetry?: unknown }).willRetry,
);
const willRetry = Boolean((evt as { willRetry?: unknown }).willRetry);
if (willRetry) {
noteCompactionRetry();
resetForCompactionRetry();
@@ -357,7 +355,7 @@ export function subscribeEmbeddedPiSession(params: {
ensureCompactionPromise();
return compactionRetryPromise ?? Promise.resolve();
}
return new Promise((resolve) => {
return new Promise<void>((resolve) => {
queueMicrotask(() => {
if (compactionInFlight || pendingCompactionRetry > 0) {
ensureCompactionPromise();

View File

@@ -333,12 +333,25 @@ export async function monitorWebInbox(options: {
return {
close: async () => {
try {
if (typeof sock.ev.off === "function") {
sock.ev.off("messages.upsert", handleMessagesUpsert);
sock.ev.off("connection.update", handleConnectionUpdate);
} else {
sock.ev.removeListener?.("messages.upsert", handleMessagesUpsert);
sock.ev.removeListener?.("connection.update", handleConnectionUpdate);
const ev = sock.ev as unknown as {
off?: (event: string, listener: (...args: unknown[]) => void) => void;
removeListener?: (
event: string,
listener: (...args: unknown[]) => void,
) => void;
};
const messagesUpsertHandler = handleMessagesUpsert as unknown as (
...args: unknown[]
) => void;
const connectionUpdateHandler = handleConnectionUpdate as unknown as (
...args: unknown[]
) => void;
if (typeof ev.off === "function") {
ev.off("messages.upsert", messagesUpsertHandler);
ev.off("connection.update", connectionUpdateHandler);
} else if (typeof ev.removeListener === "function") {
ev.removeListener("messages.upsert", messagesUpsertHandler);
ev.removeListener("connection.update", connectionUpdateHandler);
}
sock.ws?.close();
} catch (err) {