fix: align pi model discovery with auth storage
This commit is contained in:
@@ -10,10 +10,15 @@ type ModelEntry = { id: string; contextWindow?: number };
|
||||
const MODEL_CACHE = new Map<string, number>();
|
||||
const loadPromise = (async () => {
|
||||
try {
|
||||
const { discoverModels } = await import("@mariozechner/pi-coding-agent");
|
||||
const { discoverAuthStorage, discoverModels } = await import(
|
||||
"@mariozechner/pi-coding-agent"
|
||||
);
|
||||
const cfg = loadConfig();
|
||||
await ensureClawdisModelsJson(cfg);
|
||||
const models = discoverModels(resolveClawdisAgentDir()) as ModelEntry[];
|
||||
const agentDir = resolveClawdisAgentDir();
|
||||
const authStorage = discoverAuthStorage(agentDir);
|
||||
const modelRegistry = discoverModels(authStorage, agentDir);
|
||||
const models = modelRegistry.getAll() as ModelEntry[];
|
||||
for (const m of models) {
|
||||
if (!m?.id) continue;
|
||||
if (typeof m.contextWindow === "number" && m.contextWindow > 0) {
|
||||
|
||||
@@ -25,41 +25,30 @@ export async function loadModelCatalog(params?: {
|
||||
if (modelCatalogPromise) return modelCatalogPromise;
|
||||
|
||||
modelCatalogPromise = (async () => {
|
||||
const piSdk = (await import("@mariozechner/pi-coding-agent")) as {
|
||||
discoverModels: (agentDir?: string) => Array<{
|
||||
id: string;
|
||||
name?: string;
|
||||
provider: string;
|
||||
contextWindow?: number;
|
||||
}>;
|
||||
};
|
||||
const piSdk = await import("@mariozechner/pi-coding-agent");
|
||||
|
||||
let entries: Array<{
|
||||
id: string;
|
||||
name?: string;
|
||||
provider: string;
|
||||
contextWindow?: number;
|
||||
}> = [];
|
||||
const models: ModelCatalogEntry[] = [];
|
||||
try {
|
||||
const cfg = params?.config ?? loadConfig();
|
||||
await ensureClawdisModelsJson(cfg);
|
||||
entries = piSdk.discoverModels(resolveClawdisAgentDir());
|
||||
const agentDir = resolveClawdisAgentDir();
|
||||
const authStorage = piSdk.discoverAuthStorage(agentDir);
|
||||
const registry = piSdk.discoverModels(authStorage, agentDir);
|
||||
const entries = registry.getAll();
|
||||
for (const entry of entries) {
|
||||
const id = String(entry?.id ?? "").trim();
|
||||
if (!id) continue;
|
||||
const provider = String(entry?.provider ?? "").trim();
|
||||
if (!provider) continue;
|
||||
const name = String(entry?.name ?? id).trim() || id;
|
||||
const contextWindow =
|
||||
typeof entry?.contextWindow === "number" && entry.contextWindow > 0
|
||||
? entry.contextWindow
|
||||
: undefined;
|
||||
models.push({ id, name, provider, contextWindow });
|
||||
}
|
||||
} catch {
|
||||
entries = [];
|
||||
}
|
||||
|
||||
const models: ModelCatalogEntry[] = [];
|
||||
for (const entry of entries) {
|
||||
const id = String(entry?.id ?? "").trim();
|
||||
if (!id) continue;
|
||||
const provider = String(entry?.provider ?? "").trim();
|
||||
if (!provider) continue;
|
||||
const name = String(entry?.name ?? id).trim() || id;
|
||||
const contextWindow =
|
||||
typeof entry?.contextWindow === "number" && entry.contextWindow > 0
|
||||
? entry.contextWindow
|
||||
: undefined;
|
||||
models.push({ id, name, provider, contextWindow });
|
||||
// Leave models empty on discovery errors.
|
||||
}
|
||||
|
||||
return models.sort((a, b) => {
|
||||
|
||||
@@ -7,11 +7,11 @@ import type { AppMessage, ThinkingLevel } from "@mariozechner/pi-agent-core";
|
||||
import {
|
||||
type Api,
|
||||
type AssistantMessage,
|
||||
getEnvApiKey,
|
||||
getOAuthApiKey,
|
||||
type Model,
|
||||
type OAuthCredentials,
|
||||
type OAuthProvider,
|
||||
getEnvApiKey,
|
||||
getOAuthApiKey,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import {
|
||||
buildSystemPrompt,
|
||||
@@ -213,16 +213,32 @@ function resolveModel(
|
||||
provider: string,
|
||||
modelId: string,
|
||||
agentDir?: string,
|
||||
): { model?: Model<Api>; error?: string } {
|
||||
): {
|
||||
model?: Model<Api>;
|
||||
error?: string;
|
||||
authStorage: ReturnType<typeof discoverAuthStorage>;
|
||||
modelRegistry: ReturnType<typeof discoverModels>;
|
||||
} {
|
||||
const resolvedAgentDir = agentDir ?? resolveClawdisAgentDir();
|
||||
const authStorage = discoverAuthStorage(resolvedAgentDir);
|
||||
const modelRegistry = discoverModels(authStorage, resolvedAgentDir);
|
||||
const model = modelRegistry.find(provider, modelId) as Model<Api> | null;
|
||||
if (!model) return { error: `Unknown model: ${provider}/${modelId}` };
|
||||
return { model };
|
||||
if (!model) {
|
||||
return {
|
||||
error: `Unknown model: ${provider}/${modelId}`,
|
||||
authStorage,
|
||||
modelRegistry,
|
||||
};
|
||||
}
|
||||
return { model, authStorage, modelRegistry };
|
||||
}
|
||||
|
||||
async function getApiKeyForModel(model: Model<Api>): Promise<string> {
|
||||
async function getApiKeyForModel(
|
||||
model: Model<Api>,
|
||||
authStorage: ReturnType<typeof discoverAuthStorage>,
|
||||
): Promise<string> {
|
||||
const storedKey = await authStorage.getApiKey(model.provider);
|
||||
if (storedKey) return storedKey;
|
||||
ensureOAuthStorage();
|
||||
if (model.provider === "anthropic") {
|
||||
const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN;
|
||||
@@ -320,10 +336,16 @@ export async function runEmbeddedPiAgent(params: {
|
||||
const modelId = (params.model ?? DEFAULT_MODEL).trim() || DEFAULT_MODEL;
|
||||
await ensureClawdisModelsJson(params.config);
|
||||
const agentDir = resolveClawdisAgentDir();
|
||||
const { model, error } = resolveModel(provider, modelId, agentDir);
|
||||
const { model, error, authStorage, modelRegistry } = resolveModel(
|
||||
provider,
|
||||
modelId,
|
||||
agentDir,
|
||||
);
|
||||
if (!model) {
|
||||
throw new Error(error ?? `Unknown model: ${provider}/${modelId}`);
|
||||
}
|
||||
const apiKey = await getApiKeyForModel(model, authStorage);
|
||||
authStorage.setRuntimeApiKey(model.provider, apiKey);
|
||||
|
||||
const thinkingLevel = mapThinkingLevel(params.thinkLevel);
|
||||
|
||||
@@ -402,6 +424,8 @@ export async function runEmbeddedPiAgent(params: {
|
||||
const { session } = await createAgentSession({
|
||||
cwd: resolvedWorkspace,
|
||||
agentDir,
|
||||
authStorage,
|
||||
modelRegistry,
|
||||
model,
|
||||
thinkingLevel,
|
||||
systemPrompt,
|
||||
@@ -410,9 +434,6 @@ export async function runEmbeddedPiAgent(params: {
|
||||
tools,
|
||||
sessionManager,
|
||||
settingsManager,
|
||||
getApiKey: async (m) => {
|
||||
return await getApiKeyForModel(m as Model<Api>);
|
||||
},
|
||||
skills: promptSkills,
|
||||
contextFiles,
|
||||
});
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
|
||||
import type { AssistantMessage } from "@mariozechner/pi-ai";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
|
||||
import { subscribeEmbeddedPiSession } from "./pi-embedded-subscribe.js";
|
||||
|
||||
@@ -8,6 +7,8 @@ type StubSession = {
|
||||
subscribe: (fn: (evt: unknown) => void) => () => void;
|
||||
};
|
||||
|
||||
type SessionEventHandler = (evt: unknown) => void;
|
||||
|
||||
describe("subscribeEmbeddedPiSession", () => {
|
||||
it("filters to <final> and falls back when tags are malformed", () => {
|
||||
let handler: ((evt: unknown) => void) | undefined;
|
||||
@@ -97,16 +98,16 @@ describe("subscribeEmbeddedPiSession", () => {
|
||||
});
|
||||
|
||||
it("waits for auto-compaction retry and clears buffered text", async () => {
|
||||
const listeners: Array<(evt: any) => void> = [];
|
||||
const listeners: SessionEventHandler[] = [];
|
||||
const session = {
|
||||
subscribe: (listener: (evt: any) => void) => {
|
||||
subscribe: (listener: SessionEventHandler) => {
|
||||
listeners.push(listener);
|
||||
return () => {
|
||||
const index = listeners.indexOf(listener);
|
||||
if (index !== -1) listeners.splice(index, 1);
|
||||
};
|
||||
},
|
||||
} as any;
|
||||
} as unknown as Parameters<typeof subscribeEmbeddedPiSession>[0]["session"];
|
||||
|
||||
const subscription = subscribeEmbeddedPiSession({
|
||||
session,
|
||||
@@ -150,13 +151,13 @@ describe("subscribeEmbeddedPiSession", () => {
|
||||
});
|
||||
|
||||
it("resolves after compaction ends without retry", async () => {
|
||||
const listeners: Array<(evt: any) => void> = [];
|
||||
const listeners: SessionEventHandler[] = [];
|
||||
const session = {
|
||||
subscribe: (listener: (evt: any) => void) => {
|
||||
subscribe: (listener: SessionEventHandler) => {
|
||||
listeners.push(listener);
|
||||
return () => {};
|
||||
},
|
||||
} as any;
|
||||
} as unknown as Parameters<typeof subscribeEmbeddedPiSession>[0]["session"];
|
||||
|
||||
const subscription = subscribeEmbeddedPiSession({
|
||||
session,
|
||||
@@ -184,13 +185,13 @@ describe("subscribeEmbeddedPiSession", () => {
|
||||
});
|
||||
|
||||
it("waits for multiple compaction retries before resolving", async () => {
|
||||
const listeners: Array<(evt: any) => void> = [];
|
||||
const listeners: SessionEventHandler[] = [];
|
||||
const session = {
|
||||
subscribe: (listener: (evt: any) => void) => {
|
||||
subscribe: (listener: SessionEventHandler) => {
|
||||
listeners.push(listener);
|
||||
return () => {};
|
||||
},
|
||||
} as any;
|
||||
} as unknown as Parameters<typeof subscribeEmbeddedPiSession>[0]["session"];
|
||||
|
||||
const subscription = subscribeEmbeddedPiSession({
|
||||
session,
|
||||
|
||||
@@ -325,9 +325,7 @@ export function subscribeEmbeddedPiSession(params: {
|
||||
|
||||
if (evt.type === "auto_compaction_end") {
|
||||
compactionInFlight = false;
|
||||
const willRetry = Boolean(
|
||||
(evt as { willRetry?: unknown }).willRetry,
|
||||
);
|
||||
const willRetry = Boolean((evt as { willRetry?: unknown }).willRetry);
|
||||
if (willRetry) {
|
||||
noteCompactionRetry();
|
||||
resetForCompactionRetry();
|
||||
@@ -357,7 +355,7 @@ export function subscribeEmbeddedPiSession(params: {
|
||||
ensureCompactionPromise();
|
||||
return compactionRetryPromise ?? Promise.resolve();
|
||||
}
|
||||
return new Promise((resolve) => {
|
||||
return new Promise<void>((resolve) => {
|
||||
queueMicrotask(() => {
|
||||
if (compactionInFlight || pendingCompactionRetry > 0) {
|
||||
ensureCompactionPromise();
|
||||
|
||||
@@ -333,12 +333,25 @@ export async function monitorWebInbox(options: {
|
||||
return {
|
||||
close: async () => {
|
||||
try {
|
||||
if (typeof sock.ev.off === "function") {
|
||||
sock.ev.off("messages.upsert", handleMessagesUpsert);
|
||||
sock.ev.off("connection.update", handleConnectionUpdate);
|
||||
} else {
|
||||
sock.ev.removeListener?.("messages.upsert", handleMessagesUpsert);
|
||||
sock.ev.removeListener?.("connection.update", handleConnectionUpdate);
|
||||
const ev = sock.ev as unknown as {
|
||||
off?: (event: string, listener: (...args: unknown[]) => void) => void;
|
||||
removeListener?: (
|
||||
event: string,
|
||||
listener: (...args: unknown[]) => void,
|
||||
) => void;
|
||||
};
|
||||
const messagesUpsertHandler = handleMessagesUpsert as unknown as (
|
||||
...args: unknown[]
|
||||
) => void;
|
||||
const connectionUpdateHandler = handleConnectionUpdate as unknown as (
|
||||
...args: unknown[]
|
||||
) => void;
|
||||
if (typeof ev.off === "function") {
|
||||
ev.off("messages.upsert", messagesUpsertHandler);
|
||||
ev.off("connection.update", connectionUpdateHandler);
|
||||
} else if (typeof ev.removeListener === "function") {
|
||||
ev.removeListener("messages.upsert", messagesUpsertHandler);
|
||||
ev.removeListener("connection.update", connectionUpdateHandler);
|
||||
}
|
||||
sock.ws?.close();
|
||||
} catch (err) {
|
||||
|
||||
Reference in New Issue
Block a user