From 64fc5fa9fc73e6b9baa78f4dfc74cc9d79902be9 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Fri, 9 Jan 2026 04:18:21 +0000 Subject: [PATCH] fix: allow default model outside allowlist --- src/agents/model-selection.test.ts | 156 ++++-------------- src/agents/model-selection.ts | 12 ++ .../reply/agent-runner.claude-cli.test.ts | 137 +++++++++++++++ src/auto-reply/reply/agent-runner.ts | 71 +++++++- src/auto-reply/reply/model-selection.ts | 2 + src/commands/agent.ts | 1 + src/cron/isolated-agent.ts | 1 + src/gateway/server-bridge.ts | 1 + src/gateway/server-methods/sessions.ts | 1 + 9 files changed, 255 insertions(+), 127 deletions(-) create mode 100644 src/auto-reply/reply/agent-runner.claude-cli.test.ts diff --git a/src/agents/model-selection.test.ts b/src/agents/model-selection.test.ts index 8131da54f..6011ab4fd 100644 --- a/src/agents/model-selection.test.ts +++ b/src/agents/model-selection.test.ts @@ -1,146 +1,56 @@ import { describe, expect, it } from "vitest"; import type { ClawdbotConfig } from "../config/config.js"; -import { DEFAULT_MODEL, DEFAULT_PROVIDER } from "./defaults.js"; -import { - normalizeProviderId, - resolveConfiguredModelRef, -} from "./model-selection.js"; +import { buildAllowedModelSet, modelKey } from "./model-selection.js"; -describe("resolveConfiguredModelRef", () => { - it("parses provider/model from agent.model.primary", () => { - const cfg = { - agent: { model: { primary: "openai/gpt-4.1-mini" } }, - } satisfies ClawdbotConfig; +const catalog = [ + { + provider: "openai", + id: "gpt-4", + name: "GPT-4", + }, +]; - const resolved = resolveConfiguredModelRef({ - cfg, - defaultProvider: DEFAULT_PROVIDER, - defaultModel: DEFAULT_MODEL, - }); - - expect(resolved).toEqual({ provider: "openai", model: "gpt-4.1-mini" }); - }); - - it("falls back to anthropic when agent.model.primary omits provider", () => { - const cfg = { - agent: { model: { primary: "claude-opus-4-5" } }, - } satisfies ClawdbotConfig; - - const resolved = resolveConfiguredModelRef({ - cfg, - defaultProvider: DEFAULT_PROVIDER, - defaultModel: DEFAULT_MODEL, - }); - - expect(resolved).toEqual({ - provider: "anthropic", - model: "claude-opus-4-5", - }); - }); - - it("falls back to defaults when agent.model is missing", () => { - const cfg = {} satisfies ClawdbotConfig; - - const resolved = resolveConfiguredModelRef({ - cfg, - defaultProvider: DEFAULT_PROVIDER, - defaultModel: DEFAULT_MODEL, - }); - - expect(resolved).toEqual({ - provider: DEFAULT_PROVIDER, - model: DEFAULT_MODEL, - }); - }); - - it("resolves agent.model aliases when configured", () => { +describe("buildAllowedModelSet", () => { + it("always allows the configured default model", () => { const cfg = { agent: { - model: { primary: "Opus" }, models: { - "anthropic/claude-opus-4-5": { alias: "Opus" }, + "openai/gpt-4": { alias: "gpt4" }, }, }, - } satisfies ClawdbotConfig; + } as ClawdbotConfig; - const resolved = resolveConfiguredModelRef({ + const allowed = buildAllowedModelSet({ cfg, - defaultProvider: DEFAULT_PROVIDER, - defaultModel: DEFAULT_MODEL, + catalog, + defaultProvider: "claude-cli", + defaultModel: "opus-4.5", }); - expect(resolved).toEqual({ - provider: "anthropic", - model: "claude-opus-4-5", - }); + expect(allowed.allowAny).toBe(false); + expect(allowed.allowedKeys.has(modelKey("openai", "gpt-4"))).toBe(true); + expect( + allowed.allowedKeys.has(modelKey("claude-cli", "opus-4.5")), + ).toBe(true); }); - it("normalizes z.ai provider in agent.model", () => { + it("includes the default model when no allowlist is set", () => { const cfg = { - agent: { model: "z.ai/glm-4.7" }, - } satisfies ClawdbotConfig; + agent: {}, + } as ClawdbotConfig; - const resolved = resolveConfiguredModelRef({ + const allowed = buildAllowedModelSet({ cfg, - defaultProvider: DEFAULT_PROVIDER, - defaultModel: DEFAULT_MODEL, + catalog, + defaultProvider: "claude-cli", + defaultModel: "opus-4.5", }); - expect(resolved).toEqual({ provider: "zai", model: "glm-4.7" }); - }); - - it("normalizes z-ai provider in agent.model", () => { - const cfg = { - agent: { model: "z-ai/glm-4.7" }, - } satisfies ClawdbotConfig; - - const resolved = resolveConfiguredModelRef({ - cfg, - defaultProvider: DEFAULT_PROVIDER, - defaultModel: DEFAULT_MODEL, - }); - - expect(resolved).toEqual({ provider: "zai", model: "glm-4.7" }); - }); - - it("normalizes provider casing in agent.model", () => { - const cfg = { - agent: { model: "OpenAI/gpt-4.1-mini" }, - } satisfies ClawdbotConfig; - - const resolved = resolveConfiguredModelRef({ - cfg, - defaultProvider: DEFAULT_PROVIDER, - defaultModel: DEFAULT_MODEL, - }); - - expect(resolved).toEqual({ provider: "openai", model: "gpt-4.1-mini" }); - }); - - it("normalizes z.ai casing in agent.model", () => { - const cfg = { - agent: { model: "Z.AI/glm-4.7" }, - } satisfies ClawdbotConfig; - - const resolved = resolveConfiguredModelRef({ - cfg, - defaultProvider: DEFAULT_PROVIDER, - defaultModel: DEFAULT_MODEL, - }); - - expect(resolved).toEqual({ provider: "zai", model: "glm-4.7" }); - }); -}); - -describe("normalizeProviderId", () => { - it("normalizes z.ai aliases to canonical zai", () => { - expect(normalizeProviderId("z.ai")).toBe("zai"); - expect(normalizeProviderId("z-ai")).toBe("zai"); - }); - - it("normalizes provider casing", () => { - expect(normalizeProviderId("OpenAI")).toBe("openai"); - expect(normalizeProviderId("Z.AI")).toBe("zai"); + expect(allowed.allowAny).toBe(true); + expect(allowed.allowedKeys.has(modelKey("openai", "gpt-4"))).toBe(true); + expect( + allowed.allowedKeys.has(modelKey("claude-cli", "opus-4.5")), + ).toBe(true); }); }); diff --git a/src/agents/model-selection.ts b/src/agents/model-selection.ts index 12a06a44b..93ce39aaa 100644 --- a/src/agents/model-selection.ts +++ b/src/agents/model-selection.ts @@ -124,6 +124,7 @@ export function buildAllowedModelSet(params: { cfg: ClawdbotConfig; catalog: ModelCatalogEntry[]; defaultProvider: string; + defaultModel?: string; }): { allowAny: boolean; allowedCatalog: ModelCatalogEntry[]; @@ -134,11 +135,17 @@ export function buildAllowedModelSet(params: { return Object.keys(modelMap); })(); const allowAny = rawAllowlist.length === 0; + const defaultModel = params.defaultModel?.trim(); + const defaultKey = + defaultModel && params.defaultProvider + ? modelKey(params.defaultProvider, defaultModel) + : undefined; const catalogKeys = new Set( params.catalog.map((entry) => modelKey(entry.provider, entry.id)), ); if (allowAny) { + if (defaultKey) catalogKeys.add(defaultKey); return { allowAny: true, allowedCatalog: params.catalog, @@ -156,11 +163,16 @@ export function buildAllowedModelSet(params: { } } + if (defaultKey) { + allowedKeys.add(defaultKey); + } + const allowedCatalog = params.catalog.filter((entry) => allowedKeys.has(modelKey(entry.provider, entry.id)), ); if (allowedCatalog.length === 0) { + if (defaultKey) catalogKeys.add(defaultKey); return { allowAny: true, allowedCatalog: params.catalog, diff --git a/src/auto-reply/reply/agent-runner.claude-cli.test.ts b/src/auto-reply/reply/agent-runner.claude-cli.test.ts new file mode 100644 index 000000000..f4e109148 --- /dev/null +++ b/src/auto-reply/reply/agent-runner.claude-cli.test.ts @@ -0,0 +1,137 @@ +import crypto from "node:crypto"; +import { describe, expect, it, vi } from "vitest"; + +import type { TemplateContext } from "../templating.js"; +import { onAgentEvent } from "../../infra/agent-events.js"; +import type { FollowupRun, QueueSettings } from "./queue.js"; +import { createMockTypingController } from "./test-helpers.js"; + +const runEmbeddedPiAgentMock = vi.fn(); +const runClaudeCliAgentMock = vi.fn(); + +vi.mock("../../agents/model-fallback.js", () => ({ + runWithModelFallback: async ({ + provider, + model, + run, + }: { + provider: string; + model: string; + run: (provider: string, model: string) => Promise; + }) => ({ + result: await run(provider, model), + provider, + model, + }), +})); + +vi.mock("../../agents/pi-embedded.js", () => ({ + queueEmbeddedPiMessage: vi.fn().mockReturnValue(false), + runEmbeddedPiAgent: (params: unknown) => runEmbeddedPiAgentMock(params), +})); + +vi.mock("../../agents/claude-cli-runner.js", () => ({ + runClaudeCliAgent: (params: unknown) => runClaudeCliAgentMock(params), +})); + +vi.mock("./queue.js", async () => { + const actual = + await vi.importActual("./queue.js"); + return { + ...actual, + enqueueFollowupRun: vi.fn(), + scheduleFollowupDrain: vi.fn(), + }; +}); + +import { runReplyAgent } from "./agent-runner.js"; + +function createRun() { + const typing = createMockTypingController(); + const sessionCtx = { + Provider: "webchat", + OriginatingTo: "session:1", + AccountId: "primary", + MessageSid: "msg", + } as unknown as TemplateContext; + const resolvedQueue = { mode: "interrupt" } as unknown as QueueSettings; + const followupRun = { + prompt: "hello", + summaryLine: "hello", + enqueuedAt: Date.now(), + run: { + sessionId: "session", + sessionKey: "main", + messageProvider: "webchat", + sessionFile: "/tmp/session.jsonl", + workspaceDir: "/tmp", + config: {}, + skillsSnapshot: {}, + provider: "claude-cli", + model: "opus-4.5", + thinkLevel: "low", + verboseLevel: "off", + elevatedLevel: "off", + bashElevated: { + enabled: false, + allowed: false, + defaultLevel: "off", + }, + timeoutMs: 1_000, + blockReplyBreak: "message_end", + }, + } as unknown as FollowupRun; + + return runReplyAgent({ + commandBody: "hello", + followupRun, + queueKey: "main", + resolvedQueue, + shouldSteer: false, + shouldFollowup: false, + isActive: false, + isStreaming: false, + typing, + sessionCtx, + defaultModel: "claude-cli/opus-4.5", + resolvedVerboseLevel: "off", + isNewSession: false, + blockStreamingEnabled: false, + resolvedBlockStreamingBreak: "message_end", + shouldInjectGroupIntro: false, + typingMode: "instant", + }); +} + +describe("runReplyAgent claude-cli routing", () => { + it("uses claude-cli runner for claude-cli provider", async () => { + const randomSpy = vi + .spyOn(crypto, "randomUUID") + .mockReturnValue("run-1"); + const lifecyclePhases: string[] = []; + const unsubscribe = onAgentEvent((evt) => { + if (evt.runId !== "run-1") return; + if (evt.stream !== "lifecycle") return; + const phase = evt.data?.phase; + if (typeof phase === "string") lifecyclePhases.push(phase); + }); + runClaudeCliAgentMock.mockResolvedValueOnce({ + payloads: [{ text: "ok" }], + meta: { + agentMeta: { + provider: "claude-cli", + model: "opus-4.5", + }, + }, + }); + + const result = await createRun(); + unsubscribe(); + randomSpy.mockRestore(); + + expect(runClaudeCliAgentMock).toHaveBeenCalledTimes(1); + expect(runEmbeddedPiAgentMock).not.toHaveBeenCalled(); + expect(lifecyclePhases).toEqual(["start", "end"]); + expect(result).toMatchObject({ text: "ok" }); + }); +}); diff --git a/src/auto-reply/reply/agent-runner.ts b/src/auto-reply/reply/agent-runner.ts index 076b92fa2..068cb2259 100644 --- a/src/auto-reply/reply/agent-runner.ts +++ b/src/auto-reply/reply/agent-runner.ts @@ -1,5 +1,6 @@ import crypto from "node:crypto"; import fs from "node:fs"; +import { runClaudeCliAgent } from "../../agents/claude-cli-runner.js"; import { lookupContextTokens } from "../../agents/context.js"; import { DEFAULT_CONTEXT_TOKENS } from "../../agents/defaults.js"; import { resolveModelAuthMode } from "../../agents/model-auth.js"; @@ -17,7 +18,7 @@ import { } from "../../config/sessions.js"; import type { TypingMode } from "../../config/types.js"; import { logVerbose } from "../../globals.js"; -import { registerAgentRunContext } from "../../infra/agent-events.js"; +import { emitAgentEvent, registerAgentRunContext } from "../../infra/agent-events.js"; import { defaultRuntime } from "../../runtime.js"; import { estimateUsageCost, @@ -326,8 +327,61 @@ export async function runReplyAgent(params: { cfg: followupRun.run.config, provider: followupRun.run.provider, model: followupRun.run.model, - run: (provider, model) => - runEmbeddedPiAgent({ + run: (provider, model) => { + if (provider === "claude-cli") { + const startedAt = Date.now(); + emitAgentEvent({ + runId, + stream: "lifecycle", + data: { + phase: "start", + startedAt, + }, + }); + return runClaudeCliAgent({ + sessionId: followupRun.run.sessionId, + sessionKey, + sessionFile: followupRun.run.sessionFile, + workspaceDir: followupRun.run.workspaceDir, + config: followupRun.run.config, + prompt: commandBody, + provider, + model, + thinkLevel: followupRun.run.thinkLevel, + timeoutMs: followupRun.run.timeoutMs, + runId, + extraSystemPrompt: followupRun.run.extraSystemPrompt, + ownerNumbers: followupRun.run.ownerNumbers, + resumeSessionId: + sessionEntry?.claudeCliSessionId?.trim() || undefined, + }) + .then((result) => { + emitAgentEvent({ + runId, + stream: "lifecycle", + data: { + phase: "end", + startedAt, + endedAt: Date.now(), + }, + }); + return result; + }) + .catch((err) => { + emitAgentEvent({ + runId, + stream: "lifecycle", + data: { + phase: "error", + startedAt, + endedAt: Date.now(), + error: err instanceof Error ? err.message : String(err), + }, + }); + throw err; + }); + } + return runEmbeddedPiAgent({ sessionId: followupRun.run.sessionId, sessionKey, messageProvider: @@ -554,7 +608,8 @@ export async function runReplyAgent(params: { pendingToolTasks.add(task); } : undefined, - }), + }); + }, }); runResult = fallbackResult.result; fallbackProvider = fallbackResult.provider; @@ -716,6 +771,10 @@ export async function runReplyAgent(params: { runResult.meta.agentMeta?.provider ?? fallbackProvider ?? followupRun.run.provider; + const cliSessionId = + providerUsed === "claude-cli" + ? runResult.meta.agentMeta?.sessionId?.trim() + : undefined; const contextTokensUsed = agentCfgContextTokens ?? lookupContextTokens(modelUsed) ?? @@ -741,6 +800,9 @@ export async function runReplyAgent(params: { contextTokens: contextTokensUsed ?? entry.contextTokens, updatedAt: Date.now(), }; + if (cliSessionId) { + nextEntry.claudeCliSessionId = cliSessionId; + } sessionStore[sessionKey] = nextEntry; if (storePath) { await saveSessionStore(storePath, sessionStore); @@ -754,6 +816,7 @@ export async function runReplyAgent(params: { modelProvider: providerUsed ?? entry.modelProvider, model: modelUsed ?? entry.model, contextTokens: contextTokensUsed ?? entry.contextTokens, + claudeCliSessionId: cliSessionId ?? entry.claudeCliSessionId, }; if (storePath) { await saveSessionStore(storePath, sessionStore); diff --git a/src/auto-reply/reply/model-selection.ts b/src/auto-reply/reply/model-selection.ts index a4cb67359..63f58b721 100644 --- a/src/auto-reply/reply/model-selection.ts +++ b/src/auto-reply/reply/model-selection.ts @@ -52,6 +52,7 @@ export async function createModelSelectionState(params: { sessionKey, storePath, defaultProvider, + defaultModel, } = params; let provider = params.provider; @@ -76,6 +77,7 @@ export async function createModelSelectionState(params: { cfg, catalog: modelCatalog, defaultProvider, + defaultModel, }); allowedModelCatalog = allowed.allowedCatalog; allowedModelKeys = allowed.allowedKeys; diff --git a/src/commands/agent.ts b/src/commands/agent.ts index 5f6d25b9e..af8a9edc8 100644 --- a/src/commands/agent.ts +++ b/src/commands/agent.ts @@ -337,6 +337,7 @@ export async function agentCommand( cfg, catalog: modelCatalog, defaultProvider, + defaultModel, }); allowedModelKeys = allowed.allowedKeys; allowedModelCatalog = allowed.allowedCatalog; diff --git a/src/cron/isolated-agent.ts b/src/cron/isolated-agent.ts index a4db61028..170c3228a 100644 --- a/src/cron/isolated-agent.ts +++ b/src/cron/isolated-agent.ts @@ -319,6 +319,7 @@ export async function runCronIsolatedAgentTurn(params: { cfg: params.cfg, catalog: await loadCatalog(), defaultProvider: resolvedDefault.provider, + defaultModel: resolvedDefault.model, }); const key = modelKey( resolvedOverride.ref.provider, diff --git a/src/gateway/server-bridge.ts b/src/gateway/server-bridge.ts index 5d4b2cf5e..eea960170 100644 --- a/src/gateway/server-bridge.ts +++ b/src/gateway/server-bridge.ts @@ -518,6 +518,7 @@ export function createBridgeHandlers(ctx: BridgeHandlersContext) { cfg, catalog, defaultProvider: resolvedDefault.provider, + defaultModel: resolvedDefault.model, }); const key = modelKey(resolved.ref.provider, resolved.ref.model); if (!allowed.allowAny && !allowed.allowedKeys.has(key)) { diff --git a/src/gateway/server-methods/sessions.ts b/src/gateway/server-methods/sessions.ts index d3cc38f04..fb265c891 100644 --- a/src/gateway/server-methods/sessions.ts +++ b/src/gateway/server-methods/sessions.ts @@ -299,6 +299,7 @@ export const sessionsHandlers: GatewayRequestHandlers = { cfg, catalog, defaultProvider: resolvedDefault.provider, + defaultModel: resolvedDefault.model, }); const key = modelKey(resolved.ref.provider, resolved.ref.model); if (!allowed.allowAny && !allowed.allowedKeys.has(key)) {