feat: allow hook model overrides
This commit is contained in:
@@ -216,6 +216,8 @@ export type HookMappingConfig = {
|
|||||||
| "signal"
|
| "signal"
|
||||||
| "imessage";
|
| "imessage";
|
||||||
to?: string;
|
to?: string;
|
||||||
|
/** Override model for this hook (provider/model or alias). */
|
||||||
|
model?: string;
|
||||||
thinking?: string;
|
thinking?: string;
|
||||||
timeoutSeconds?: number;
|
timeoutSeconds?: number;
|
||||||
transform?: HookMappingTransform;
|
transform?: HookMappingTransform;
|
||||||
|
|||||||
@@ -743,6 +743,7 @@ const HookMappingSchema = z
|
|||||||
])
|
])
|
||||||
.optional(),
|
.optional(),
|
||||||
to: z.string().optional(),
|
to: z.string().optional(),
|
||||||
|
model: z.string().optional(),
|
||||||
thinking: z.string().optional(),
|
thinking: z.string().optional(),
|
||||||
timeoutSeconds: z.number().int().positive().optional(),
|
timeoutSeconds: z.number().int().positive().optional(),
|
||||||
transform: z
|
transform: z
|
||||||
|
|||||||
@@ -126,6 +126,78 @@ describe("runCronIsolatedAgentTurn", () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("uses model override when provided", async () => {
|
||||||
|
await withTempHome(async (home) => {
|
||||||
|
const storePath = await writeSessionStore(home);
|
||||||
|
const deps: CliDeps = {
|
||||||
|
sendMessageWhatsApp: vi.fn(),
|
||||||
|
sendMessageTelegram: vi.fn(),
|
||||||
|
sendMessageDiscord: vi.fn(),
|
||||||
|
sendMessageSignal: vi.fn(),
|
||||||
|
sendMessageIMessage: vi.fn(),
|
||||||
|
};
|
||||||
|
vi.mocked(runEmbeddedPiAgent).mockResolvedValue({
|
||||||
|
payloads: [{ text: "ok" }],
|
||||||
|
meta: {
|
||||||
|
durationMs: 5,
|
||||||
|
agentMeta: { sessionId: "s", provider: "p", model: "m" },
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const res = await runCronIsolatedAgentTurn({
|
||||||
|
cfg: makeCfg(home, storePath),
|
||||||
|
deps,
|
||||||
|
job: makeJob({
|
||||||
|
kind: "agentTurn",
|
||||||
|
message: "do it",
|
||||||
|
model: "openai/gpt-4.1-mini",
|
||||||
|
}),
|
||||||
|
message: "do it",
|
||||||
|
sessionKey: "cron:job-1",
|
||||||
|
lane: "cron",
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(res.status).toBe("ok");
|
||||||
|
const call = vi.mocked(runEmbeddedPiAgent).mock.calls[0]?.[0] as {
|
||||||
|
provider?: string;
|
||||||
|
model?: string;
|
||||||
|
};
|
||||||
|
expect(call?.provider).toBe("openai");
|
||||||
|
expect(call?.model).toBe("gpt-4.1-mini");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it("rejects invalid model override", async () => {
|
||||||
|
await withTempHome(async (home) => {
|
||||||
|
const storePath = await writeSessionStore(home);
|
||||||
|
const deps: CliDeps = {
|
||||||
|
sendMessageWhatsApp: vi.fn(),
|
||||||
|
sendMessageTelegram: vi.fn(),
|
||||||
|
sendMessageDiscord: vi.fn(),
|
||||||
|
sendMessageSignal: vi.fn(),
|
||||||
|
sendMessageIMessage: vi.fn(),
|
||||||
|
};
|
||||||
|
vi.mocked(runEmbeddedPiAgent).mockReset();
|
||||||
|
|
||||||
|
const res = await runCronIsolatedAgentTurn({
|
||||||
|
cfg: makeCfg(home, storePath),
|
||||||
|
deps,
|
||||||
|
job: makeJob({
|
||||||
|
kind: "agentTurn",
|
||||||
|
message: "do it",
|
||||||
|
model: "openai/",
|
||||||
|
}),
|
||||||
|
message: "do it",
|
||||||
|
sessionKey: "cron:job-1",
|
||||||
|
lane: "cron",
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(res.status).toBe("error");
|
||||||
|
expect(res.error).toMatch("invalid model");
|
||||||
|
expect(vi.mocked(runEmbeddedPiAgent)).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
it("defaults thinking to low for reasoning-capable models", async () => {
|
it("defaults thinking to low for reasoning-capable models", async () => {
|
||||||
await withTempHome(async (home) => {
|
await withTempHome(async (home) => {
|
||||||
const storePath = await writeSessionStore(home);
|
const storePath = await writeSessionStore(home);
|
||||||
|
|||||||
@@ -8,7 +8,11 @@ import {
|
|||||||
import { loadModelCatalog } from "../agents/model-catalog.js";
|
import { loadModelCatalog } from "../agents/model-catalog.js";
|
||||||
import { runWithModelFallback } from "../agents/model-fallback.js";
|
import { runWithModelFallback } from "../agents/model-fallback.js";
|
||||||
import {
|
import {
|
||||||
|
buildAllowedModelSet,
|
||||||
|
buildModelAliasIndex,
|
||||||
|
modelKey,
|
||||||
resolveConfiguredModelRef,
|
resolveConfiguredModelRef,
|
||||||
|
resolveModelRefFromString,
|
||||||
resolveThinkingDefault,
|
resolveThinkingDefault,
|
||||||
} from "../agents/model-selection.js";
|
} from "../agents/model-selection.js";
|
||||||
import { runEmbeddedPiAgent } from "../agents/pi-embedded.js";
|
import { runEmbeddedPiAgent } from "../agents/pi-embedded.js";
|
||||||
@@ -212,11 +216,59 @@ export async function runCronIsolatedAgentTurn(params: {
|
|||||||
});
|
});
|
||||||
const workspaceDir = workspace.dir;
|
const workspaceDir = workspace.dir;
|
||||||
|
|
||||||
const { provider, model } = resolveConfiguredModelRef({
|
const resolvedDefault = resolveConfiguredModelRef({
|
||||||
cfg: params.cfg,
|
cfg: params.cfg,
|
||||||
defaultProvider: DEFAULT_PROVIDER,
|
defaultProvider: DEFAULT_PROVIDER,
|
||||||
defaultModel: DEFAULT_MODEL,
|
defaultModel: DEFAULT_MODEL,
|
||||||
});
|
});
|
||||||
|
let provider = resolvedDefault.provider;
|
||||||
|
let model = resolvedDefault.model;
|
||||||
|
let catalog: Awaited<ReturnType<typeof loadModelCatalog>> | undefined;
|
||||||
|
const loadCatalog = async () => {
|
||||||
|
if (!catalog) {
|
||||||
|
catalog = await loadModelCatalog({ config: params.cfg });
|
||||||
|
}
|
||||||
|
return catalog;
|
||||||
|
};
|
||||||
|
const modelOverrideRaw =
|
||||||
|
params.job.payload.kind === "agentTurn"
|
||||||
|
? params.job.payload.model
|
||||||
|
: undefined;
|
||||||
|
if (modelOverrideRaw !== undefined) {
|
||||||
|
if (typeof modelOverrideRaw !== "string") {
|
||||||
|
return { status: "error", error: "invalid model: expected string" };
|
||||||
|
}
|
||||||
|
const trimmed = modelOverrideRaw.trim();
|
||||||
|
if (!trimmed) {
|
||||||
|
return { status: "error", error: "invalid model: empty" };
|
||||||
|
}
|
||||||
|
const aliasIndex = buildModelAliasIndex({
|
||||||
|
cfg: params.cfg,
|
||||||
|
defaultProvider: resolvedDefault.provider,
|
||||||
|
});
|
||||||
|
const resolvedOverride = resolveModelRefFromString({
|
||||||
|
raw: trimmed,
|
||||||
|
defaultProvider: resolvedDefault.provider,
|
||||||
|
aliasIndex,
|
||||||
|
});
|
||||||
|
if (!resolvedOverride) {
|
||||||
|
return { status: "error", error: `invalid model: ${trimmed}` };
|
||||||
|
}
|
||||||
|
const allowed = buildAllowedModelSet({
|
||||||
|
cfg: params.cfg,
|
||||||
|
catalog: await loadCatalog(),
|
||||||
|
defaultProvider: resolvedDefault.provider,
|
||||||
|
});
|
||||||
|
const key = modelKey(
|
||||||
|
resolvedOverride.ref.provider,
|
||||||
|
resolvedOverride.ref.model,
|
||||||
|
);
|
||||||
|
if (!allowed.allowAny && !allowed.allowedKeys.has(key)) {
|
||||||
|
return { status: "error", error: `model not allowed: ${key}` };
|
||||||
|
}
|
||||||
|
provider = resolvedOverride.ref.provider;
|
||||||
|
model = resolvedOverride.ref.model;
|
||||||
|
}
|
||||||
const now = Date.now();
|
const now = Date.now();
|
||||||
const cronSession = resolveCronSession({
|
const cronSession = resolveCronSession({
|
||||||
cfg: params.cfg,
|
cfg: params.cfg,
|
||||||
@@ -234,12 +286,11 @@ export async function runCronIsolatedAgentTurn(params: {
|
|||||||
);
|
);
|
||||||
let thinkLevel = jobThink ?? thinkOverride;
|
let thinkLevel = jobThink ?? thinkOverride;
|
||||||
if (!thinkLevel) {
|
if (!thinkLevel) {
|
||||||
const catalog = await loadModelCatalog({ config: params.cfg });
|
|
||||||
thinkLevel = resolveThinkingDefault({
|
thinkLevel = resolveThinkingDefault({
|
||||||
cfg: params.cfg,
|
cfg: params.cfg,
|
||||||
provider,
|
provider,
|
||||||
model,
|
model,
|
||||||
catalog,
|
catalog: await loadCatalog(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ export type CronPayload =
|
|||||||
| {
|
| {
|
||||||
kind: "agentTurn";
|
kind: "agentTurn";
|
||||||
message: string;
|
message: string;
|
||||||
|
/** Optional model override (provider/model or alias). */
|
||||||
|
model?: string;
|
||||||
thinking?: string;
|
thinking?: string;
|
||||||
timeoutSeconds?: number;
|
timeoutSeconds?: number;
|
||||||
deliver?: boolean;
|
deliver?: boolean;
|
||||||
|
|||||||
@@ -38,6 +38,30 @@ describe("hooks mapping", () => {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("passes model override from mapping", async () => {
|
||||||
|
const mappings = resolveHookMappings({
|
||||||
|
mappings: [
|
||||||
|
{
|
||||||
|
id: "demo",
|
||||||
|
match: { path: "gmail" },
|
||||||
|
action: "agent",
|
||||||
|
messageTemplate: "Subject: {{messages[0].subject}}",
|
||||||
|
model: "openai/gpt-4.1-mini",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
const result = await applyHookMappings(mappings, {
|
||||||
|
payload: { messages: [{ subject: "Hello" }] },
|
||||||
|
headers: {},
|
||||||
|
url: baseUrl,
|
||||||
|
path: "gmail",
|
||||||
|
});
|
||||||
|
expect(result?.ok).toBe(true);
|
||||||
|
if (result?.ok && result.action.kind === "agent") {
|
||||||
|
expect(result.action.model).toBe("openai/gpt-4.1-mini");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
it("runs transform module", async () => {
|
it("runs transform module", async () => {
|
||||||
const dir = fs.mkdtempSync(path.join(os.tmpdir(), "clawdbot-hooks-"));
|
const dir = fs.mkdtempSync(path.join(os.tmpdir(), "clawdbot-hooks-"));
|
||||||
const modPath = path.join(dir, "transform.mjs");
|
const modPath = path.join(dir, "transform.mjs");
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ export type HookMappingResolved = {
|
|||||||
| "signal"
|
| "signal"
|
||||||
| "imessage";
|
| "imessage";
|
||||||
to?: string;
|
to?: string;
|
||||||
|
model?: string;
|
||||||
thinking?: string;
|
thinking?: string;
|
||||||
timeoutSeconds?: number;
|
timeoutSeconds?: number;
|
||||||
transform?: HookMappingTransformResolved;
|
transform?: HookMappingTransformResolved;
|
||||||
@@ -66,6 +67,7 @@ export type HookAction =
|
|||||||
| "signal"
|
| "signal"
|
||||||
| "imessage";
|
| "imessage";
|
||||||
to?: string;
|
to?: string;
|
||||||
|
model?: string;
|
||||||
thinking?: string;
|
thinking?: string;
|
||||||
timeoutSeconds?: number;
|
timeoutSeconds?: number;
|
||||||
};
|
};
|
||||||
@@ -110,6 +112,7 @@ type HookTransformResult = Partial<{
|
|||||||
| "signal"
|
| "signal"
|
||||||
| "imessage";
|
| "imessage";
|
||||||
to: string;
|
to: string;
|
||||||
|
model: string;
|
||||||
thinking: string;
|
thinking: string;
|
||||||
timeoutSeconds: number;
|
timeoutSeconds: number;
|
||||||
}> | null;
|
}> | null;
|
||||||
@@ -198,6 +201,7 @@ function normalizeHookMapping(
|
|||||||
deliver: mapping.deliver,
|
deliver: mapping.deliver,
|
||||||
provider: mapping.provider,
|
provider: mapping.provider,
|
||||||
to: mapping.to,
|
to: mapping.to,
|
||||||
|
model: mapping.model,
|
||||||
thinking: mapping.thinking,
|
thinking: mapping.thinking,
|
||||||
timeoutSeconds: mapping.timeoutSeconds,
|
timeoutSeconds: mapping.timeoutSeconds,
|
||||||
transform,
|
transform,
|
||||||
@@ -243,6 +247,7 @@ function buildActionFromMapping(
|
|||||||
deliver: mapping.deliver,
|
deliver: mapping.deliver,
|
||||||
provider: mapping.provider,
|
provider: mapping.provider,
|
||||||
to: renderOptional(mapping.to, ctx),
|
to: renderOptional(mapping.to, ctx),
|
||||||
|
model: renderOptional(mapping.model, ctx),
|
||||||
thinking: renderOptional(mapping.thinking, ctx),
|
thinking: renderOptional(mapping.thinking, ctx),
|
||||||
timeoutSeconds: mapping.timeoutSeconds,
|
timeoutSeconds: mapping.timeoutSeconds,
|
||||||
},
|
},
|
||||||
@@ -293,6 +298,7 @@ function mergeAction(
|
|||||||
: baseAgent?.deliver,
|
: baseAgent?.deliver,
|
||||||
provider: override.provider ?? baseAgent?.provider,
|
provider: override.provider ?? baseAgent?.provider,
|
||||||
to: override.to ?? baseAgent?.to,
|
to: override.to ?? baseAgent?.to,
|
||||||
|
model: override.model ?? baseAgent?.model,
|
||||||
thinking: override.thinking ?? baseAgent?.thinking,
|
thinking: override.thinking ?? baseAgent?.thinking,
|
||||||
timeoutSeconds: override.timeoutSeconds ?? baseAgent?.timeoutSeconds,
|
timeoutSeconds: override.timeoutSeconds ?? baseAgent?.timeoutSeconds,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -147,6 +147,7 @@ export type HookAgentPayload = {
|
|||||||
| "signal"
|
| "signal"
|
||||||
| "imessage";
|
| "imessage";
|
||||||
to?: string;
|
to?: string;
|
||||||
|
model?: string;
|
||||||
thinking?: string;
|
thinking?: string;
|
||||||
timeoutSeconds?: number;
|
timeoutSeconds?: number;
|
||||||
};
|
};
|
||||||
@@ -201,6 +202,14 @@ export function normalizeAgentPayload(
|
|||||||
const toRaw = payload.to;
|
const toRaw = payload.to;
|
||||||
const to =
|
const to =
|
||||||
typeof toRaw === "string" && toRaw.trim() ? toRaw.trim() : undefined;
|
typeof toRaw === "string" && toRaw.trim() ? toRaw.trim() : undefined;
|
||||||
|
const modelRaw = payload.model;
|
||||||
|
const model =
|
||||||
|
typeof modelRaw === "string" && modelRaw.trim()
|
||||||
|
? modelRaw.trim()
|
||||||
|
: undefined;
|
||||||
|
if (modelRaw !== undefined && !model) {
|
||||||
|
return { ok: false, error: "model required" };
|
||||||
|
}
|
||||||
const deliver = payload.deliver === true;
|
const deliver = payload.deliver === true;
|
||||||
const thinkingRaw = payload.thinking;
|
const thinkingRaw = payload.thinking;
|
||||||
const thinking =
|
const thinking =
|
||||||
@@ -224,6 +233,7 @@ export function normalizeAgentPayload(
|
|||||||
deliver,
|
deliver,
|
||||||
provider,
|
provider,
|
||||||
to,
|
to,
|
||||||
|
model,
|
||||||
thinking,
|
thinking,
|
||||||
timeoutSeconds,
|
timeoutSeconds,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -664,6 +664,7 @@ export const CronPayloadSchema = Type.Union([
|
|||||||
{
|
{
|
||||||
kind: Type.Literal("agentTurn"),
|
kind: Type.Literal("agentTurn"),
|
||||||
message: NonEmptyString,
|
message: NonEmptyString,
|
||||||
|
model: Type.Optional(Type.String()),
|
||||||
thinking: Type.Optional(Type.String()),
|
thinking: Type.Optional(Type.String()),
|
||||||
timeoutSeconds: Type.Optional(Type.Integer({ minimum: 1 })),
|
timeoutSeconds: Type.Optional(Type.Integer({ minimum: 1 })),
|
||||||
deliver: Type.Optional(Type.Boolean()),
|
deliver: Type.Optional(Type.Boolean()),
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ type HookDispatchers = {
|
|||||||
| "signal"
|
| "signal"
|
||||||
| "imessage";
|
| "imessage";
|
||||||
to?: string;
|
to?: string;
|
||||||
|
model?: string;
|
||||||
thinking?: string;
|
thinking?: string;
|
||||||
timeoutSeconds?: number;
|
timeoutSeconds?: number;
|
||||||
}) => string;
|
}) => string;
|
||||||
@@ -177,6 +178,7 @@ export function createHooksRequestHandler(
|
|||||||
deliver: mapped.action.deliver === true,
|
deliver: mapped.action.deliver === true,
|
||||||
provider: mapped.action.provider ?? "last",
|
provider: mapped.action.provider ?? "last",
|
||||||
to: mapped.action.to,
|
to: mapped.action.to,
|
||||||
|
model: mapped.action.model,
|
||||||
thinking: mapped.action.thinking,
|
thinking: mapped.action.thinking,
|
||||||
timeoutSeconds: mapped.action.timeoutSeconds,
|
timeoutSeconds: mapped.action.timeoutSeconds,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -67,6 +67,37 @@ describe("gateway server hooks", () => {
|
|||||||
await server.close();
|
await server.close();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test("hooks agent forwards model override", async () => {
|
||||||
|
testState.hooksConfig = { enabled: true, token: "hook-secret" };
|
||||||
|
cronIsolatedRun.mockClear();
|
||||||
|
cronIsolatedRun.mockResolvedValueOnce({
|
||||||
|
status: "ok",
|
||||||
|
summary: "done",
|
||||||
|
});
|
||||||
|
const port = await getFreePort();
|
||||||
|
const server = await startGatewayServer(port);
|
||||||
|
const res = await fetch(`http://127.0.0.1:${port}/hooks/agent`, {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
Authorization: "Bearer hook-secret",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
message: "Do it",
|
||||||
|
name: "Email",
|
||||||
|
model: "openai/gpt-4.1-mini",
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
expect(res.status).toBe(202);
|
||||||
|
await waitForSystemEvent();
|
||||||
|
const call = cronIsolatedRun.mock.calls[0]?.[0] as {
|
||||||
|
job?: { payload?: { model?: string } };
|
||||||
|
};
|
||||||
|
expect(call?.job?.payload?.model).toBe("openai/gpt-4.1-mini");
|
||||||
|
drainSystemEvents();
|
||||||
|
await server.close();
|
||||||
|
});
|
||||||
|
|
||||||
test("hooks wake accepts query token", async () => {
|
test("hooks wake accepts query token", async () => {
|
||||||
testState.hooksConfig = { enabled: true, token: "hook-secret" };
|
testState.hooksConfig = { enabled: true, token: "hook-secret" };
|
||||||
const port = await getFreePort();
|
const port = await getFreePort();
|
||||||
|
|||||||
@@ -502,6 +502,7 @@ export async function startGatewayServer(
|
|||||||
| "signal"
|
| "signal"
|
||||||
| "imessage";
|
| "imessage";
|
||||||
to?: string;
|
to?: string;
|
||||||
|
model?: string;
|
||||||
thinking?: string;
|
thinking?: string;
|
||||||
timeoutSeconds?: number;
|
timeoutSeconds?: number;
|
||||||
}) => {
|
}) => {
|
||||||
@@ -522,6 +523,7 @@ export async function startGatewayServer(
|
|||||||
payload: {
|
payload: {
|
||||||
kind: "agentTurn",
|
kind: "agentTurn",
|
||||||
message: value.message,
|
message: value.message,
|
||||||
|
model: value.model,
|
||||||
thinking: value.thinking,
|
thinking: value.thinking,
|
||||||
timeoutSeconds: value.timeoutSeconds,
|
timeoutSeconds: value.timeoutSeconds,
|
||||||
deliver: value.deliver,
|
deliver: value.deliver,
|
||||||
|
|||||||
Reference in New Issue
Block a user