feat: allow hook model overrides
This commit is contained in:
@@ -126,6 +126,78 @@ describe("runCronIsolatedAgentTurn", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("uses model override when provided", async () => {
|
||||
await withTempHome(async (home) => {
|
||||
const storePath = await writeSessionStore(home);
|
||||
const deps: CliDeps = {
|
||||
sendMessageWhatsApp: vi.fn(),
|
||||
sendMessageTelegram: vi.fn(),
|
||||
sendMessageDiscord: vi.fn(),
|
||||
sendMessageSignal: vi.fn(),
|
||||
sendMessageIMessage: vi.fn(),
|
||||
};
|
||||
vi.mocked(runEmbeddedPiAgent).mockResolvedValue({
|
||||
payloads: [{ text: "ok" }],
|
||||
meta: {
|
||||
durationMs: 5,
|
||||
agentMeta: { sessionId: "s", provider: "p", model: "m" },
|
||||
},
|
||||
});
|
||||
|
||||
const res = await runCronIsolatedAgentTurn({
|
||||
cfg: makeCfg(home, storePath),
|
||||
deps,
|
||||
job: makeJob({
|
||||
kind: "agentTurn",
|
||||
message: "do it",
|
||||
model: "openai/gpt-4.1-mini",
|
||||
}),
|
||||
message: "do it",
|
||||
sessionKey: "cron:job-1",
|
||||
lane: "cron",
|
||||
});
|
||||
|
||||
expect(res.status).toBe("ok");
|
||||
const call = vi.mocked(runEmbeddedPiAgent).mock.calls[0]?.[0] as {
|
||||
provider?: string;
|
||||
model?: string;
|
||||
};
|
||||
expect(call?.provider).toBe("openai");
|
||||
expect(call?.model).toBe("gpt-4.1-mini");
|
||||
});
|
||||
});
|
||||
|
||||
it("rejects invalid model override", async () => {
|
||||
await withTempHome(async (home) => {
|
||||
const storePath = await writeSessionStore(home);
|
||||
const deps: CliDeps = {
|
||||
sendMessageWhatsApp: vi.fn(),
|
||||
sendMessageTelegram: vi.fn(),
|
||||
sendMessageDiscord: vi.fn(),
|
||||
sendMessageSignal: vi.fn(),
|
||||
sendMessageIMessage: vi.fn(),
|
||||
};
|
||||
vi.mocked(runEmbeddedPiAgent).mockReset();
|
||||
|
||||
const res = await runCronIsolatedAgentTurn({
|
||||
cfg: makeCfg(home, storePath),
|
||||
deps,
|
||||
job: makeJob({
|
||||
kind: "agentTurn",
|
||||
message: "do it",
|
||||
model: "openai/",
|
||||
}),
|
||||
message: "do it",
|
||||
sessionKey: "cron:job-1",
|
||||
lane: "cron",
|
||||
});
|
||||
|
||||
expect(res.status).toBe("error");
|
||||
expect(res.error).toMatch("invalid model");
|
||||
expect(vi.mocked(runEmbeddedPiAgent)).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it("defaults thinking to low for reasoning-capable models", async () => {
|
||||
await withTempHome(async (home) => {
|
||||
const storePath = await writeSessionStore(home);
|
||||
|
||||
@@ -8,7 +8,11 @@ import {
|
||||
import { loadModelCatalog } from "../agents/model-catalog.js";
|
||||
import { runWithModelFallback } from "../agents/model-fallback.js";
|
||||
import {
|
||||
buildAllowedModelSet,
|
||||
buildModelAliasIndex,
|
||||
modelKey,
|
||||
resolveConfiguredModelRef,
|
||||
resolveModelRefFromString,
|
||||
resolveThinkingDefault,
|
||||
} from "../agents/model-selection.js";
|
||||
import { runEmbeddedPiAgent } from "../agents/pi-embedded.js";
|
||||
@@ -212,11 +216,59 @@ export async function runCronIsolatedAgentTurn(params: {
|
||||
});
|
||||
const workspaceDir = workspace.dir;
|
||||
|
||||
const { provider, model } = resolveConfiguredModelRef({
|
||||
const resolvedDefault = resolveConfiguredModelRef({
|
||||
cfg: params.cfg,
|
||||
defaultProvider: DEFAULT_PROVIDER,
|
||||
defaultModel: DEFAULT_MODEL,
|
||||
});
|
||||
let provider = resolvedDefault.provider;
|
||||
let model = resolvedDefault.model;
|
||||
let catalog: Awaited<ReturnType<typeof loadModelCatalog>> | undefined;
|
||||
const loadCatalog = async () => {
|
||||
if (!catalog) {
|
||||
catalog = await loadModelCatalog({ config: params.cfg });
|
||||
}
|
||||
return catalog;
|
||||
};
|
||||
const modelOverrideRaw =
|
||||
params.job.payload.kind === "agentTurn"
|
||||
? params.job.payload.model
|
||||
: undefined;
|
||||
if (modelOverrideRaw !== undefined) {
|
||||
if (typeof modelOverrideRaw !== "string") {
|
||||
return { status: "error", error: "invalid model: expected string" };
|
||||
}
|
||||
const trimmed = modelOverrideRaw.trim();
|
||||
if (!trimmed) {
|
||||
return { status: "error", error: "invalid model: empty" };
|
||||
}
|
||||
const aliasIndex = buildModelAliasIndex({
|
||||
cfg: params.cfg,
|
||||
defaultProvider: resolvedDefault.provider,
|
||||
});
|
||||
const resolvedOverride = resolveModelRefFromString({
|
||||
raw: trimmed,
|
||||
defaultProvider: resolvedDefault.provider,
|
||||
aliasIndex,
|
||||
});
|
||||
if (!resolvedOverride) {
|
||||
return { status: "error", error: `invalid model: ${trimmed}` };
|
||||
}
|
||||
const allowed = buildAllowedModelSet({
|
||||
cfg: params.cfg,
|
||||
catalog: await loadCatalog(),
|
||||
defaultProvider: resolvedDefault.provider,
|
||||
});
|
||||
const key = modelKey(
|
||||
resolvedOverride.ref.provider,
|
||||
resolvedOverride.ref.model,
|
||||
);
|
||||
if (!allowed.allowAny && !allowed.allowedKeys.has(key)) {
|
||||
return { status: "error", error: `model not allowed: ${key}` };
|
||||
}
|
||||
provider = resolvedOverride.ref.provider;
|
||||
model = resolvedOverride.ref.model;
|
||||
}
|
||||
const now = Date.now();
|
||||
const cronSession = resolveCronSession({
|
||||
cfg: params.cfg,
|
||||
@@ -234,12 +286,11 @@ export async function runCronIsolatedAgentTurn(params: {
|
||||
);
|
||||
let thinkLevel = jobThink ?? thinkOverride;
|
||||
if (!thinkLevel) {
|
||||
const catalog = await loadModelCatalog({ config: params.cfg });
|
||||
thinkLevel = resolveThinkingDefault({
|
||||
cfg: params.cfg,
|
||||
provider,
|
||||
model,
|
||||
catalog,
|
||||
catalog: await loadCatalog(),
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ export type CronPayload =
|
||||
| {
|
||||
kind: "agentTurn";
|
||||
message: string;
|
||||
/** Optional model override (provider/model or alias). */
|
||||
model?: string;
|
||||
thinking?: string;
|
||||
timeoutSeconds?: number;
|
||||
deliver?: boolean;
|
||||
|
||||
Reference in New Issue
Block a user