fix: align pi model discovery with auth storage
This commit is contained in:
@@ -10,10 +10,15 @@ type ModelEntry = { id: string; contextWindow?: number };
|
|||||||
const MODEL_CACHE = new Map<string, number>();
|
const MODEL_CACHE = new Map<string, number>();
|
||||||
const loadPromise = (async () => {
|
const loadPromise = (async () => {
|
||||||
try {
|
try {
|
||||||
const { discoverModels } = await import("@mariozechner/pi-coding-agent");
|
const { discoverAuthStorage, discoverModels } = await import(
|
||||||
|
"@mariozechner/pi-coding-agent"
|
||||||
|
);
|
||||||
const cfg = loadConfig();
|
const cfg = loadConfig();
|
||||||
await ensureClawdisModelsJson(cfg);
|
await ensureClawdisModelsJson(cfg);
|
||||||
const models = discoverModels(resolveClawdisAgentDir()) as ModelEntry[];
|
const agentDir = resolveClawdisAgentDir();
|
||||||
|
const authStorage = discoverAuthStorage(agentDir);
|
||||||
|
const modelRegistry = discoverModels(authStorage, agentDir);
|
||||||
|
const models = modelRegistry.getAll() as ModelEntry[];
|
||||||
for (const m of models) {
|
for (const m of models) {
|
||||||
if (!m?.id) continue;
|
if (!m?.id) continue;
|
||||||
if (typeof m.contextWindow === "number" && m.contextWindow > 0) {
|
if (typeof m.contextWindow === "number" && m.contextWindow > 0) {
|
||||||
|
|||||||
@@ -25,41 +25,30 @@ export async function loadModelCatalog(params?: {
|
|||||||
if (modelCatalogPromise) return modelCatalogPromise;
|
if (modelCatalogPromise) return modelCatalogPromise;
|
||||||
|
|
||||||
modelCatalogPromise = (async () => {
|
modelCatalogPromise = (async () => {
|
||||||
const piSdk = (await import("@mariozechner/pi-coding-agent")) as {
|
const piSdk = await import("@mariozechner/pi-coding-agent");
|
||||||
discoverModels: (agentDir?: string) => Array<{
|
|
||||||
id: string;
|
|
||||||
name?: string;
|
|
||||||
provider: string;
|
|
||||||
contextWindow?: number;
|
|
||||||
}>;
|
|
||||||
};
|
|
||||||
|
|
||||||
let entries: Array<{
|
const models: ModelCatalogEntry[] = [];
|
||||||
id: string;
|
|
||||||
name?: string;
|
|
||||||
provider: string;
|
|
||||||
contextWindow?: number;
|
|
||||||
}> = [];
|
|
||||||
try {
|
try {
|
||||||
const cfg = params?.config ?? loadConfig();
|
const cfg = params?.config ?? loadConfig();
|
||||||
await ensureClawdisModelsJson(cfg);
|
await ensureClawdisModelsJson(cfg);
|
||||||
entries = piSdk.discoverModels(resolveClawdisAgentDir());
|
const agentDir = resolveClawdisAgentDir();
|
||||||
|
const authStorage = piSdk.discoverAuthStorage(agentDir);
|
||||||
|
const registry = piSdk.discoverModels(authStorage, agentDir);
|
||||||
|
const entries = registry.getAll();
|
||||||
|
for (const entry of entries) {
|
||||||
|
const id = String(entry?.id ?? "").trim();
|
||||||
|
if (!id) continue;
|
||||||
|
const provider = String(entry?.provider ?? "").trim();
|
||||||
|
if (!provider) continue;
|
||||||
|
const name = String(entry?.name ?? id).trim() || id;
|
||||||
|
const contextWindow =
|
||||||
|
typeof entry?.contextWindow === "number" && entry.contextWindow > 0
|
||||||
|
? entry.contextWindow
|
||||||
|
: undefined;
|
||||||
|
models.push({ id, name, provider, contextWindow });
|
||||||
|
}
|
||||||
} catch {
|
} catch {
|
||||||
entries = [];
|
// Leave models empty on discovery errors.
|
||||||
}
|
|
||||||
|
|
||||||
const models: ModelCatalogEntry[] = [];
|
|
||||||
for (const entry of entries) {
|
|
||||||
const id = String(entry?.id ?? "").trim();
|
|
||||||
if (!id) continue;
|
|
||||||
const provider = String(entry?.provider ?? "").trim();
|
|
||||||
if (!provider) continue;
|
|
||||||
const name = String(entry?.name ?? id).trim() || id;
|
|
||||||
const contextWindow =
|
|
||||||
typeof entry?.contextWindow === "number" && entry.contextWindow > 0
|
|
||||||
? entry.contextWindow
|
|
||||||
: undefined;
|
|
||||||
models.push({ id, name, provider, contextWindow });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return models.sort((a, b) => {
|
return models.sort((a, b) => {
|
||||||
|
|||||||
@@ -7,11 +7,11 @@ import type { AppMessage, ThinkingLevel } from "@mariozechner/pi-agent-core";
|
|||||||
import {
|
import {
|
||||||
type Api,
|
type Api,
|
||||||
type AssistantMessage,
|
type AssistantMessage,
|
||||||
|
getEnvApiKey,
|
||||||
|
getOAuthApiKey,
|
||||||
type Model,
|
type Model,
|
||||||
type OAuthCredentials,
|
type OAuthCredentials,
|
||||||
type OAuthProvider,
|
type OAuthProvider,
|
||||||
getEnvApiKey,
|
|
||||||
getOAuthApiKey,
|
|
||||||
} from "@mariozechner/pi-ai";
|
} from "@mariozechner/pi-ai";
|
||||||
import {
|
import {
|
||||||
buildSystemPrompt,
|
buildSystemPrompt,
|
||||||
@@ -213,16 +213,32 @@ function resolveModel(
|
|||||||
provider: string,
|
provider: string,
|
||||||
modelId: string,
|
modelId: string,
|
||||||
agentDir?: string,
|
agentDir?: string,
|
||||||
): { model?: Model<Api>; error?: string } {
|
): {
|
||||||
|
model?: Model<Api>;
|
||||||
|
error?: string;
|
||||||
|
authStorage: ReturnType<typeof discoverAuthStorage>;
|
||||||
|
modelRegistry: ReturnType<typeof discoverModels>;
|
||||||
|
} {
|
||||||
const resolvedAgentDir = agentDir ?? resolveClawdisAgentDir();
|
const resolvedAgentDir = agentDir ?? resolveClawdisAgentDir();
|
||||||
const authStorage = discoverAuthStorage(resolvedAgentDir);
|
const authStorage = discoverAuthStorage(resolvedAgentDir);
|
||||||
const modelRegistry = discoverModels(authStorage, resolvedAgentDir);
|
const modelRegistry = discoverModels(authStorage, resolvedAgentDir);
|
||||||
const model = modelRegistry.find(provider, modelId) as Model<Api> | null;
|
const model = modelRegistry.find(provider, modelId) as Model<Api> | null;
|
||||||
if (!model) return { error: `Unknown model: ${provider}/${modelId}` };
|
if (!model) {
|
||||||
return { model };
|
return {
|
||||||
|
error: `Unknown model: ${provider}/${modelId}`,
|
||||||
|
authStorage,
|
||||||
|
modelRegistry,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return { model, authStorage, modelRegistry };
|
||||||
}
|
}
|
||||||
|
|
||||||
async function getApiKeyForModel(model: Model<Api>): Promise<string> {
|
async function getApiKeyForModel(
|
||||||
|
model: Model<Api>,
|
||||||
|
authStorage: ReturnType<typeof discoverAuthStorage>,
|
||||||
|
): Promise<string> {
|
||||||
|
const storedKey = await authStorage.getApiKey(model.provider);
|
||||||
|
if (storedKey) return storedKey;
|
||||||
ensureOAuthStorage();
|
ensureOAuthStorage();
|
||||||
if (model.provider === "anthropic") {
|
if (model.provider === "anthropic") {
|
||||||
const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN;
|
const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN;
|
||||||
@@ -320,10 +336,16 @@ export async function runEmbeddedPiAgent(params: {
|
|||||||
const modelId = (params.model ?? DEFAULT_MODEL).trim() || DEFAULT_MODEL;
|
const modelId = (params.model ?? DEFAULT_MODEL).trim() || DEFAULT_MODEL;
|
||||||
await ensureClawdisModelsJson(params.config);
|
await ensureClawdisModelsJson(params.config);
|
||||||
const agentDir = resolveClawdisAgentDir();
|
const agentDir = resolveClawdisAgentDir();
|
||||||
const { model, error } = resolveModel(provider, modelId, agentDir);
|
const { model, error, authStorage, modelRegistry } = resolveModel(
|
||||||
|
provider,
|
||||||
|
modelId,
|
||||||
|
agentDir,
|
||||||
|
);
|
||||||
if (!model) {
|
if (!model) {
|
||||||
throw new Error(error ?? `Unknown model: ${provider}/${modelId}`);
|
throw new Error(error ?? `Unknown model: ${provider}/${modelId}`);
|
||||||
}
|
}
|
||||||
|
const apiKey = await getApiKeyForModel(model, authStorage);
|
||||||
|
authStorage.setRuntimeApiKey(model.provider, apiKey);
|
||||||
|
|
||||||
const thinkingLevel = mapThinkingLevel(params.thinkLevel);
|
const thinkingLevel = mapThinkingLevel(params.thinkLevel);
|
||||||
|
|
||||||
@@ -402,6 +424,8 @@ export async function runEmbeddedPiAgent(params: {
|
|||||||
const { session } = await createAgentSession({
|
const { session } = await createAgentSession({
|
||||||
cwd: resolvedWorkspace,
|
cwd: resolvedWorkspace,
|
||||||
agentDir,
|
agentDir,
|
||||||
|
authStorage,
|
||||||
|
modelRegistry,
|
||||||
model,
|
model,
|
||||||
thinkingLevel,
|
thinkingLevel,
|
||||||
systemPrompt,
|
systemPrompt,
|
||||||
@@ -410,9 +434,6 @@ export async function runEmbeddedPiAgent(params: {
|
|||||||
tools,
|
tools,
|
||||||
sessionManager,
|
sessionManager,
|
||||||
settingsManager,
|
settingsManager,
|
||||||
getApiKey: async (m) => {
|
|
||||||
return await getApiKeyForModel(m as Model<Api>);
|
|
||||||
},
|
|
||||||
skills: promptSkills,
|
skills: promptSkills,
|
||||||
contextFiles,
|
contextFiles,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import { describe, expect, it, vi } from "vitest";
|
|
||||||
|
|
||||||
import type { AssistantMessage } from "@mariozechner/pi-ai";
|
import type { AssistantMessage } from "@mariozechner/pi-ai";
|
||||||
|
import { describe, expect, it, vi } from "vitest";
|
||||||
|
|
||||||
import { subscribeEmbeddedPiSession } from "./pi-embedded-subscribe.js";
|
import { subscribeEmbeddedPiSession } from "./pi-embedded-subscribe.js";
|
||||||
|
|
||||||
@@ -8,6 +7,8 @@ type StubSession = {
|
|||||||
subscribe: (fn: (evt: unknown) => void) => () => void;
|
subscribe: (fn: (evt: unknown) => void) => () => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
type SessionEventHandler = (evt: unknown) => void;
|
||||||
|
|
||||||
describe("subscribeEmbeddedPiSession", () => {
|
describe("subscribeEmbeddedPiSession", () => {
|
||||||
it("filters to <final> and falls back when tags are malformed", () => {
|
it("filters to <final> and falls back when tags are malformed", () => {
|
||||||
let handler: ((evt: unknown) => void) | undefined;
|
let handler: ((evt: unknown) => void) | undefined;
|
||||||
@@ -97,16 +98,16 @@ describe("subscribeEmbeddedPiSession", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it("waits for auto-compaction retry and clears buffered text", async () => {
|
it("waits for auto-compaction retry and clears buffered text", async () => {
|
||||||
const listeners: Array<(evt: any) => void> = [];
|
const listeners: SessionEventHandler[] = [];
|
||||||
const session = {
|
const session = {
|
||||||
subscribe: (listener: (evt: any) => void) => {
|
subscribe: (listener: SessionEventHandler) => {
|
||||||
listeners.push(listener);
|
listeners.push(listener);
|
||||||
return () => {
|
return () => {
|
||||||
const index = listeners.indexOf(listener);
|
const index = listeners.indexOf(listener);
|
||||||
if (index !== -1) listeners.splice(index, 1);
|
if (index !== -1) listeners.splice(index, 1);
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
} as any;
|
} as unknown as Parameters<typeof subscribeEmbeddedPiSession>[0]["session"];
|
||||||
|
|
||||||
const subscription = subscribeEmbeddedPiSession({
|
const subscription = subscribeEmbeddedPiSession({
|
||||||
session,
|
session,
|
||||||
@@ -150,13 +151,13 @@ describe("subscribeEmbeddedPiSession", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it("resolves after compaction ends without retry", async () => {
|
it("resolves after compaction ends without retry", async () => {
|
||||||
const listeners: Array<(evt: any) => void> = [];
|
const listeners: SessionEventHandler[] = [];
|
||||||
const session = {
|
const session = {
|
||||||
subscribe: (listener: (evt: any) => void) => {
|
subscribe: (listener: SessionEventHandler) => {
|
||||||
listeners.push(listener);
|
listeners.push(listener);
|
||||||
return () => {};
|
return () => {};
|
||||||
},
|
},
|
||||||
} as any;
|
} as unknown as Parameters<typeof subscribeEmbeddedPiSession>[0]["session"];
|
||||||
|
|
||||||
const subscription = subscribeEmbeddedPiSession({
|
const subscription = subscribeEmbeddedPiSession({
|
||||||
session,
|
session,
|
||||||
@@ -184,13 +185,13 @@ describe("subscribeEmbeddedPiSession", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it("waits for multiple compaction retries before resolving", async () => {
|
it("waits for multiple compaction retries before resolving", async () => {
|
||||||
const listeners: Array<(evt: any) => void> = [];
|
const listeners: SessionEventHandler[] = [];
|
||||||
const session = {
|
const session = {
|
||||||
subscribe: (listener: (evt: any) => void) => {
|
subscribe: (listener: SessionEventHandler) => {
|
||||||
listeners.push(listener);
|
listeners.push(listener);
|
||||||
return () => {};
|
return () => {};
|
||||||
},
|
},
|
||||||
} as any;
|
} as unknown as Parameters<typeof subscribeEmbeddedPiSession>[0]["session"];
|
||||||
|
|
||||||
const subscription = subscribeEmbeddedPiSession({
|
const subscription = subscribeEmbeddedPiSession({
|
||||||
session,
|
session,
|
||||||
|
|||||||
@@ -325,9 +325,7 @@ export function subscribeEmbeddedPiSession(params: {
|
|||||||
|
|
||||||
if (evt.type === "auto_compaction_end") {
|
if (evt.type === "auto_compaction_end") {
|
||||||
compactionInFlight = false;
|
compactionInFlight = false;
|
||||||
const willRetry = Boolean(
|
const willRetry = Boolean((evt as { willRetry?: unknown }).willRetry);
|
||||||
(evt as { willRetry?: unknown }).willRetry,
|
|
||||||
);
|
|
||||||
if (willRetry) {
|
if (willRetry) {
|
||||||
noteCompactionRetry();
|
noteCompactionRetry();
|
||||||
resetForCompactionRetry();
|
resetForCompactionRetry();
|
||||||
@@ -357,7 +355,7 @@ export function subscribeEmbeddedPiSession(params: {
|
|||||||
ensureCompactionPromise();
|
ensureCompactionPromise();
|
||||||
return compactionRetryPromise ?? Promise.resolve();
|
return compactionRetryPromise ?? Promise.resolve();
|
||||||
}
|
}
|
||||||
return new Promise((resolve) => {
|
return new Promise<void>((resolve) => {
|
||||||
queueMicrotask(() => {
|
queueMicrotask(() => {
|
||||||
if (compactionInFlight || pendingCompactionRetry > 0) {
|
if (compactionInFlight || pendingCompactionRetry > 0) {
|
||||||
ensureCompactionPromise();
|
ensureCompactionPromise();
|
||||||
|
|||||||
@@ -333,12 +333,25 @@ export async function monitorWebInbox(options: {
|
|||||||
return {
|
return {
|
||||||
close: async () => {
|
close: async () => {
|
||||||
try {
|
try {
|
||||||
if (typeof sock.ev.off === "function") {
|
const ev = sock.ev as unknown as {
|
||||||
sock.ev.off("messages.upsert", handleMessagesUpsert);
|
off?: (event: string, listener: (...args: unknown[]) => void) => void;
|
||||||
sock.ev.off("connection.update", handleConnectionUpdate);
|
removeListener?: (
|
||||||
} else {
|
event: string,
|
||||||
sock.ev.removeListener?.("messages.upsert", handleMessagesUpsert);
|
listener: (...args: unknown[]) => void,
|
||||||
sock.ev.removeListener?.("connection.update", handleConnectionUpdate);
|
) => void;
|
||||||
|
};
|
||||||
|
const messagesUpsertHandler = handleMessagesUpsert as unknown as (
|
||||||
|
...args: unknown[]
|
||||||
|
) => void;
|
||||||
|
const connectionUpdateHandler = handleConnectionUpdate as unknown as (
|
||||||
|
...args: unknown[]
|
||||||
|
) => void;
|
||||||
|
if (typeof ev.off === "function") {
|
||||||
|
ev.off("messages.upsert", messagesUpsertHandler);
|
||||||
|
ev.off("connection.update", connectionUpdateHandler);
|
||||||
|
} else if (typeof ev.removeListener === "function") {
|
||||||
|
ev.removeListener("messages.upsert", messagesUpsertHandler);
|
||||||
|
ev.removeListener("connection.update", connectionUpdateHandler);
|
||||||
}
|
}
|
||||||
sock.ws?.close();
|
sock.ws?.close();
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
|||||||
Reference in New Issue
Block a user