From 734bb6b4fd9a3a2ee1d03eea9f8f5c802dcc04d0 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Sun, 4 Jan 2026 17:50:55 +0100 Subject: [PATCH] feat: add models scan and fallbacks --- docs/configuration.md | 5 + docs/models.md | 21 +- src/agents/model-fallback.ts | 150 +++++++++ src/agents/model-scan.ts | 379 +++++++++++++++++++++ src/auto-reply/reply/agent-runner.ts | 267 ++++++++------- src/auto-reply/reply/followup-runner.ts | 53 +-- src/auto-reply/status.ts | 3 +- src/cli/models-cli.ts | 198 +++++++++++ src/cli/program.ts | 2 + src/commands/agent.ts | 62 ++-- src/commands/models.ts | 14 + src/commands/models/aliases.ts | 89 +++++ src/commands/models/fallbacks.ts | 134 ++++++++ src/commands/models/list.ts | 419 ++++++++++++++++++++++++ src/commands/models/scan.ts | 267 +++++++++++++++ src/commands/models/set.ts | 29 ++ src/commands/models/shared.ts | 95 ++++++ src/config/schema.ts | 3 + src/config/sessions.ts | 2 + src/config/types.ts | 2 + src/config/zod-schema.ts | 1 + src/cron/isolated-agent.ts | 50 ++- 22 files changed, 2058 insertions(+), 187 deletions(-) create mode 100644 src/agents/model-fallback.ts create mode 100644 src/agents/model-scan.ts create mode 100644 src/cli/models-cli.ts create mode 100644 src/commands/models.ts create mode 100644 src/commands/models/aliases.ts create mode 100644 src/commands/models/fallbacks.ts create mode 100644 src/commands/models/list.ts create mode 100644 src/commands/models/scan.ts create mode 100644 src/commands/models/set.ts create mode 100644 src/commands/models/shared.ts diff --git a/docs/configuration.md b/docs/configuration.md index 16207917c..e22414cf2 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -430,6 +430,7 @@ Controls the embedded agent runtime (model/thinking/verbose/timeouts). `allowedModels` lets `/model` list/filter and enforce a per-session allowlist (omit to show the full catalog). `modelAliases` adds short names for `/model` (alias -> provider/model). +`modelFallbacks` lists ordered fallback models to try when the default fails. ```json5 { @@ -443,6 +444,10 @@ Controls the embedded agent runtime (model/thinking/verbose/timeouts). Opus: "anthropic/claude-opus-4-5", Sonnet: "anthropic/claude-sonnet-4-1" }, + modelFallbacks: [ + "openrouter/deepseek/deepseek-r1:free", + "openrouter/meta-llama/llama-3.3-70b-instruct:free" + ], thinkingDefault: "low", verboseDefault: "off", elevatedDefault: "on", diff --git a/docs/models.md b/docs/models.md index 5b6c3b876..939ffc6bf 100644 --- a/docs/models.md +++ b/docs/models.md @@ -12,18 +12,18 @@ that prefers tool-call + image-capable models and maintains ordered fallbacks. ## Command tree (draft) -- `clawdis models list` +- `clawdbot models list` - default: configured models only - flags: `--all` (full catalog), `--local`, `--provider `, `--json`, `--plain` -- `clawdis models status` - - show default model + last used + aliases + fallbacks -- `clawdis models set ` +- `clawdbot models status` + - show default model + aliases + fallbacks + allowlist +- `clawdbot models set ` - writes `agent.model` in config -- `clawdis models aliases list|add|remove` +- `clawdbot models aliases list|add|remove` - writes `agent.modelAliases` -- `clawdis models fallbacks list|add|remove|clear` +- `clawdbot models fallbacks list|add|remove|clear` - writes `agent.modelFallbacks` -- `clawdis models scan` +- `clawdbot models scan` - OpenRouter :free scan; probe tool-call + image; interactive selection ## Config changes @@ -38,7 +38,9 @@ that prefers tool-call + image-capable models and maintains ordered fallbacks. Input - OpenRouter `/models` list (filter `:free`) +- Requires `OPENROUTER_API_KEY` (or stored OpenRouter key in auth storage) - Optional filters: `--max-age-days`, `--min-params`, `--provider`, `--max-candidates` +- Probe controls: `--timeout`, `--concurrency` Probes (direct pi-ai complete) - Tool-call probe (required): @@ -49,13 +51,13 @@ Probes (direct pi-ai complete) Scoring/selection - Prefer models passing tool + image. - Fallback to tool-only if no tool+image pass. -- Rank by: tool+image first, then lower median latency, then larger context. +- Rank by: image ok, then lower tool latency, then larger context, then params. Interactive selection (TTY) - Multiselect list with per-model stats: - model id, tool ok, image ok, median latency, context, inferred params. - Pre-select top N (default 6). -- Non-TTY: auto-select; require `--yes` or use defaults. +- Non-TTY: auto-select; require `--yes`/`--no-input` to apply. Output - Writes `agent.modelFallbacks` ordered. @@ -64,6 +66,7 @@ Output ## Runtime fallback - On model failure: try `agent.modelFallbacks` in order. +- Ignore fallback entries not in `agent.allowedModels` (if allowlist set). - Persist last successful provider/model to session entry. - `/status` shows last used model (not just default). diff --git a/src/agents/model-fallback.ts b/src/agents/model-fallback.ts new file mode 100644 index 000000000..2f7e4f556 --- /dev/null +++ b/src/agents/model-fallback.ts @@ -0,0 +1,150 @@ +import type { ClawdbotConfig } from "../config/config.js"; +import { DEFAULT_MODEL, DEFAULT_PROVIDER } from "./defaults.js"; +import { + buildModelAliasIndex, + modelKey, + parseModelRef, + resolveModelRefFromString, +} from "./model-selection.js"; + +type ModelCandidate = { + provider: string; + model: string; +}; + +type FallbackAttempt = { + provider: string; + model: string; + error: string; +}; + +function isAbortError(err: unknown): boolean { + if (!err || typeof err !== "object") return false; + const name = "name" in err ? String(err.name) : ""; + if (name === "AbortError") return true; + const message = + "message" in err && typeof err.message === "string" + ? err.message.toLowerCase() + : ""; + return message.includes("aborted"); +} + +function buildAllowedModelKeys( + cfg: ClawdbotConfig | undefined, + defaultProvider: string, +): Set | null { + const rawAllowlist = cfg?.agent?.allowedModels ?? []; + if (rawAllowlist.length === 0) return null; + const keys = new Set(); + for (const raw of rawAllowlist) { + const parsed = parseModelRef(String(raw ?? ""), defaultProvider); + if (!parsed) continue; + keys.add(modelKey(parsed.provider, parsed.model)); + } + return keys.size > 0 ? keys : null; +} + +function resolveFallbackCandidates(params: { + cfg: ClawdbotConfig | undefined; + provider: string; + model: string; +}): ModelCandidate[] { + const provider = params.provider.trim() || DEFAULT_PROVIDER; + const model = params.model.trim() || DEFAULT_MODEL; + const aliasIndex = buildModelAliasIndex({ + cfg: params.cfg ?? {}, + defaultProvider: DEFAULT_PROVIDER, + }); + const allowlist = buildAllowedModelKeys(params.cfg, DEFAULT_PROVIDER); + const seen = new Set(); + const candidates: ModelCandidate[] = []; + + const addCandidate = (candidate: ModelCandidate, enforceAllowlist: boolean) => { + if (!candidate.provider || !candidate.model) return; + const key = modelKey(candidate.provider, candidate.model); + if (seen.has(key)) return; + if (enforceAllowlist && allowlist && !allowlist.has(key)) return; + seen.add(key); + candidates.push(candidate); + }; + + addCandidate({ provider, model }, false); + + for (const raw of params.cfg?.agent?.modelFallbacks ?? []) { + const resolved = resolveModelRefFromString({ + raw: String(raw ?? ""), + defaultProvider: DEFAULT_PROVIDER, + aliasIndex, + }); + if (!resolved) continue; + addCandidate(resolved.ref, true); + } + + return candidates; +} + +export async function runWithModelFallback(params: { + cfg: ClawdbotConfig | undefined; + provider: string; + model: string; + run: (provider: string, model: string) => Promise; + onError?: (attempt: { + provider: string; + model: string; + error: unknown; + attempt: number; + total: number; + }) => void | Promise; +}): Promise<{ + result: T; + provider: string; + model: string; + attempts: FallbackAttempt[]; +}> { + const candidates = resolveFallbackCandidates(params); + const attempts: FallbackAttempt[] = []; + let lastError: unknown; + + for (let i = 0; i < candidates.length; i += 1) { + const candidate = candidates[i] as ModelCandidate; + try { + const result = await params.run(candidate.provider, candidate.model); + return { + result, + provider: candidate.provider, + model: candidate.model, + attempts, + }; + } catch (err) { + if (isAbortError(err)) throw err; + lastError = err; + attempts.push({ + provider: candidate.provider, + model: candidate.model, + error: err instanceof Error ? err.message : String(err), + }); + await params.onError?.({ + provider: candidate.provider, + model: candidate.model, + error: err, + attempt: i + 1, + total: candidates.length, + }); + } + } + + if (attempts.length <= 1 && lastError) throw lastError; + const summary = + attempts.length > 0 + ? attempts + .map( + (attempt) => + `${attempt.provider}/${attempt.model}: ${attempt.error}`, + ) + .join(" | ") + : "unknown"; + throw new Error( + `All models failed (${attempts.length || candidates.length}): ${summary}`, + { cause: lastError instanceof Error ? lastError : undefined }, + ); +} diff --git a/src/agents/model-scan.ts b/src/agents/model-scan.ts new file mode 100644 index 000000000..ac5deefbb --- /dev/null +++ b/src/agents/model-scan.ts @@ -0,0 +1,379 @@ +import { Type } from "@sinclair/typebox"; +import { + complete, + getEnvApiKey, + getModel, + type Context, + type Model, + type Tool, + type OpenAICompletionsOptions, +} from "@mariozechner/pi-ai"; + +const OPENROUTER_MODELS_URL = "https://openrouter.ai/api/v1/models"; +const DEFAULT_TIMEOUT_MS = 12_000; +const DEFAULT_CONCURRENCY = 3; + +const BASE_IMAGE_PNG = + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+X3mIAAAAASUVORK5CYII="; + +const TOOL_PING: Tool = { + name: "ping", + description: "Return OK.", + parameters: Type.Object({}), +}; + +type OpenRouterModelMeta = { + id: string; + name: string; + contextLength: number | null; + maxCompletionTokens: number | null; + supportedParametersCount: number; + modality: string | null; + inferredParamB: number | null; + createdAtMs: number | null; +}; + +export type ProbeResult = { + ok: boolean; + latencyMs: number | null; + error?: string; + skipped?: boolean; +}; + +export type ModelScanResult = { + id: string; + name: string; + provider: string; + modelRef: string; + contextLength: number | null; + maxCompletionTokens: number | null; + supportedParametersCount: number; + modality: string | null; + inferredParamB: number | null; + createdAtMs: number | null; + tool: ProbeResult; + image: ProbeResult; +}; + +export type OpenRouterScanOptions = { + apiKey?: string; + fetchImpl?: typeof fetch; + timeoutMs?: number; + concurrency?: number; + minParamB?: number; + maxAgeDays?: number; + providerFilter?: string; +}; + +type OpenAIModel = Model<"openai-completions">; + +function normalizeCreatedAtMs(value: unknown): number | null { + if (typeof value !== "number" || !Number.isFinite(value)) return null; + if (value <= 0) return null; + if (value > 1e12) return Math.round(value); + return Math.round(value * 1000); +} + +function inferParamBFromIdOrName(text: string): number | null { + const raw = text.toLowerCase(); + const matches = raw.matchAll( + /(?:^|[^a-z0-9])[a-z]?(\d+(?:\.\d+)?)b(?:[^a-z0-9]|$)/g, + ); + let best: number | null = null; + for (const match of matches) { + const numRaw = match[1]; + if (!numRaw) continue; + const value = Number(numRaw); + if (!Number.isFinite(value) || value <= 0) continue; + if (best === null || value > best) best = value; + } + return best; +} + +function parseModality(modality: string | null): Array<"text" | "image"> { + if (!modality) return ["text"]; + const normalized = modality.toLowerCase(); + const parts = normalized.split(/[^a-z]+/).filter(Boolean); + const hasImage = parts.includes("image"); + return hasImage ? ["text", "image"] : ["text"]; +} + +async function withTimeout( + timeoutMs: number, + fn: (signal: AbortSignal) => Promise, +): Promise { + const controller = new AbortController(); + const timer = setTimeout(() => controller.abort(), timeoutMs); + try { + return await fn(controller.signal); + } finally { + clearTimeout(timer); + } +} + +async function fetchOpenRouterModels( + fetchImpl: typeof fetch, +): Promise { + const res = await fetchImpl(OPENROUTER_MODELS_URL, { + headers: { Accept: "application/json" }, + }); + if (!res.ok) { + throw new Error(`OpenRouter /models failed: HTTP ${res.status}`); + } + const payload = (await res.json()) as { data?: unknown }; + const entries = Array.isArray(payload.data) ? payload.data : []; + + return entries + .map((entry) => { + if (!entry || typeof entry !== "object") return null; + const obj = entry as Record; + const id = typeof obj.id === "string" ? obj.id.trim() : ""; + if (!id) return null; + const name = + typeof obj.name === "string" && obj.name.trim() + ? obj.name.trim() + : id; + + const contextLength = + typeof obj.context_length === "number" && + Number.isFinite(obj.context_length) + ? obj.context_length + : null; + + const maxCompletionTokens = + typeof obj.max_completion_tokens === "number" && + Number.isFinite(obj.max_completion_tokens) + ? obj.max_completion_tokens + : typeof obj.max_output_tokens === "number" && + Number.isFinite(obj.max_output_tokens) + ? obj.max_output_tokens + : null; + + const supportedParametersCount = Array.isArray(obj.supported_parameters) + ? obj.supported_parameters.length + : 0; + + const modality = + typeof obj.modality === "string" && obj.modality.trim() + ? obj.modality.trim() + : null; + + const inferredParamB = inferParamBFromIdOrName(`${id} ${name}`); + const createdAtMs = normalizeCreatedAtMs(obj.created_at); + + return { + id, + name, + contextLength, + maxCompletionTokens, + supportedParametersCount, + modality, + inferredParamB, + createdAtMs, + } satisfies OpenRouterModelMeta; + }) + .filter((entry): entry is OpenRouterModelMeta => Boolean(entry)); +} + +async function probeTool( + model: OpenAIModel, + apiKey: string, + timeoutMs: number, +): Promise { + const context: Context = { + messages: [ + { + role: "user", + content: "Call the ping tool with {} and nothing else.", + timestamp: Date.now(), + }, + ], + tools: [TOOL_PING], + }; + const startedAt = Date.now(); + try { + const message = await withTimeout(timeoutMs, (signal) => + complete(model, context, { + apiKey, + maxTokens: 32, + temperature: 0, + toolChoice: "required", + signal, + } satisfies OpenAICompletionsOptions), + ); + + const hasToolCall = message.content.some( + (block) => block.type === "toolCall", + ); + if (!hasToolCall) { + return { + ok: false, + latencyMs: Date.now() - startedAt, + error: "No tool call returned", + }; + } + + return { ok: true, latencyMs: Date.now() - startedAt }; + } catch (err) { + return { + ok: false, + latencyMs: Date.now() - startedAt, + error: err instanceof Error ? err.message : String(err), + }; + } +} + +async function probeImage( + model: OpenAIModel, + apiKey: string, + timeoutMs: number, +): Promise { + const context: Context = { + messages: [ + { + role: "user", + content: [ + { type: "text", text: "Reply with OK." }, + { type: "image", data: BASE_IMAGE_PNG, mimeType: "image/png" }, + ], + timestamp: Date.now(), + }, + ], + }; + const startedAt = Date.now(); + try { + await withTimeout(timeoutMs, (signal) => + complete(model, context, { + apiKey, + maxTokens: 16, + temperature: 0, + signal, + } satisfies OpenAICompletionsOptions), + ); + return { ok: true, latencyMs: Date.now() - startedAt }; + } catch (err) { + return { + ok: false, + latencyMs: Date.now() - startedAt, + error: err instanceof Error ? err.message : String(err), + }; + } +} + +function ensureImageInput(model: OpenAIModel): OpenAIModel { + if (model.input.includes("image")) return model; + return { + ...model, + input: Array.from(new Set([...model.input, "image"])), + }; +} + +async function mapWithConcurrency( + items: T[], + concurrency: number, + fn: (item: T, index: number) => Promise, +): Promise { + const limit = Math.max(1, Math.floor(concurrency)); + const results: R[] = new Array(items.length); + let nextIndex = 0; + + const worker = async () => { + while (true) { + const current = nextIndex; + nextIndex += 1; + if (current >= items.length) return; + results[current] = await fn(items[current] as T, current); + } + }; + + await Promise.all( + Array.from({ length: Math.min(limit, items.length) }, () => worker()), + ); + return results; +} + +export async function scanOpenRouterModels( + options: OpenRouterScanOptions = {}, +): Promise { + const fetchImpl = options.fetchImpl ?? fetch; + const apiKey = + options.apiKey?.trim() || getEnvApiKey("openrouter") || ""; + if (!apiKey) { + throw new Error( + "Missing OpenRouter API key. Set OPENROUTER_API_KEY to run models scan.", + ); + } + + const timeoutMs = Math.max( + 1, + Math.floor(options.timeoutMs ?? DEFAULT_TIMEOUT_MS), + ); + const concurrency = Math.max( + 1, + Math.floor(options.concurrency ?? DEFAULT_CONCURRENCY), + ); + const minParamB = Math.max(0, Math.floor(options.minParamB ?? 0)); + const maxAgeDays = Math.max(0, Math.floor(options.maxAgeDays ?? 0)); + const providerFilter = options.providerFilter?.trim().toLowerCase() ?? ""; + + const catalog = await fetchOpenRouterModels(fetchImpl); + const now = Date.now(); + + const filtered = catalog.filter((entry) => { + if (!entry.id.endsWith(":free")) return false; + if (providerFilter) { + const prefix = entry.id.split("/")[0]?.toLowerCase() ?? ""; + if (prefix !== providerFilter) return false; + } + if (minParamB > 0) { + const params = entry.inferredParamB ?? 0; + if (params < minParamB) return false; + } + if (maxAgeDays > 0 && entry.createdAtMs) { + const ageMs = now - entry.createdAtMs; + const ageDays = ageMs / (24 * 60 * 60 * 1000); + if (ageDays > maxAgeDays) return false; + } + return true; + }); + + const baseModel = getModel( + "openrouter", + "openrouter/auto", + ) as OpenAIModel; + + return mapWithConcurrency(filtered, concurrency, async (entry) => { + const model: OpenAIModel = { + ...baseModel, + id: entry.id, + name: entry.name || entry.id, + contextWindow: entry.contextLength ?? baseModel.contextWindow, + maxTokens: entry.maxCompletionTokens ?? baseModel.maxTokens, + input: parseModality(entry.modality), + reasoning: baseModel.reasoning, + }; + + const toolResult = await probeTool(model, apiKey, timeoutMs); + const imageResult = toolResult.ok + ? await probeImage(ensureImageInput(model), apiKey, timeoutMs) + : { ok: false, latencyMs: null, skipped: true }; + + return { + id: entry.id, + name: entry.name, + provider: "openrouter", + modelRef: `openrouter/${entry.id}`, + contextLength: entry.contextLength, + maxCompletionTokens: entry.maxCompletionTokens, + supportedParametersCount: entry.supportedParametersCount, + modality: entry.modality, + inferredParamB: entry.inferredParamB, + createdAtMs: entry.createdAtMs, + tool: toolResult, + image: imageResult, + } satisfies ModelScanResult; + }); +} + +export { OPENROUTER_MODELS_URL }; +export type { OpenRouterModelMeta }; diff --git a/src/auto-reply/reply/agent-runner.ts b/src/auto-reply/reply/agent-runner.ts index ac0535308..f5eaccf74 100644 --- a/src/auto-reply/reply/agent-runner.ts +++ b/src/auto-reply/reply/agent-runner.ts @@ -5,6 +5,7 @@ import { queueEmbeddedPiMessage, runEmbeddedPiAgent, } from "../../agents/pi-embedded.js"; +import { runWithModelFallback } from "../../agents/model-fallback.js"; import { loadSessionStore, type SessionEntry, @@ -170,131 +171,154 @@ export async function runReplyAgent(params: { registerAgentRunContext(runId, { sessionKey }); } let runResult: Awaited>; + let fallbackProvider = followupRun.run.provider; + let fallbackModel = followupRun.run.model; try { - runResult = await runEmbeddedPiAgent({ - sessionId: followupRun.run.sessionId, - sessionKey, - surface: sessionCtx.Surface?.trim().toLowerCase() || undefined, - sessionFile: followupRun.run.sessionFile, - workspaceDir: followupRun.run.workspaceDir, - config: followupRun.run.config, - skillsSnapshot: followupRun.run.skillsSnapshot, - prompt: commandBody, - extraSystemPrompt: followupRun.run.extraSystemPrompt, - ownerNumbers: followupRun.run.ownerNumbers, - enforceFinalTag: followupRun.run.enforceFinalTag, + const fallbackResult = await runWithModelFallback({ + cfg: followupRun.run.config, provider: followupRun.run.provider, model: followupRun.run.model, - thinkLevel: followupRun.run.thinkLevel, - verboseLevel: followupRun.run.verboseLevel, - bashElevated: followupRun.run.bashElevated, - timeoutMs: followupRun.run.timeoutMs, - runId, - blockReplyBreak: resolvedBlockStreamingBreak, - blockReplyChunking, - onPartialReply: opts?.onPartialReply - ? async (payload) => { - let text = payload.text; - if (!opts?.isHeartbeat && text?.includes("HEARTBEAT_OK")) { - const stripped = stripHeartbeatToken(text, { mode: "message" }); - if (stripped.didStrip && !didLogHeartbeatStrip) { - didLogHeartbeatStrip = true; - logVerbose("Stripped stray HEARTBEAT_OK token from reply"); - } - if ( - stripped.shouldSkip && - (payload.mediaUrls?.length ?? 0) === 0 - ) { - return; - } - text = stripped.text; - } - await typing.startTypingOnText(text); - await opts.onPartialReply?.({ - text, - mediaUrls: payload.mediaUrls, - }); - } - : undefined, - onBlockReply: - blockStreamingEnabled && opts?.onBlockReply - ? async (payload) => { - let text = payload.text; - if (!opts?.isHeartbeat && text?.includes("HEARTBEAT_OK")) { - const stripped = stripHeartbeatToken(text, { - mode: "message", - }); - if (stripped.didStrip && !didLogHeartbeatStrip) { - didLogHeartbeatStrip = true; - logVerbose("Stripped stray HEARTBEAT_OK token from reply"); + run: (provider, model) => + runEmbeddedPiAgent({ + sessionId: followupRun.run.sessionId, + sessionKey, + surface: sessionCtx.Surface?.trim().toLowerCase() || undefined, + sessionFile: followupRun.run.sessionFile, + workspaceDir: followupRun.run.workspaceDir, + config: followupRun.run.config, + skillsSnapshot: followupRun.run.skillsSnapshot, + prompt: commandBody, + extraSystemPrompt: followupRun.run.extraSystemPrompt, + ownerNumbers: followupRun.run.ownerNumbers, + enforceFinalTag: followupRun.run.enforceFinalTag, + provider, + model, + thinkLevel: followupRun.run.thinkLevel, + verboseLevel: followupRun.run.verboseLevel, + bashElevated: followupRun.run.bashElevated, + timeoutMs: followupRun.run.timeoutMs, + runId, + blockReplyBreak: resolvedBlockStreamingBreak, + blockReplyChunking, + onPartialReply: opts?.onPartialReply + ? async (payload) => { + let text = payload.text; + if (!opts?.isHeartbeat && text?.includes("HEARTBEAT_OK")) { + const stripped = stripHeartbeatToken(text, { + mode: "message", + }); + if (stripped.didStrip && !didLogHeartbeatStrip) { + didLogHeartbeatStrip = true; + logVerbose("Stripped stray HEARTBEAT_OK token from reply"); + } + if ( + stripped.shouldSkip && + (payload.mediaUrls?.length ?? 0) === 0 + ) { + return; + } + text = stripped.text; } - const hasMedia = (payload.mediaUrls?.length ?? 0) > 0; - if (stripped.shouldSkip && !hasMedia) return; - text = stripped.text; - } - const tagResult = extractReplyToTag( - text, - sessionCtx.MessageSid, - ); - const cleaned = tagResult.cleaned || undefined; - const hasMedia = (payload.mediaUrls?.length ?? 0) > 0; - if (!cleaned && !hasMedia) return; - if (cleaned?.trim() === SILENT_REPLY_TOKEN && !hasMedia) return; - const blockPayload: ReplyPayload = { - text: cleaned, - mediaUrls: payload.mediaUrls, - mediaUrl: payload.mediaUrls?.[0], - replyToId: tagResult.replyToId, - }; - const payloadKey = buildPayloadKey(blockPayload); - if ( - streamedPayloadKeys.has(payloadKey) || - pendingStreamedPayloadKeys.has(payloadKey) - ) { - return; - } - pendingStreamedPayloadKeys.add(payloadKey); - const task = (async () => { - await typing.startTypingOnText(cleaned); - await opts.onBlockReply?.(blockPayload); - })() - .then(() => { - streamedPayloadKeys.add(payloadKey); - didStreamBlockReply = true; - }) - .catch((err) => { - logVerbose(`block reply delivery failed: ${String(err)}`); - }) - .finally(() => { - pendingStreamedPayloadKeys.delete(payloadKey); + await typing.startTypingOnText(text); + await opts.onPartialReply?.({ + text, + mediaUrls: payload.mediaUrls, }); - pendingBlockTasks.add(task); - void task.finally(() => pendingBlockTasks.delete(task)); - } - : undefined, - shouldEmitToolResult, - onToolResult: opts?.onToolResult - ? async (payload) => { - let text = payload.text; - if (!opts?.isHeartbeat && text?.includes("HEARTBEAT_OK")) { - const stripped = stripHeartbeatToken(text, { mode: "message" }); - if (stripped.didStrip && !didLogHeartbeatStrip) { - didLogHeartbeatStrip = true; - logVerbose("Stripped stray HEARTBEAT_OK token from reply"); } - if ( - stripped.shouldSkip && - (payload.mediaUrls?.length ?? 0) === 0 - ) { - return; + : undefined, + onBlockReply: + blockStreamingEnabled && opts?.onBlockReply + ? async (payload) => { + let text = payload.text; + if (!opts?.isHeartbeat && text?.includes("HEARTBEAT_OK")) { + const stripped = stripHeartbeatToken(text, { + mode: "message", + }); + if (stripped.didStrip && !didLogHeartbeatStrip) { + didLogHeartbeatStrip = true; + logVerbose( + "Stripped stray HEARTBEAT_OK token from reply", + ); + } + const hasMedia = (payload.mediaUrls?.length ?? 0) > 0; + if (stripped.shouldSkip && !hasMedia) return; + text = stripped.text; + } + const tagResult = extractReplyToTag( + text, + sessionCtx.MessageSid, + ); + const cleaned = tagResult.cleaned || undefined; + const hasMedia = (payload.mediaUrls?.length ?? 0) > 0; + if (!cleaned && !hasMedia) return; + if (cleaned?.trim() === SILENT_REPLY_TOKEN && !hasMedia) + return; + const blockPayload: ReplyPayload = { + text: cleaned, + mediaUrls: payload.mediaUrls, + mediaUrl: payload.mediaUrls?.[0], + replyToId: tagResult.replyToId, + }; + const payloadKey = buildPayloadKey(blockPayload); + if ( + streamedPayloadKeys.has(payloadKey) || + pendingStreamedPayloadKeys.has(payloadKey) + ) { + return; + } + pendingStreamedPayloadKeys.add(payloadKey); + const task = (async () => { + await typing.startTypingOnText(cleaned); + await opts.onBlockReply?.(blockPayload); + })() + .then(() => { + streamedPayloadKeys.add(payloadKey); + didStreamBlockReply = true; + }) + .catch((err) => { + logVerbose( + `block reply delivery failed: ${String(err)}`, + ); + }) + .finally(() => { + pendingStreamedPayloadKeys.delete(payloadKey); + }); + pendingBlockTasks.add(task); + void task.finally(() => pendingBlockTasks.delete(task)); + } + : undefined, + shouldEmitToolResult, + onToolResult: opts?.onToolResult + ? async (payload) => { + let text = payload.text; + if (!opts?.isHeartbeat && text?.includes("HEARTBEAT_OK")) { + const stripped = stripHeartbeatToken(text, { + mode: "message", + }); + if (stripped.didStrip && !didLogHeartbeatStrip) { + didLogHeartbeatStrip = true; + logVerbose("Stripped stray HEARTBEAT_OK token from reply"); + } + if ( + stripped.shouldSkip && + (payload.mediaUrls?.length ?? 0) === 0 + ) { + return; + } + text = stripped.text; + } + await typing.startTypingOnText(text); + await opts.onToolResult?.({ + text, + mediaUrls: payload.mediaUrls, + }); } - text = stripped.text; - } - await typing.startTypingOnText(text); - await opts.onToolResult?.({ text, mediaUrls: payload.mediaUrls }); - } - : undefined, + : undefined, + }), }); + runResult = fallbackResult.result; + fallbackProvider = fallbackResult.provider; + fallbackModel = fallbackResult.model; } catch (err) { const message = err instanceof Error ? err.message : String(err); const isContextOverflow = @@ -388,7 +412,12 @@ export async function runReplyAgent(params: { if (sessionStore && sessionKey) { const usage = runResult.meta.agentMeta?.usage; - const modelUsed = runResult.meta.agentMeta?.model ?? defaultModel; + const modelUsed = + runResult.meta.agentMeta?.model ?? fallbackModel ?? defaultModel; + const providerUsed = + runResult.meta.agentMeta?.provider ?? + fallbackProvider ?? + followupRun.run.provider; const contextTokensUsed = agentCfgContextTokens ?? lookupContextTokens(modelUsed) ?? @@ -408,6 +437,7 @@ export async function runReplyAgent(params: { outputTokens: output, totalTokens: promptTokens > 0 ? promptTokens : (usage.total ?? input), + modelProvider: providerUsed, model: modelUsed, contextTokens: contextTokensUsed ?? entry.contextTokens, updatedAt: Date.now(), @@ -422,6 +452,7 @@ export async function runReplyAgent(params: { if (entry) { sessionStore[sessionKey] = { ...entry, + modelProvider: providerUsed ?? entry.modelProvider, model: modelUsed ?? entry.model, contextTokens: contextTokensUsed ?? entry.contextTokens, }; diff --git a/src/auto-reply/reply/followup-runner.ts b/src/auto-reply/reply/followup-runner.ts index ed2d700e2..86f978ec3 100644 --- a/src/auto-reply/reply/followup-runner.ts +++ b/src/auto-reply/reply/followup-runner.ts @@ -1,6 +1,7 @@ import crypto from "node:crypto"; import { lookupContextTokens } from "../../agents/context.js"; import { DEFAULT_CONTEXT_TOKENS } from "../../agents/defaults.js"; +import { runWithModelFallback } from "../../agents/model-fallback.js"; import { runEmbeddedPiAgent } from "../../agents/pi-embedded.js"; import { type SessionEntry, saveSessionStore } from "../../config/sessions.js"; import { logVerbose } from "../../globals.js"; @@ -61,28 +62,39 @@ export function createFollowupRunner(params: { registerAgentRunContext(runId, { sessionKey: queued.run.sessionKey }); } let runResult: Awaited>; + let fallbackProvider = queued.run.provider; + let fallbackModel = queued.run.model; try { - runResult = await runEmbeddedPiAgent({ - sessionId: queued.run.sessionId, - sessionKey: queued.run.sessionKey, - surface: queued.run.surface, - sessionFile: queued.run.sessionFile, - workspaceDir: queued.run.workspaceDir, - config: queued.run.config, - skillsSnapshot: queued.run.skillsSnapshot, - prompt: queued.prompt, - extraSystemPrompt: queued.run.extraSystemPrompt, - ownerNumbers: queued.run.ownerNumbers, - enforceFinalTag: queued.run.enforceFinalTag, + const fallbackResult = await runWithModelFallback({ + cfg: queued.run.config, provider: queued.run.provider, model: queued.run.model, - thinkLevel: queued.run.thinkLevel, - verboseLevel: queued.run.verboseLevel, - bashElevated: queued.run.bashElevated, - timeoutMs: queued.run.timeoutMs, - runId, - blockReplyBreak: queued.run.blockReplyBreak, + run: (provider, model) => + runEmbeddedPiAgent({ + sessionId: queued.run.sessionId, + sessionKey: queued.run.sessionKey, + surface: queued.run.surface, + sessionFile: queued.run.sessionFile, + workspaceDir: queued.run.workspaceDir, + config: queued.run.config, + skillsSnapshot: queued.run.skillsSnapshot, + prompt: queued.prompt, + extraSystemPrompt: queued.run.extraSystemPrompt, + ownerNumbers: queued.run.ownerNumbers, + enforceFinalTag: queued.run.enforceFinalTag, + provider, + model, + thinkLevel: queued.run.thinkLevel, + verboseLevel: queued.run.verboseLevel, + bashElevated: queued.run.bashElevated, + timeoutMs: queued.run.timeoutMs, + runId, + blockReplyBreak: queued.run.blockReplyBreak, + }), }); + runResult = fallbackResult.result; + fallbackProvider = fallbackResult.provider; + fallbackModel = fallbackResult.model; } catch (err) { const message = err instanceof Error ? err.message : String(err); defaultRuntime.error?.(`Followup agent failed before reply: ${message}`); @@ -121,7 +133,8 @@ export function createFollowupRunner(params: { if (sessionStore && sessionKey) { const usage = runResult.meta.agentMeta?.usage; - const modelUsed = runResult.meta.agentMeta?.model ?? defaultModel; + const modelUsed = + runResult.meta.agentMeta?.model ?? fallbackModel ?? defaultModel; const contextTokensUsed = agentCfgContextTokens ?? lookupContextTokens(modelUsed) ?? @@ -141,6 +154,7 @@ export function createFollowupRunner(params: { outputTokens: output, totalTokens: promptTokens > 0 ? promptTokens : (usage.total ?? input), + modelProvider: fallbackProvider ?? entry.modelProvider, model: modelUsed, contextTokens: contextTokensUsed ?? entry.contextTokens, updatedAt: Date.now(), @@ -154,6 +168,7 @@ export function createFollowupRunner(params: { if (entry) { sessionStore[sessionKey] = { ...entry, + modelProvider: fallbackProvider ?? entry.modelProvider, model: modelUsed ?? entry.model, contextTokens: contextTokensUsed ?? entry.contextTokens, }; diff --git a/src/auto-reply/status.ts b/src/auto-reply/status.ts index a860a6452..1a808e86c 100644 --- a/src/auto-reply/status.ts +++ b/src/auto-reply/status.ts @@ -133,6 +133,7 @@ export function buildStatusMessage(args: StatusArgs): string { defaultProvider: DEFAULT_PROVIDER, defaultModel: DEFAULT_MODEL, }); + const provider = entry?.modelProvider ?? resolved.provider ?? DEFAULT_PROVIDER; let model = entry?.model ?? resolved.model ?? DEFAULT_MODEL; let contextTokens = entry?.contextTokens ?? @@ -204,7 +205,7 @@ export function buildStatusMessage(args: StatusArgs): string { const optionsLine = `Options: thinking=${thinkLevel} | verbose=${verboseLevel} | elevated=${elevatedLevel} (set with /think , /verbose on|off, /elevated on|off, /model )`; - const modelLabel = model ? `${resolved.provider}/${model}` : "unknown"; + const modelLabel = model ? `${provider}/${model}` : "unknown"; const agentLine = `Agent: embedded pi • ${modelLabel}`; diff --git a/src/cli/models-cli.ts b/src/cli/models-cli.ts new file mode 100644 index 000000000..c3e427c40 --- /dev/null +++ b/src/cli/models-cli.ts @@ -0,0 +1,198 @@ +import type { Command } from "commander"; + +import { + modelsAliasesAddCommand, + modelsAliasesListCommand, + modelsAliasesRemoveCommand, + modelsFallbacksAddCommand, + modelsFallbacksClearCommand, + modelsFallbacksListCommand, + modelsFallbacksRemoveCommand, + modelsListCommand, + modelsScanCommand, + modelsSetCommand, + modelsStatusCommand, +} from "../commands/models.js"; +import { defaultRuntime } from "../runtime.js"; + +export function registerModelsCli(program: Command) { + const models = program + .command("models") + .description("Model discovery, scanning, and configuration"); + + models + .command("list") + .description("List models (configured by default)") + .option("--all", "Show full model catalog", false) + .option("--local", "Filter to local models", false) + .option("--provider ", "Filter by provider") + .option("--json", "Output JSON", false) + .option("--plain", "Plain line output", false) + .action(async (opts) => { + try { + await modelsListCommand(opts, defaultRuntime); + } catch (err) { + defaultRuntime.error(String(err)); + defaultRuntime.exit(1); + } + }); + + models + .command("status") + .description("Show configured model state") + .option("--json", "Output JSON", false) + .option("--plain", "Plain output", false) + .action(async (opts) => { + try { + await modelsStatusCommand(opts, defaultRuntime); + } catch (err) { + defaultRuntime.error(String(err)); + defaultRuntime.exit(1); + } + }); + + models + .command("set") + .description("Set the default model") + .argument("", "Model id or alias") + .action(async (model: string) => { + try { + await modelsSetCommand(model, defaultRuntime); + } catch (err) { + defaultRuntime.error(String(err)); + defaultRuntime.exit(1); + } + }); + + const aliases = models + .command("aliases") + .description("Manage model aliases"); + + aliases + .command("list") + .description("List model aliases") + .option("--json", "Output JSON", false) + .option("--plain", "Plain output", false) + .action(async (opts) => { + try { + await modelsAliasesListCommand(opts, defaultRuntime); + } catch (err) { + defaultRuntime.error(String(err)); + defaultRuntime.exit(1); + } + }); + + aliases + .command("add") + .description("Add or update a model alias") + .argument("", "Alias name") + .argument("", "Model id or alias") + .action(async (alias: string, model: string) => { + try { + await modelsAliasesAddCommand(alias, model, defaultRuntime); + } catch (err) { + defaultRuntime.error(String(err)); + defaultRuntime.exit(1); + } + }); + + aliases + .command("remove") + .description("Remove a model alias") + .argument("", "Alias name") + .action(async (alias: string) => { + try { + await modelsAliasesRemoveCommand(alias, defaultRuntime); + } catch (err) { + defaultRuntime.error(String(err)); + defaultRuntime.exit(1); + } + }); + + const fallbacks = models + .command("fallbacks") + .description("Manage model fallback list"); + + fallbacks + .command("list") + .description("List fallback models") + .option("--json", "Output JSON", false) + .option("--plain", "Plain output", false) + .action(async (opts) => { + try { + await modelsFallbacksListCommand(opts, defaultRuntime); + } catch (err) { + defaultRuntime.error(String(err)); + defaultRuntime.exit(1); + } + }); + + fallbacks + .command("add") + .description("Add a fallback model") + .argument("", "Model id or alias") + .action(async (model: string) => { + try { + await modelsFallbacksAddCommand(model, defaultRuntime); + } catch (err) { + defaultRuntime.error(String(err)); + defaultRuntime.exit(1); + } + }); + + fallbacks + .command("remove") + .description("Remove a fallback model") + .argument("", "Model id or alias") + .action(async (model: string) => { + try { + await modelsFallbacksRemoveCommand(model, defaultRuntime); + } catch (err) { + defaultRuntime.error(String(err)); + defaultRuntime.exit(1); + } + }); + + fallbacks + .command("clear") + .description("Clear all fallback models") + .action(async () => { + try { + await modelsFallbacksClearCommand(defaultRuntime); + } catch (err) { + defaultRuntime.error(String(err)); + defaultRuntime.exit(1); + } + }); + + models + .command("scan") + .description("Scan OpenRouter free models for tools + images") + .option("--min-params ", "Minimum parameter size (billions)") + .option("--max-age-days ", "Skip models older than N days") + .option("--provider ", "Filter by provider prefix") + .option("--max-candidates ", "Max fallback candidates", "6") + .option("--timeout ", "Per-probe timeout in ms") + .option("--concurrency ", "Probe concurrency") + .option("--yes", "Accept defaults without prompting", false) + .option("--no-input", "Disable prompts (use defaults)") + .option("--set-default", "Set agent.model to the first selection", false) + .option("--json", "Output JSON", false) + .action(async (opts) => { + try { + await modelsScanCommand(opts, defaultRuntime); + } catch (err) { + defaultRuntime.error(String(err)); + defaultRuntime.exit(1); + } + }); + + models.action(async () => { + try { + await modelsStatusCommand({}, defaultRuntime); + } catch (err) { + defaultRuntime.error(String(err)); + defaultRuntime.exit(1); + } + }); +} diff --git a/src/cli/program.ts b/src/cli/program.ts index 3d87a49f4..9b7a866d9 100644 --- a/src/cli/program.ts +++ b/src/cli/program.ts @@ -22,6 +22,7 @@ import { createDefaultDeps } from "./deps.js"; import { registerDnsCli } from "./dns-cli.js"; import { registerGatewayCli } from "./gateway-cli.js"; import { registerHooksCli } from "./hooks-cli.js"; +import { registerModelsCli } from "./models-cli.js"; import { registerNodesCli } from "./nodes-cli.js"; import { forceFreePort } from "./ports.js"; import { registerTuiCli } from "./tui-cli.js"; @@ -399,6 +400,7 @@ Examples: registerCanvasCli(program); registerGatewayCli(program); + registerModelsCli(program); registerNodesCli(program); registerTuiCli(program); registerCronCli(program); diff --git a/src/commands/agent.ts b/src/commands/agent.ts index e13007e42..8b385c303 100644 --- a/src/commands/agent.ts +++ b/src/commands/agent.ts @@ -12,6 +12,7 @@ import { resolveConfiguredModelRef, resolveThinkingDefault, } from "../agents/model-selection.js"; +import { runWithModelFallback } from "../agents/model-fallback.js"; import { runEmbeddedPiAgent } from "../agents/pi-embedded.js"; import { buildWorkspaceSkillSnapshot } from "../agents/skills.js"; import { @@ -364,6 +365,8 @@ export async function agentCommand( }); let result: Awaited>; + let fallbackProvider = provider; + let fallbackModel = model; try { const surface = opts.surface?.trim().toLowerCase() || @@ -372,32 +375,41 @@ export async function agentCommand( if (!raw) return undefined; return raw === "imsg" ? "imessage" : raw; })(); - result = await runEmbeddedPiAgent({ - sessionId, - sessionKey, - surface, - sessionFile, - workspaceDir, - config: cfg, - skillsSnapshot, - prompt: body, + const fallbackResult = await runWithModelFallback({ + cfg, provider, model, - thinkLevel: resolvedThinkLevel, - verboseLevel: resolvedVerboseLevel, - timeoutMs, - runId, - lane: opts.lane, - abortSignal: opts.abortSignal, - extraSystemPrompt: opts.extraSystemPrompt, - onAgentEvent: (evt) => { - emitAgentEvent({ + run: (providerOverride, modelOverride) => + runEmbeddedPiAgent({ + sessionId, + sessionKey, + surface, + sessionFile, + workspaceDir, + config: cfg, + skillsSnapshot, + prompt: body, + provider: providerOverride, + model: modelOverride, + thinkLevel: resolvedThinkLevel, + verboseLevel: resolvedVerboseLevel, + timeoutMs, runId, - stream: evt.stream, - data: evt.data, - }); - }, + lane: opts.lane, + abortSignal: opts.abortSignal, + extraSystemPrompt: opts.extraSystemPrompt, + onAgentEvent: (evt) => { + emitAgentEvent({ + runId, + stream: evt.stream, + data: evt.data, + }); + }, + }), }); + result = fallbackResult.result; + fallbackProvider = fallbackResult.provider; + fallbackModel = fallbackResult.model; emitAgentEvent({ runId, stream: "job", @@ -431,7 +443,10 @@ export async function agentCommand( // Update token+model fields in the session store. if (sessionStore && sessionKey) { const usage = result.meta.agentMeta?.usage; - const modelUsed = result.meta.agentMeta?.model ?? model; + const modelUsed = + result.meta.agentMeta?.model ?? fallbackModel ?? model; + const providerUsed = + result.meta.agentMeta?.provider ?? fallbackProvider ?? provider; const contextTokens = agentCfg?.contextTokens ?? lookupContextTokens(modelUsed) ?? @@ -445,6 +460,7 @@ export async function agentCommand( ...entry, sessionId, updatedAt: Date.now(), + modelProvider: providerUsed, model: modelUsed, contextTokens, }; diff --git a/src/commands/models.ts b/src/commands/models.ts new file mode 100644 index 000000000..9916c0529 --- /dev/null +++ b/src/commands/models.ts @@ -0,0 +1,14 @@ +export { modelsListCommand, modelsStatusCommand } from "./models/list.js"; +export { + modelsAliasesAddCommand, + modelsAliasesListCommand, + modelsAliasesRemoveCommand, +} from "./models/aliases.js"; +export { + modelsFallbacksAddCommand, + modelsFallbacksClearCommand, + modelsFallbacksListCommand, + modelsFallbacksRemoveCommand, +} from "./models/fallbacks.js"; +export { modelsScanCommand } from "./models/scan.js"; +export { modelsSetCommand } from "./models/set.js"; diff --git a/src/commands/models/aliases.ts b/src/commands/models/aliases.ts new file mode 100644 index 000000000..9ed12db3f --- /dev/null +++ b/src/commands/models/aliases.ts @@ -0,0 +1,89 @@ +import { + CONFIG_PATH_CLAWDBOT, + loadConfig, +} from "../../config/config.js"; +import type { RuntimeEnv } from "../../runtime.js"; +import { + ensureFlagCompatibility, + normalizeAlias, + resolveModelTarget, + updateConfig, +} from "./shared.js"; + +export async function modelsAliasesListCommand( + opts: { json?: boolean; plain?: boolean }, + runtime: RuntimeEnv, +) { + ensureFlagCompatibility(opts); + const cfg = loadConfig(); + const aliases = cfg.agent?.modelAliases ?? {}; + + if (opts.json) { + runtime.log(JSON.stringify({ aliases }, null, 2)); + return; + } + if (opts.plain) { + for (const [alias, target] of Object.entries(aliases)) { + runtime.log(`${alias} ${target}`); + } + return; + } + + runtime.log(`Aliases (${Object.keys(aliases).length}):`); + if (Object.keys(aliases).length === 0) { + runtime.log("- none"); + return; + } + for (const [alias, target] of Object.entries(aliases)) { + runtime.log(`- ${alias} -> ${target}`); + } +} + +export async function modelsAliasesAddCommand( + aliasRaw: string, + modelRaw: string, + runtime: RuntimeEnv, +) { + const alias = normalizeAlias(aliasRaw); + const updated = await updateConfig((cfg) => { + const resolved = resolveModelTarget({ raw: modelRaw, cfg }); + const nextAliases = { ...(cfg.agent?.modelAliases ?? {}) }; + nextAliases[alias] = `${resolved.provider}/${resolved.model}`; + return { + ...cfg, + agent: { + ...cfg.agent, + modelAliases: nextAliases, + }, + }; + }); + + runtime.log(`Updated ${CONFIG_PATH_CLAWDBOT}`); + runtime.log(`Alias ${alias} -> ${updated.agent?.modelAliases?.[alias]}`); +} + +export async function modelsAliasesRemoveCommand( + aliasRaw: string, + runtime: RuntimeEnv, +) { + const alias = normalizeAlias(aliasRaw); + const updated = await updateConfig((cfg) => { + const nextAliases = { ...(cfg.agent?.modelAliases ?? {}) }; + if (!nextAliases[alias]) { + throw new Error(`Alias not found: ${alias}`); + } + delete nextAliases[alias]; + return { + ...cfg, + agent: { + ...cfg.agent, + modelAliases: nextAliases, + }, + }; + }); + + runtime.log(`Updated ${CONFIG_PATH_CLAWDBOT}`); + if (!updated.agent?.modelAliases || Object.keys(updated.agent.modelAliases).length === 0) { + runtime.log("No aliases configured."); + } +} diff --git a/src/commands/models/fallbacks.ts b/src/commands/models/fallbacks.ts new file mode 100644 index 000000000..ebfe406e7 --- /dev/null +++ b/src/commands/models/fallbacks.ts @@ -0,0 +1,134 @@ +import { + CONFIG_PATH_CLAWDBOT, + loadConfig, +} from "../../config/config.js"; +import type { RuntimeEnv } from "../../runtime.js"; +import { + buildModelAliasIndex, + resolveModelRefFromString, +} from "../../agents/model-selection.js"; +import { + DEFAULT_PROVIDER, + ensureFlagCompatibility, + modelKey, + resolveModelTarget, + updateConfig, +} from "./shared.js"; + +export async function modelsFallbacksListCommand( + opts: { json?: boolean; plain?: boolean }, + runtime: RuntimeEnv, +) { + ensureFlagCompatibility(opts); + const cfg = loadConfig(); + const fallbacks = cfg.agent?.modelFallbacks ?? []; + + if (opts.json) { + runtime.log(JSON.stringify({ fallbacks }, null, 2)); + return; + } + if (opts.plain) { + for (const entry of fallbacks) runtime.log(entry); + return; + } + + runtime.log(`Fallbacks (${fallbacks.length}):`); + if (fallbacks.length === 0) { + runtime.log("- none"); + return; + } + for (const entry of fallbacks) runtime.log(`- ${entry}`); +} + +export async function modelsFallbacksAddCommand( + modelRaw: string, + runtime: RuntimeEnv, +) { + const updated = await updateConfig((cfg) => { + const resolved = resolveModelTarget({ raw: modelRaw, cfg }); + const targetKey = modelKey(resolved.provider, resolved.model); + const aliasIndex = buildModelAliasIndex({ + cfg, + defaultProvider: DEFAULT_PROVIDER, + }); + const existing = cfg.agent?.modelFallbacks ?? []; + const existingKeys = existing + .map((entry) => + resolveModelRefFromString({ + raw: String(entry ?? ""), + defaultProvider: DEFAULT_PROVIDER, + aliasIndex, + }), + ) + .filter(Boolean) + .map((entry) => modelKey(entry!.ref.provider, entry!.ref.model)); + + if (existingKeys.includes(targetKey)) return cfg; + + return { + ...cfg, + agent: { + ...cfg.agent, + modelFallbacks: [...existing, targetKey], + }, + }; + }); + + runtime.log(`Updated ${CONFIG_PATH_CLAWDBOT}`); + runtime.log(`Fallbacks: ${(updated.agent?.modelFallbacks ?? []).join(", ")}`); +} + +export async function modelsFallbacksRemoveCommand( + modelRaw: string, + runtime: RuntimeEnv, +) { + const updated = await updateConfig((cfg) => { + const resolved = resolveModelTarget({ raw: modelRaw, cfg }); + const targetKey = modelKey(resolved.provider, resolved.model); + const aliasIndex = buildModelAliasIndex({ + cfg, + defaultProvider: DEFAULT_PROVIDER, + }); + const existing = cfg.agent?.modelFallbacks ?? []; + const filtered = existing.filter((entry) => { + const resolvedEntry = resolveModelRefFromString({ + raw: String(entry ?? ""), + defaultProvider: DEFAULT_PROVIDER, + aliasIndex, + }); + if (!resolvedEntry) return true; + return ( + modelKey(resolvedEntry.ref.provider, resolvedEntry.ref.model) !== + targetKey + ); + }); + + if (filtered.length === existing.length) { + throw new Error(`Fallback not found: ${targetKey}`); + } + + return { + ...cfg, + agent: { + ...cfg.agent, + modelFallbacks: filtered, + }, + }; + }); + + runtime.log(`Updated ${CONFIG_PATH_CLAWDBOT}`); + runtime.log(`Fallbacks: ${(updated.agent?.modelFallbacks ?? []).join(", ")}`); +} + +export async function modelsFallbacksClearCommand(runtime: RuntimeEnv) { + await updateConfig((cfg) => ({ + ...cfg, + agent: { + ...cfg.agent, + modelFallbacks: [], + }, + })); + + runtime.log(`Updated ${CONFIG_PATH_CLAWDBOT}`); + runtime.log("Fallback list cleared."); +} diff --git a/src/commands/models/list.ts b/src/commands/models/list.ts new file mode 100644 index 000000000..e74bb8b7d --- /dev/null +++ b/src/commands/models/list.ts @@ -0,0 +1,419 @@ +import chalk from "chalk"; +import { + discoverAuthStorage, + discoverModels, +} from "@mariozechner/pi-coding-agent"; +import { getEnvApiKey, type Api, type Model } from "@mariozechner/pi-ai"; + +import { resolveClawdbotAgentDir } from "../../agents/agent-paths.js"; +import { ensureClawdbotModelsJson } from "../../agents/models-config.js"; +import { + buildModelAliasIndex, + parseModelRef, + resolveModelRefFromString, + resolveConfiguredModelRef, +} from "../../agents/model-selection.js"; +import { + CONFIG_PATH_CLAWDBOT, + loadConfig, + type ClawdbotConfig, +} from "../../config/config.js"; +import { info } from "../../globals.js"; +import type { RuntimeEnv } from "../../runtime.js"; +import { + DEFAULT_MODEL, + DEFAULT_PROVIDER, + ensureFlagCompatibility, + formatTokenK, + modelKey, +} from "./shared.js"; + +const MODEL_PAD = 42; +const INPUT_PAD = 10; +const CTX_PAD = 8; +const LOCAL_PAD = 5; +const AUTH_PAD = 5; + +const isRich = (opts?: { json?: boolean; plain?: boolean }) => + Boolean(process.stdout.isTTY && chalk.level > 0 && !opts?.json && !opts?.plain); + +const pad = (value: string, size: number) => value.padEnd(size); + +const truncate = (value: string, max: number) => { + if (value.length <= max) return value; + if (max <= 3) return value.slice(0, max); + return `${value.slice(0, max - 3)}...`; +}; + +type ConfiguredEntry = { + key: string; + ref: { provider: string; model: string }; + tags: Set; + aliases: string[]; +}; + +type ModelRow = { + key: string; + name: string; + input: string; + contextWindow: number | null; + local: boolean | null; + available: boolean | null; + tags: string[]; + missing: boolean; +}; + +const isLocalBaseUrl = (baseUrl: string) => { + try { + const url = new URL(baseUrl); + const host = url.hostname.toLowerCase(); + return ( + host === "localhost" || + host === "127.0.0.1" || + host === "0.0.0.0" || + host === "::1" || + host.endsWith(".local") + ); + } catch { + return false; + } +}; + +const resolveConfiguredEntries = (cfg: ClawdbotConfig) => { + const resolvedDefault = resolveConfiguredModelRef({ + cfg, + defaultProvider: DEFAULT_PROVIDER, + defaultModel: DEFAULT_MODEL, + }); + const aliasIndex = buildModelAliasIndex({ + cfg, + defaultProvider: DEFAULT_PROVIDER, + }); + const order: string[] = []; + const tagsByKey = new Map>(); + const aliasesByKey = new Map(); + + for (const [key, aliases] of aliasIndex.byKey.entries()) { + aliasesByKey.set(key, aliases); + } + + const addEntry = (ref: { provider: string; model: string }, tag: string) => { + const key = modelKey(ref.provider, ref.model); + if (!tagsByKey.has(key)) { + tagsByKey.set(key, new Set()); + order.push(key); + } + tagsByKey.get(key)?.add(tag); + }; + + addEntry(resolvedDefault, "default"); + + (cfg.agent?.modelFallbacks ?? []).forEach((raw, idx) => { + const resolved = resolveModelRefFromString({ + raw: String(raw ?? ""), + defaultProvider: DEFAULT_PROVIDER, + aliasIndex, + }); + if (!resolved) return; + addEntry(resolved.ref, `fallback#${idx + 1}`); + }); + + (cfg.agent?.allowedModels ?? []).forEach((raw) => { + const parsed = parseModelRef(String(raw ?? ""), DEFAULT_PROVIDER); + if (!parsed) return; + addEntry(parsed, "allowed"); + }); + + for (const targetRaw of Object.values(cfg.agent?.modelAliases ?? {})) { + const resolved = resolveModelRefFromString({ + raw: String(targetRaw ?? ""), + defaultProvider: DEFAULT_PROVIDER, + aliasIndex, + }); + if (!resolved) continue; + addEntry(resolved.ref, "alias"); + } + + const entries: ConfiguredEntry[] = order.map((key) => { + const slash = key.indexOf("/"); + const provider = slash === -1 ? key : key.slice(0, slash); + const model = slash === -1 ? "" : key.slice(slash + 1); + return { + key, + ref: { provider, model }, + tags: tagsByKey.get(key) ?? new Set(), + aliases: aliasesByKey.get(key) ?? [], + } satisfies ConfiguredEntry; + }); + + return { entries }; +}; + +async function loadModelRegistry(cfg: ClawdbotConfig) { + await ensureClawdbotModelsJson(cfg); + const agentDir = resolveClawdbotAgentDir(); + const authStorage = discoverAuthStorage(agentDir); + const registry = discoverModels(authStorage, agentDir); + const models = registry.getAll() as Model[]; + const availableModels = registry.getAvailable() as Model[]; + const availableKeys = new Set( + availableModels.map((model) => modelKey(model.provider, model.id)), + ); + return { registry, models, availableKeys }; +} + +function toModelRow(params: { + model?: Model; + key: string; + tags: string[]; + aliases?: string[]; + availableKeys?: Set; +}): ModelRow { + const { model, key, tags, aliases = [], availableKeys } = params; + if (!model) { + return { + key, + name: key, + input: "-", + contextWindow: null, + local: null, + available: null, + tags: [...tags, "missing"], + missing: true, + }; + } + + const input = model.input.join("+") || "text"; + const local = isLocalBaseUrl(model.baseUrl); + const envKey = getEnvApiKey(model.provider); + const available = + availableKeys?.has(modelKey(model.provider, model.id)) || Boolean(envKey); + const aliasTags = aliases.length > 0 ? [`alias:${aliases.join(",")}`] : []; + const mergedTags = new Set(tags); + if (aliasTags.length > 0) { + for (const tag of mergedTags) { + if (tag === "alias" || tag.startsWith("alias:")) mergedTags.delete(tag); + } + for (const tag of aliasTags) mergedTags.add(tag); + } + + return { + key, + name: model.name || model.id, + input, + contextWindow: model.contextWindow ?? null, + local, + available, + tags: Array.from(mergedTags), + missing: false, + }; +} + +function printModelTable( + rows: ModelRow[], + runtime: RuntimeEnv, + opts: { json?: boolean; plain?: boolean } = {}, +) { + if (opts.json) { + runtime.log( + JSON.stringify( + { + count: rows.length, + models: rows, + }, + null, + 2, + ), + ); + return; + } + + if (opts.plain) { + for (const row of rows) runtime.log(row.key); + return; + } + + const rich = isRich(opts); + const header = [ + pad("Model", MODEL_PAD), + pad("Input", INPUT_PAD), + pad("Ctx", CTX_PAD), + pad("Local", LOCAL_PAD), + pad("Auth", AUTH_PAD), + "Tags", + ].join(" "); + runtime.log(rich ? chalk.bold(header) : header); + + for (const row of rows) { + const keyLabel = pad(truncate(row.key, MODEL_PAD), MODEL_PAD); + const inputLabel = pad(row.input || "-", INPUT_PAD); + const ctxLabel = pad(formatTokenK(row.contextWindow), CTX_PAD); + const localLabel = pad( + row.local === null ? "-" : row.local ? "yes" : "no", + LOCAL_PAD, + ); + const authLabel = pad( + row.available === null ? "-" : row.available ? "yes" : "no", + AUTH_PAD, + ); + const tagsLabel = row.tags.length > 0 ? row.tags.join(",") : ""; + + const line = [ + rich ? chalk.cyan(keyLabel) : keyLabel, + inputLabel, + ctxLabel, + localLabel, + authLabel, + rich ? chalk.gray(tagsLabel) : tagsLabel, + ].join(" "); + runtime.log(line); + } +} + +export async function modelsListCommand( + opts: { + all?: boolean; + local?: boolean; + provider?: string; + json?: boolean; + plain?: boolean; + }, + runtime: RuntimeEnv, +) { + ensureFlagCompatibility(opts); + const cfg = loadConfig(); + const providerFilter = opts.provider?.trim().toLowerCase(); + + let models: Model[] = []; + let availableKeys: Set | undefined; + try { + const loaded = await loadModelRegistry(cfg); + models = loaded.models; + availableKeys = loaded.availableKeys; + } catch (err) { + runtime.error(`Model registry unavailable: ${String(err)}`); + } + + const modelByKey = new Map( + models.map((model) => [modelKey(model.provider, model.id), model]), + ); + + const { entries } = resolveConfiguredEntries(cfg); + const configuredByKey = new Map(entries.map((entry) => [entry.key, entry])); + + const rows: ModelRow[] = []; + + if (opts.all) { + const sorted = [...models].sort((a, b) => { + const p = a.provider.localeCompare(b.provider); + if (p !== 0) return p; + return a.id.localeCompare(b.id); + }); + + for (const model of sorted) { + if (providerFilter && model.provider.toLowerCase() !== providerFilter) { + continue; + } + if (opts.local && !isLocalBaseUrl(model.baseUrl)) continue; + const key = modelKey(model.provider, model.id); + const configured = configuredByKey.get(key); + rows.push( + toModelRow({ + model, + key, + tags: configured ? Array.from(configured.tags) : [], + aliases: configured?.aliases ?? [], + availableKeys, + }), + ); + } + } else { + for (const entry of entries) { + if ( + providerFilter && + entry.ref.provider.toLowerCase() !== providerFilter + ) { + continue; + } + const model = modelByKey.get(entry.key); + if (opts.local && model && !isLocalBaseUrl(model.baseUrl)) continue; + if (opts.local && !model) continue; + rows.push( + toModelRow({ + model, + key: entry.key, + tags: Array.from(entry.tags), + aliases: entry.aliases, + availableKeys, + }), + ); + } + } + + if (rows.length === 0) { + runtime.log("No models found."); + return; + } + + printModelTable(rows, runtime, opts); +} + +export async function modelsStatusCommand( + opts: { json?: boolean; plain?: boolean }, + runtime: RuntimeEnv, +) { + ensureFlagCompatibility(opts); + const cfg = loadConfig(); + const resolved = resolveConfiguredModelRef({ + cfg, + defaultProvider: DEFAULT_PROVIDER, + defaultModel: DEFAULT_MODEL, + }); + + const rawModel = cfg.agent?.model?.trim() ?? ""; + const defaultLabel = rawModel || `${resolved.provider}/${resolved.model}`; + const fallbacks = cfg.agent?.modelFallbacks ?? []; + const aliases = cfg.agent?.modelAliases ?? {}; + const allowed = cfg.agent?.allowedModels ?? []; + + if (opts.json) { + runtime.log( + JSON.stringify( + { + configPath: CONFIG_PATH_CLAWDBOT, + defaultModel: defaultLabel, + resolvedDefault: `${resolved.provider}/${resolved.model}`, + fallbacks, + aliases, + allowed, + }, + null, + 2, + ), + ); + return; + } + + if (opts.plain) { + runtime.log(defaultLabel); + return; + } + + runtime.log(info(`Config: ${CONFIG_PATH_CLAWDBOT}`)); + runtime.log(`Default: ${defaultLabel}`); + runtime.log( + `Fallbacks (${fallbacks.length || 0}): ${fallbacks.join(", ") || "-"}`, + ); + runtime.log( + `Aliases (${Object.keys(aliases).length || 0}): ${ + Object.keys(aliases).length + ? Object.entries(aliases) + .map(([alias, target]) => `${alias} -> ${target}`) + .join(", ") + : "-" + }`, + ); + runtime.log( + `Allowed (${allowed.length || 0}): ${allowed.length ? allowed.join(", ") : "all"}`, + ); +} diff --git a/src/commands/models/scan.ts b/src/commands/models/scan.ts new file mode 100644 index 000000000..d02882cd5 --- /dev/null +++ b/src/commands/models/scan.ts @@ -0,0 +1,267 @@ +import { cancel, isCancel, multiselect } from "@clack/prompts"; +import { discoverAuthStorage } from "@mariozechner/pi-coding-agent"; + +import { resolveClawdbotAgentDir } from "../../agents/agent-paths.js"; +import { + scanOpenRouterModels, + type ModelScanResult, +} from "../../agents/model-scan.js"; +import { warn } from "../../globals.js"; +import type { RuntimeEnv } from "../../runtime.js"; +import { + buildAllowlistSet, + formatMs, + formatTokenK, + updateConfig, +} from "./shared.js"; +import { CONFIG_PATH_CLAWDBOT } from "../../config/config.js"; + +const MODEL_PAD = 42; +const CTX_PAD = 8; + +const pad = (value: string, size: number) => value.padEnd(size); + +const truncate = (value: string, max: number) => { + if (value.length <= max) return value; + if (max <= 3) return value.slice(0, max); + return `${value.slice(0, max - 3)}...`; +}; + + +function sortScanResults(results: ModelScanResult[]): ModelScanResult[] { + return results.slice().sort((a, b) => { + const aImage = a.image.ok ? 1 : 0; + const bImage = b.image.ok ? 1 : 0; + if (aImage !== bImage) return bImage - aImage; + + const aToolLatency = a.tool.latencyMs ?? Number.POSITIVE_INFINITY; + const bToolLatency = b.tool.latencyMs ?? Number.POSITIVE_INFINITY; + if (aToolLatency !== bToolLatency) return aToolLatency - bToolLatency; + + const aCtx = a.contextLength ?? 0; + const bCtx = b.contextLength ?? 0; + if (aCtx !== bCtx) return bCtx - aCtx; + + const aParams = a.inferredParamB ?? 0; + const bParams = b.inferredParamB ?? 0; + if (aParams !== bParams) return bParams - aParams; + + return a.modelRef.localeCompare(b.modelRef); + }); +} + +function buildScanHint(result: ModelScanResult): string { + const toolLabel = result.tool.ok + ? `tool ${formatMs(result.tool.latencyMs)}` + : "tool fail"; + const imageLabel = result.image.skipped + ? "img skip" + : result.image.ok + ? `img ${formatMs(result.image.latencyMs)}` + : "img fail"; + const ctxLabel = result.contextLength + ? `ctx ${formatTokenK(result.contextLength)}` + : "ctx ?"; + const paramLabel = result.inferredParamB ? `${result.inferredParamB}b` : null; + return [toolLabel, imageLabel, ctxLabel, paramLabel] + .filter(Boolean) + .join(" | "); +} + +function printScanSummary(results: ModelScanResult[], runtime: RuntimeEnv) { + const toolOk = results.filter((r) => r.tool.ok); + const imageOk = results.filter((r) => r.image.ok); + const toolImageOk = results.filter((r) => r.tool.ok && r.image.ok); + runtime.log( + `Scan results: tested ${results.length}, tool ok ${toolOk.length}, image ok ${imageOk.length}, tool+image ok ${toolImageOk.length}`, + ); +} + +function printScanTable(results: ModelScanResult[], runtime: RuntimeEnv) { + const header = [ + pad("Model", MODEL_PAD), + pad("Tool", 10), + pad("Image", 10), + pad("Ctx", CTX_PAD), + pad("Params", 8), + "Notes", + ].join(" "); + runtime.log(header); + + for (const entry of results) { + const modelLabel = pad(truncate(entry.modelRef, MODEL_PAD), MODEL_PAD); + const toolLabel = pad( + entry.tool.ok ? formatMs(entry.tool.latencyMs) : "fail", + 10, + ); + const imageLabel = pad( + entry.image.ok + ? formatMs(entry.image.latencyMs) + : entry.image.skipped + ? "skip" + : "fail", + 10, + ); + const ctxLabel = pad(formatTokenK(entry.contextLength), CTX_PAD); + const paramsLabel = pad( + entry.inferredParamB ? `${entry.inferredParamB}b` : "-", + 8, + ); + const notes = entry.modality ? `modality:${entry.modality}` : ""; + + runtime.log( + [modelLabel, toolLabel, imageLabel, ctxLabel, paramsLabel, notes].join( + " ", + ), + ); + } +} + +export async function modelsScanCommand( + opts: { + minParams?: string; + maxAgeDays?: string; + provider?: string; + maxCandidates?: string; + timeout?: string; + concurrency?: string; + yes?: boolean; + input?: boolean; + setDefault?: boolean; + json?: boolean; + }, + runtime: RuntimeEnv, +) { + const minParams = opts.minParams ? Number(opts.minParams) : undefined; + if (minParams !== undefined && (!Number.isFinite(minParams) || minParams < 0)) { + throw new Error("--min-params must be >= 0"); + } + const maxAgeDays = opts.maxAgeDays ? Number(opts.maxAgeDays) : undefined; + if (maxAgeDays !== undefined && (!Number.isFinite(maxAgeDays) || maxAgeDays < 0)) { + throw new Error("--max-age-days must be >= 0"); + } + const maxCandidates = opts.maxCandidates + ? Number(opts.maxCandidates) + : 6; + if (!Number.isFinite(maxCandidates) || maxCandidates <= 0) { + throw new Error("--max-candidates must be > 0"); + } + const timeout = opts.timeout ? Number(opts.timeout) : undefined; + if (timeout !== undefined && (!Number.isFinite(timeout) || timeout <= 0)) { + throw new Error("--timeout must be > 0"); + } + const concurrency = opts.concurrency ? Number(opts.concurrency) : undefined; + if (concurrency !== undefined && (!Number.isFinite(concurrency) || concurrency <= 0)) { + throw new Error("--concurrency must be > 0"); + } + + const authStorage = discoverAuthStorage(resolveClawdbotAgentDir()); + const storedKey = await authStorage.getApiKey("openrouter"); + const results = await scanOpenRouterModels({ + apiKey: storedKey ?? undefined, + minParamB: minParams, + maxAgeDays, + providerFilter: opts.provider, + timeoutMs: timeout, + concurrency, + }); + + const toolOk = results.filter((entry) => entry.tool.ok); + if (toolOk.length === 0) { + throw new Error("No tool-capable OpenRouter free models found."); + } + + const sorted = sortScanResults(toolOk); + const imagePreferred = sorted.filter((entry) => entry.image.ok); + const preselectPool = imagePreferred.length > 0 ? imagePreferred : sorted; + const preselected = preselectPool + .slice(0, Math.floor(maxCandidates)) + .map((entry) => entry.modelRef); + + if (!opts.json) { + printScanSummary(results, runtime); + printScanTable(sorted, runtime); + } + + const noInput = opts.input === false; + const canPrompt = process.stdin.isTTY && !opts.yes && !noInput && !opts.json; + let selected: string[] = preselected; + + if (canPrompt) { + const selection = await multiselect({ + message: "Select fallback models (ordered)", + options: sorted.map((entry) => ({ + value: entry.modelRef, + label: entry.modelRef, + hint: buildScanHint(entry), + })), + initialValues: preselected, + }); + + if (isCancel(selection)) { + cancel("Model scan cancelled."); + runtime.exit(0); + } + + selected = selection as string[]; + } else if (!process.stdin.isTTY && !opts.yes && !noInput && !opts.json) { + throw new Error("Non-interactive scan: pass --yes to apply defaults."); + } + + if (selected.length === 0) { + throw new Error("No models selected for fallbacks."); + } + + const updated = await updateConfig((cfg) => { + const next = { + ...cfg, + agent: { + ...cfg.agent, + modelFallbacks: selected, + ...(opts.setDefault ? { model: selected[0] } : {}), + }, + }; + return next; + }); + + const allowlist = buildAllowlistSet(updated); + const allowlistMissing = + allowlist.size > 0 + ? selected.filter((entry) => !allowlist.has(entry)) + : []; + + if (opts.json) { + runtime.log( + JSON.stringify( + { + selected, + setDefault: Boolean(opts.setDefault), + results, + warnings: + allowlistMissing.length > 0 + ? [ + `Selected models not in agent.allowedModels: ${allowlistMissing.join(", ")}`, + ] + : [], + }, + null, + 2, + ), + ); + return; + } + + if (allowlistMissing.length > 0) { + runtime.log( + warn( + `Warning: ${allowlistMissing.length} selected models are not in agent.allowedModels and will be ignored by fallback: ${allowlistMissing.join(", ")}`, + ), + ); + } + + runtime.log(`Updated ${CONFIG_PATH_CLAWDBOT}`); + runtime.log(`Fallbacks: ${selected.join(", ")}`); + if (opts.setDefault) { + runtime.log(`Default model: ${selected[0]}`); + } +} diff --git a/src/commands/models/set.ts b/src/commands/models/set.ts new file mode 100644 index 000000000..b3d8f4532 --- /dev/null +++ b/src/commands/models/set.ts @@ -0,0 +1,29 @@ +import { CONFIG_PATH_CLAWDBOT } from "../../config/config.js"; +import type { RuntimeEnv } from "../../runtime.js"; +import { buildAllowlistSet, modelKey, resolveModelTarget, updateConfig } from "./shared.js"; + +export async function modelsSetCommand( + modelRaw: string, + runtime: RuntimeEnv, +) { + const updated = await updateConfig((cfg) => { + const resolved = resolveModelTarget({ raw: modelRaw, cfg }); + const allowlist = buildAllowlistSet(cfg); + if (allowlist.size > 0) { + const key = modelKey(resolved.provider, resolved.model); + if (!allowlist.has(key)) { + throw new Error(`Model ${key} is not in agent.allowedModels.`); + } + } + return { + ...cfg, + agent: { + ...cfg.agent, + model: `${resolved.provider}/${resolved.model}`, + }, + }; + }); + + runtime.log(`Updated ${CONFIG_PATH_CLAWDBOT}`); + runtime.log(`Default model: ${updated.agent?.model ?? modelRaw}`); +} diff --git a/src/commands/models/shared.ts b/src/commands/models/shared.ts new file mode 100644 index 000000000..fedf425b4 --- /dev/null +++ b/src/commands/models/shared.ts @@ -0,0 +1,95 @@ +import { + DEFAULT_MODEL, + DEFAULT_PROVIDER, +} from "../../agents/defaults.js"; +import { + buildModelAliasIndex, + modelKey, + parseModelRef, + resolveModelRefFromString, +} from "../../agents/model-selection.js"; +import { + readConfigFileSnapshot, + writeConfigFile, + type ClawdbotConfig, +} from "../../config/config.js"; + +export const ensureFlagCompatibility = (opts: { + json?: boolean; + plain?: boolean; +}) => { + if (opts.json && opts.plain) { + throw new Error("Choose either --json or --plain, not both."); + } +}; + +export const formatTokenK = (value?: number | null) => { + if (!value || !Number.isFinite(value)) return "-"; + if (value < 1024) return `${Math.round(value)}`; + return `${Math.round(value / 1024)}k`; +}; + +export const formatMs = (value?: number | null) => { + if (value === null || value === undefined) return "-"; + if (!Number.isFinite(value)) return "-"; + if (value < 1000) return `${Math.round(value)}ms`; + return `${Math.round(value / 100) / 10}s`; +}; + +export async function updateConfig( + mutator: (cfg: ClawdbotConfig) => ClawdbotConfig, +): Promise { + const snapshot = await readConfigFileSnapshot(); + if (!snapshot.valid) { + const issues = snapshot.issues + .map((issue) => `- ${issue.path}: ${issue.message}`) + .join("\n"); + throw new Error(`Invalid config at ${snapshot.path}\n${issues}`); + } + const next = mutator(snapshot.config); + await writeConfigFile(next); + return next; +} + +export function resolveModelTarget(params: { + raw: string; + cfg: ClawdbotConfig; +}): { provider: string; model: string } { + const aliasIndex = buildModelAliasIndex({ + cfg: params.cfg, + defaultProvider: DEFAULT_PROVIDER, + }); + const resolved = resolveModelRefFromString({ + raw: params.raw, + defaultProvider: DEFAULT_PROVIDER, + aliasIndex, + }); + if (!resolved) { + throw new Error(`Invalid model reference: ${params.raw}`); + } + return resolved.ref; +} + +export function buildAllowlistSet(cfg: ClawdbotConfig): Set { + const allowed = new Set(); + for (const raw of cfg.agent?.allowedModels ?? []) { + const parsed = parseModelRef(String(raw ?? ""), DEFAULT_PROVIDER); + if (!parsed) continue; + allowed.add(modelKey(parsed.provider, parsed.model)); + } + return allowed; +} + +export function normalizeAlias(alias: string): string { + const trimmed = alias.trim(); + if (!trimmed) throw new Error("Alias cannot be empty."); + if (!/^[A-Za-z0-9_.:-]+$/.test(trimmed)) { + throw new Error( + "Alias must use letters, numbers, dots, underscores, colons, or dashes.", + ); + } + return trimmed; +} + +export { modelKey }; +export { DEFAULT_MODEL, DEFAULT_PROVIDER }; diff --git a/src/config/schema.ts b/src/config/schema.ts index a21392f7b..2573eb736 100644 --- a/src/config/schema.ts +++ b/src/config/schema.ts @@ -88,6 +88,7 @@ const FIELD_LABELS: Record = { "gateway.reload.debounceMs": "Config Reload Debounce (ms)", "agent.workspace": "Workspace", "agent.model": "Default Model", + "agent.modelFallbacks": "Model Fallbacks", "ui.seamColor": "Accent Color", "browser.controlUrl": "Browser Control URL", "session.agentToAgent.maxPingPongTurns": "Agent-to-Agent Ping-Pong Turns", @@ -111,6 +112,8 @@ const FIELD_HELP: Record = { 'Hot reload strategy for config changes ("hybrid" recommended).', "gateway.reload.debounceMs": "Debounce window (ms) before applying config changes.", + "agent.modelFallbacks": + "Ordered fallback models (provider/model). Used when the primary model fails.", "session.agentToAgent.maxPingPongTurns": "Max reply-back turns between requester and target (0–5).", }; diff --git a/src/config/sessions.ts b/src/config/sessions.ts index 3d8e94408..3aa691832 100644 --- a/src/config/sessions.ts +++ b/src/config/sessions.ts @@ -50,6 +50,7 @@ export type SessionEntry = { inputTokens?: number; outputTokens?: number; totalTokens?: number; + modelProvider?: string; model?: string; contextTokens?: number; displayName?: string; @@ -335,6 +336,7 @@ export async function updateLastRoute(params: { inputTokens: existing?.inputTokens, outputTokens: existing?.outputTokens, totalTokens: existing?.totalTokens, + modelProvider: existing?.modelProvider, model: existing?.model, contextTokens: existing?.contextTokens, displayName: existing?.displayName, diff --git a/src/config/types.ts b/src/config/types.ts index 0c106c071..26c465f34 100644 --- a/src/config/types.ts +++ b/src/config/types.ts @@ -666,6 +666,8 @@ export type ClawdbotConfig = { allowedModels?: string[]; /** Optional model aliases for /model (alias -> provider/model). */ modelAliases?: Record; + /** Ordered fallback models (provider/model). */ + modelFallbacks?: string[]; /** Optional display-only context window override (used for % in status UIs). */ contextTokens?: number; /** Default thinking level when no /think directive is present. */ diff --git a/src/config/zod-schema.ts b/src/config/zod-schema.ts index b778c605f..4e1f4f17f 100644 --- a/src/config/zod-schema.ts +++ b/src/config/zod-schema.ts @@ -366,6 +366,7 @@ export const ClawdbotSchema = z.object({ workspace: z.string().optional(), allowedModels: z.array(z.string()).optional(), modelAliases: z.record(z.string(), z.string()).optional(), + modelFallbacks: z.array(z.string()).optional(), contextTokens: z.number().int().positive().optional(), thinkingDefault: z .union([ diff --git a/src/cron/isolated-agent.ts b/src/cron/isolated-agent.ts index ca4c89b78..0a7def723 100644 --- a/src/cron/isolated-agent.ts +++ b/src/cron/isolated-agent.ts @@ -10,6 +10,7 @@ import { resolveConfiguredModelRef, resolveThinkingDefault, } from "../agents/model-selection.js"; +import { runWithModelFallback } from "../agents/model-fallback.js"; import { runEmbeddedPiAgent } from "../agents/pi-embedded.js"; import { buildWorkspaceSkillSnapshot } from "../agents/skills.js"; import { @@ -264,6 +265,8 @@ export async function runCronIsolatedAgentTurn(params: { } let runResult: Awaited>; + let fallbackProvider = provider; + let fallbackModel = model; try { const sessionFile = resolveSessionTranscriptPath( cronSession.sessionEntry.sessionId, @@ -272,25 +275,34 @@ export async function runCronIsolatedAgentTurn(params: { sessionKey: params.sessionKey, }); const surface = resolvedDelivery.channel; - runResult = await runEmbeddedPiAgent({ - sessionId: cronSession.sessionEntry.sessionId, - sessionKey: params.sessionKey, - surface, - sessionFile, - workspaceDir, - config: params.cfg, - skillsSnapshot, - prompt: commandBody, - lane: params.lane ?? "cron", + const fallbackResult = await runWithModelFallback({ + cfg: params.cfg, provider, model, - thinkLevel, - verboseLevel: - (cronSession.sessionEntry.verboseLevel as "on" | "off" | undefined) ?? - (agentCfg?.verboseDefault as "on" | "off" | undefined), - timeoutMs, - runId: cronSession.sessionEntry.sessionId, + run: (providerOverride, modelOverride) => + runEmbeddedPiAgent({ + sessionId: cronSession.sessionEntry.sessionId, + sessionKey: params.sessionKey, + surface, + sessionFile, + workspaceDir, + config: params.cfg, + skillsSnapshot, + prompt: commandBody, + lane: params.lane ?? "cron", + provider: providerOverride, + model: modelOverride, + thinkLevel, + verboseLevel: + (cronSession.sessionEntry.verboseLevel as "on" | "off" | undefined) ?? + (agentCfg?.verboseDefault as "on" | "off" | undefined), + timeoutMs, + runId: cronSession.sessionEntry.sessionId, + }), }); + runResult = fallbackResult.result; + fallbackProvider = fallbackResult.provider; + fallbackModel = fallbackResult.model; } catch (err) { return { status: "error", error: String(err) }; } @@ -300,12 +312,16 @@ export async function runCronIsolatedAgentTurn(params: { // Update token+model fields in the session store. { const usage = runResult.meta.agentMeta?.usage; - const modelUsed = runResult.meta.agentMeta?.model ?? model; + const modelUsed = + runResult.meta.agentMeta?.model ?? fallbackModel ?? model; + const providerUsed = + runResult.meta.agentMeta?.provider ?? fallbackProvider ?? provider; const contextTokens = agentCfg?.contextTokens ?? lookupContextTokens(modelUsed) ?? DEFAULT_CONTEXT_TOKENS; + cronSession.sessionEntry.modelProvider = providerUsed; cronSession.sessionEntry.model = modelUsed; cronSession.sessionEntry.contextTokens = contextTokens; if (usage) {