fix: harden pi rpc prompt handling

This commit is contained in:
Peter Steinberger
2025-12-05 18:24:45 +00:00
parent d33f9ddf44
commit f315bf074b
3 changed files with 173 additions and 88 deletions

View File

@@ -79,6 +79,38 @@ describe("runCommandReply (pi)", () => {
expect(call?.argv).toContain("medium"); expect(call?.argv).toContain("medium");
}); });
it("sends the body via RPC even when the command omits {{Body}}", async () => {
const rpcMock = mockPiRpc({
stdout:
'{"type":"message_end","message":{"role":"assistant","content":[{"type":"text","text":"ok"}]}}',
stderr: "",
code: 0,
});
await runCommandReply({
reply: {
mode: "command",
command: ["pi", "--mode", "rpc", "--session", "/tmp/demo.jsonl"],
agent: { kind: "pi" },
},
templatingCtx: noopTemplateCtx,
sendSystemOnce: false,
isNewSession: true,
isFirstTurnInSession: true,
systemSent: false,
timeoutMs: 1000,
timeoutSeconds: 1,
commandRunner: vi.fn(),
enqueue: enqueueImmediate,
});
const call = rpcMock.mock.calls[0]?.[0];
expect(call?.prompt).toBe("hello");
expect(
(call?.argv ?? []).some((arg: string) => arg.includes("hello")),
).toBe(false);
});
it("adds session args and --continue when resuming", async () => { it("adds session args and --continue when resuming", async () => {
const rpcMock = mockPiRpc({ const rpcMock = mockPiRpc({
stdout: stdout:

View File

@@ -235,18 +235,58 @@ export async function runCommandReply(
const agentCfg = reply.agent ?? { kind: "pi" }; const agentCfg = reply.agent ?? { kind: "pi" };
const agentKind: AgentKind = agentCfg.kind ?? "pi"; const agentKind: AgentKind = agentCfg.kind ?? "pi";
const agent = getAgentSpec(agentKind); const agent = getAgentSpec(agentKind);
let argv = reply.command.map((part) => applyTemplate(part, templatingCtx)); const rawCommand = reply.command;
const isAgentInvocation = agent.isInvocation(argv); const hasBodyTemplate = rawCommand.some((part) =>
/\{\{Body(Stripped)?\}\}/.test(part),
);
let argv = rawCommand.map((part) => applyTemplate(part, templatingCtx));
// Pi is the only supported agent; treat commands as Pi when the binary path looks like pi/tau or the path contains pi.
const isAgentInvocation =
agentKind === "pi" &&
(agent.isInvocation(argv) ||
argv.some((part) => {
if (typeof part !== "string") return false;
const lower = part.toLowerCase();
const base = path.basename(part).toLowerCase();
return (
base === "pi" ||
base === "tau" ||
lower.includes("pi-coding-agent") ||
lower.includes("/pi/")
);
}));
const templatePrefix = const templatePrefix =
reply.template && (!sendSystemOnce || isFirstTurnInSession || !systemSent) reply.template && (!sendSystemOnce || isFirstTurnInSession || !systemSent)
? applyTemplate(reply.template, templatingCtx) ? applyTemplate(reply.template, templatingCtx)
: ""; : "";
let prefixOffset = 0;
if (templatePrefix && argv.length > 0) { if (templatePrefix && argv.length > 0) {
argv = [argv[0], templatePrefix, ...argv.slice(1)]; argv = [argv[0], templatePrefix, ...argv.slice(1)];
prefixOffset = 1;
} }
// Default body index is last arg // Extract (or synthesize) the prompt body so RPC mode works even when the
let bodyIndex = Math.max(argv.length - 1, 0); // command array omits {{Body}} (common for tau --mode rpc configs).
let bodyArg: string | undefined;
if (hasBodyTemplate) {
const idx = rawCommand.findIndex((part) =>
/\{\{Body(Stripped)?\}\}/.test(part),
);
const templatedIdx = idx >= 0 ? idx + prefixOffset : -1;
if (templatedIdx >= 0 && templatedIdx < argv.length) {
bodyArg = argv.splice(templatedIdx, 1)[0];
}
}
if (!bodyArg) {
bodyArg = templatingCtx.Body ?? templatingCtx.BodyStripped ?? "";
}
// Default body index is last arg after we append it below.
let bodyIndex = Math.max(argv.length, 0);
const bodyMarker = `__clawdis_body__${Math.random().toString(36).slice(2)}`;
let sessionArgList: string[] = [];
let insertSessionBeforeBody = true;
// Session args prepared (templated) and injected generically // Session args prepared (templated) and injected generically
if (reply.session) { if (reply.session) {
@@ -258,7 +298,7 @@ export async function runCommandReply(
}; };
const defaultNew = defaultSessionArgs.newArgs; const defaultNew = defaultSessionArgs.newArgs;
const defaultResume = defaultSessionArgs.resumeArgs; const defaultResume = defaultSessionArgs.resumeArgs;
const sessionArgList = ( sessionArgList = (
isNewSession isNewSession
? (reply.session.sessionArgNew ?? defaultNew) ? (reply.session.sessionArgNew ?? defaultNew)
: (reply.session.sessionArgResume ?? defaultResume) : (reply.session.sessionArgResume ?? defaultResume)
@@ -288,17 +328,18 @@ export async function runCommandReply(
sessionArgList.push("--continue"); sessionArgList.push("--continue");
} }
if (sessionArgList.length) { insertSessionBeforeBody = reply.session.sessionArgBeforeBody ?? true;
const insertBeforeBody = reply.session.sessionArgBeforeBody ?? true; }
const insertAt =
insertBeforeBody && argv.length > 1 ? argv.length - 1 : argv.length; if (insertSessionBeforeBody && sessionArgList.length) {
argv = [ argv = [...argv, ...sessionArgList];
...argv.slice(0, insertAt), }
...sessionArgList,
...argv.slice(insertAt), argv = [...argv, `${bodyMarker}${bodyArg}`];
]; bodyIndex = argv.length - 1;
bodyIndex = Math.max(argv.length - 1, 0);
} if (!insertSessionBeforeBody && sessionArgList.length) {
argv = [...argv, ...sessionArgList];
} }
const shouldApplyAgent = isAgentInvocation; const shouldApplyAgent = isAgentInvocation;
@@ -315,7 +356,7 @@ export async function runCommandReply(
bodyIndex += 2; bodyIndex += 2;
} }
} }
const finalArgv = shouldApplyAgent const builtArgv = shouldApplyAgent
? agent.buildArgs({ ? agent.buildArgs({
argv, argv,
bodyIndex, bodyIndex,
@@ -328,6 +369,19 @@ export async function runCommandReply(
}) })
: argv; : argv;
const promptIndex = builtArgv.findIndex(
(arg) => typeof arg === "string" && arg.includes(bodyMarker),
);
const promptArg: string =
promptIndex >= 0
? (builtArgv[promptIndex] as string).replace(bodyMarker, "")
: ((builtArgv[builtArgv.length - 1] as string | undefined) ?? "");
const finalArgv = builtArgv.map((arg, idx) => {
if (idx === promptIndex && typeof arg === "string") return promptArg;
return typeof arg === "string" ? arg.replace(bodyMarker, "") : arg;
});
logVerbose( logVerbose(
`Running command auto-reply: ${finalArgv.join(" ")}${reply.cwd ? ` (cwd: ${reply.cwd})` : ""}`, `Running command auto-reply: ${finalArgv.join(" ")}${reply.cwd ? ` (cwd: ${reply.cwd})` : ""}`,
); );
@@ -391,12 +445,13 @@ export async function runCommandReply(
const run = async () => { const run = async () => {
// Prefer long-lived tau RPC for pi agent to avoid cold starts. // Prefer long-lived tau RPC for pi agent to avoid cold starts.
if (agentKind === "pi" && shouldApplyAgent) { if (agentKind === "pi" && shouldApplyAgent) {
const promptIndex = finalArgv.length - 1; const rpcPromptIndex =
const body = finalArgv[promptIndex] ?? ""; promptIndex >= 0 ? promptIndex : finalArgv.length - 1;
const body = promptArg ?? "";
// Build rpc args without the prompt body; force --mode rpc. // Build rpc args without the prompt body; force --mode rpc.
const rpcArgv = (() => { const rpcArgv = (() => {
const copy = [...finalArgv]; const copy = [...finalArgv];
copy.splice(promptIndex, 1); copy.splice(rpcPromptIndex, 1);
const modeIdx = copy.indexOf("--mode"); const modeIdx = copy.indexOf("--mode");
if (modeIdx >= 0 && copy[modeIdx + 1]) { if (modeIdx >= 0 && copy[modeIdx + 1]) {
copy.splice(modeIdx, 2, "--mode", "rpc"); copy.splice(modeIdx, 2, "--mode", "rpc");

View File

@@ -184,6 +184,14 @@ describe("config and templating", () => {
}); });
it("getReplyFromConfig templating includes media fields", async () => { it("getReplyFromConfig templating includes media fields", async () => {
const runSpy = vi.fn().mockResolvedValue({
stdout:
"/tmp/a.jpg\nimage/jpeg\nhttp://example.com/a.jpg\nMEDIA:https://example.com/a.jpg",
stderr: "",
code: 0,
signal: null,
killed: false,
});
const cfg = { const cfg = {
inbound: { inbound: {
reply: { reply: {
@@ -203,6 +211,7 @@ describe("config and templating", () => {
}, },
undefined, undefined,
cfg, cfg,
runSpy,
); );
expect(result?.text).toContain("/tmp/a.jpg"); expect(result?.text).toContain("/tmp/a.jpg");
expect(result?.text).toContain("image/jpeg"); expect(result?.text).toContain("image/jpeg");
@@ -1163,7 +1172,7 @@ describe("config and templating", () => {
it("returns timeout reply with partial stdout snippet", async () => { it("returns timeout reply with partial stdout snippet", async () => {
const partial = "x".repeat(900); const partial = "x".repeat(900);
const runSpy = vi.fn().mockRejectedValue({ vi.spyOn(tauRpc, "runPiRpc").mockRejectedValue({
killed: true, killed: true,
signal: "SIGKILL", signal: "SIGKILL",
stdout: partial, stdout: partial,
@@ -1173,7 +1182,7 @@ describe("config and templating", () => {
inbound: { inbound: {
reply: { reply: {
mode: "command" as const, mode: "command" as const,
command: ["echo", "{{Body}}"], command: ["pi", "{{Body}}"],
timeoutSeconds: 42, timeoutSeconds: 42,
}, },
}, },
@@ -1183,7 +1192,6 @@ describe("config and templating", () => {
{ Body: "hi", From: "+1", To: "+2" }, { Body: "hi", From: "+1", To: "+2" },
undefined, undefined,
cfg, cfg,
runSpy,
); );
expect(result?.text).toContain("Command timed out after 42s"); expect(result?.text).toContain("Command timed out after 42s");
@@ -1193,7 +1201,7 @@ describe("config and templating", () => {
}); });
it("returns timeout reply without partial output when none is available", async () => { it("returns timeout reply without partial output when none is available", async () => {
const runSpy = vi.fn().mockRejectedValue({ vi.spyOn(tauRpc, "runPiRpc").mockRejectedValue({
killed: true, killed: true,
signal: "SIGKILL", signal: "SIGKILL",
stdout: "", stdout: "",
@@ -1203,7 +1211,7 @@ describe("config and templating", () => {
inbound: { inbound: {
reply: { reply: {
mode: "command" as const, mode: "command" as const,
command: ["echo", "{{Body}}"], command: ["pi", "{{Body}}"],
timeoutSeconds: 5, timeoutSeconds: 5,
}, },
}, },
@@ -1213,7 +1221,6 @@ describe("config and templating", () => {
{ Body: "hi", From: "+1", To: "+2" }, { Body: "hi", From: "+1", To: "+2" },
undefined, undefined,
cfg, cfg,
runSpy,
); );
expect(result?.text).toBe( expect(result?.text).toBe(
@@ -1232,7 +1239,7 @@ describe("config and templating", () => {
it("getReplyFromConfig runs command and manages session store", async () => { it("getReplyFromConfig runs command and manages session store", async () => {
const tmpStore = path.join(os.tmpdir(), `warelay-store-${Date.now()}.json`); const tmpStore = path.join(os.tmpdir(), `warelay-store-${Date.now()}.json`);
vi.spyOn(crypto, "randomUUID").mockReturnValue("session-123"); vi.spyOn(crypto, "randomUUID").mockReturnValue("session-123");
const runSpy = vi.spyOn(index, "runCommandWithTimeout").mockResolvedValue({ const rpcSpy = vi.spyOn(tauRpc, "runPiRpc").mockResolvedValue({
stdout: "cmd output\n", stdout: "cmd output\n",
stderr: "", stderr: "",
code: 0, code: 0,
@@ -1243,7 +1250,7 @@ describe("config and templating", () => {
inbound: { inbound: {
reply: { reply: {
mode: "command" as const, mode: "command" as const,
command: ["echo", "{{Body}}"], command: ["pi", "{{Body}}"],
template: "[tmpl]", template: "[tmpl]",
session: { session: {
scope: "per-sender" as const, scope: "per-sender" as const,
@@ -1260,34 +1267,30 @@ describe("config and templating", () => {
{ Body: "/new hello", From: "+1555", To: "+1666" }, { Body: "/new hello", From: "+1555", To: "+1666" },
undefined, undefined,
cfg, cfg,
runSpy,
); );
expect(first?.text).toBe("cmd output"); expect(first?.text).toBe("cmd output");
const argvFirst = runSpy.mock.calls[0][0]; const argvFirst = rpcSpy.mock.calls[0]?.[0];
expect(argvFirst).toEqual([ expect(argvFirst?.argv).toEqual(
"echo", expect.arrayContaining(["pi", "[tmpl]", "--sid", "session-123", "-p"]),
"[tmpl]", );
"--sid", expect(argvFirst?.prompt).toBe("hello");
"session-123",
"hello",
]);
const second = await index.getReplyFromConfig( const second = await index.getReplyFromConfig(
{ Body: "next", From: "+1555", To: "+1666" }, { Body: "next", From: "+1555", To: "+1666" },
undefined, undefined,
cfg, cfg,
runSpy,
); );
expect(second?.text).toBe("cmd output"); expect(second?.text).toBe("cmd output");
const argvSecond = runSpy.mock.calls[1][0]; const argvSecond = rpcSpy.mock.calls[1]?.[0];
expect(argvSecond[2]).toBe("--resume"); expect(argvSecond?.argv).toContain("--resume");
expect(argvSecond?.prompt).toBe("next");
}); });
it("only sends system prompt once per session when configured", async () => { it("only sends system prompt once per session when configured", async () => {
const tmpStore = path.join(os.tmpdir(), `warelay-store-${Date.now()}.json`); const tmpStore = path.join(os.tmpdir(), `warelay-store-${Date.now()}.json`);
vi.spyOn(crypto, "randomUUID").mockReturnValue("sid-1"); vi.spyOn(crypto, "randomUUID").mockReturnValue("sid-1");
const runSpy = vi.spyOn(index, "runCommandWithTimeout").mockResolvedValue({ const rpcSpy = vi.spyOn(tauRpc, "runPiRpc").mockResolvedValue({
stdout: "ok\n", stdout: '"ok"',
stderr: "", stderr: "",
code: 0, code: 0,
signal: null, signal: null,
@@ -1297,7 +1300,7 @@ describe("config and templating", () => {
inbound: { inbound: {
reply: { reply: {
mode: "command" as const, mode: "command" as const,
command: ["echo", "{{Body}}"], command: ["pi", "{{Body}}"],
template: "[tmpl]", template: "[tmpl]",
bodyPrefix: "[pfx] ", bodyPrefix: "[pfx] ",
session: { session: {
@@ -1315,26 +1318,24 @@ describe("config and templating", () => {
{ Body: "/new hi", From: "+1", To: "+2" }, { Body: "/new hi", From: "+1", To: "+2" },
undefined, undefined,
cfg, cfg,
runSpy,
); );
await index.getReplyFromConfig( await index.getReplyFromConfig(
{ Body: "next", From: "+1", To: "+2" }, { Body: "next", From: "+1", To: "+2" },
undefined, undefined,
cfg, cfg,
runSpy,
); );
const firstArgv = runSpy.mock.calls[0][0]; const firstArgv = rpcSpy.mock.calls[0]?.[0];
expect(firstArgv).toEqual([ expect(firstArgv?.argv).toEqual(
"echo", expect.arrayContaining(["pi", "[tmpl]", "--sid", "sid-1", "-p"]),
"[tmpl]", );
"--sid", expect(firstArgv?.prompt).toBe("SYS\n\n[pfx] hi");
"sid-1",
"SYS\n\n[pfx] hi",
]);
const secondArgv = runSpy.mock.calls[1][0]; const secondArgv = rpcSpy.mock.calls[1]?.[0];
expect(secondArgv).toEqual(["echo", "--resume", "sid-1", "next"]); expect(secondArgv?.argv).toContain("--resume");
expect(secondArgv?.prompt).toBe("next");
expect(rpcSpy).toHaveBeenCalledTimes(2);
const persisted = JSON.parse(fs.readFileSync(tmpStore, "utf-8")); const persisted = JSON.parse(fs.readFileSync(tmpStore, "utf-8"));
const firstEntry = Object.values(persisted)[0] as { systemSent?: boolean }; const firstEntry = Object.values(persisted)[0] as { systemSent?: boolean };
@@ -1342,8 +1343,8 @@ describe("config and templating", () => {
}); });
it("keeps sending system prompt when sendSystemOnce is disabled (default)", async () => { it("keeps sending system prompt when sendSystemOnce is disabled (default)", async () => {
const runSpy = vi.spyOn(index, "runCommandWithTimeout").mockResolvedValue({ const rpcSpy = vi.spyOn(tauRpc, "runPiRpc").mockResolvedValue({
stdout: "ok\n", stdout: '"ok"',
stderr: "", stderr: "",
code: 0, code: 0,
signal: null, signal: null,
@@ -1353,7 +1354,7 @@ describe("config and templating", () => {
inbound: { inbound: {
reply: { reply: {
mode: "command" as const, mode: "command" as const,
command: ["echo", "{{Body}}"], command: ["pi", "{{Body}}"],
bodyPrefix: "[sys] ", bodyPrefix: "[sys] ",
session: { session: {
scope: "per-sender" as const, scope: "per-sender" as const,
@@ -1368,20 +1369,18 @@ describe("config and templating", () => {
{ Body: "/new hi", From: "+1", To: "+2" }, { Body: "/new hi", From: "+1", To: "+2" },
undefined, undefined,
cfg, cfg,
runSpy,
); );
await index.getReplyFromConfig( await index.getReplyFromConfig(
{ Body: "next", From: "+1", To: "+2" }, { Body: "next", From: "+1", To: "+2" },
undefined, undefined,
cfg, cfg,
runSpy,
); );
const firstArgv = runSpy.mock.calls[0][0]; const firstArgv = rpcSpy.mock.calls[0]?.[0];
expect(firstArgv[firstArgv.length - 1]).toBe("[sys] hi"); expect(firstArgv?.prompt).toBe("[sys] hi");
const secondArgv = runSpy.mock.calls[1][0]; const secondArgv = rpcSpy.mock.calls[1]?.[0];
expect(secondArgv[secondArgv.length - 1]).toBe("[sys] next"); expect(secondArgv?.prompt).toBe("[sys] next");
}); });
it("aborts command when stop word is received and skips command runner", async () => { it("aborts command when stop word is received and skips command runner", async () => {
@@ -1400,7 +1399,7 @@ describe("config and templating", () => {
inbound: { inbound: {
reply: { reply: {
mode: "command" as const, mode: "command" as const,
command: ["echo", "{{Body}}"], command: ["pi", "{{Body}}"],
session: { store: tmpStore }, session: { store: tmpStore },
}, },
}, },
@@ -1425,7 +1424,7 @@ describe("config and templating", () => {
os.tmpdir(), os.tmpdir(),
`warelay-store-${Date.now()}-aborthint.json`, `warelay-store-${Date.now()}-aborthint.json`,
); );
const runSpy = vi.fn().mockResolvedValue({ const runSpy = vi.spyOn(index, "runCommandWithTimeout").mockResolvedValue({
stdout: "ok\n", stdout: "ok\n",
stderr: "", stderr: "",
code: 0, code: 0,
@@ -1446,7 +1445,6 @@ describe("config and templating", () => {
{ Body: "abort", From: "+1555", To: "+2666" }, { Body: "abort", From: "+1555", To: "+2666" },
undefined, undefined,
cfg, cfg,
runSpy,
); );
const result = await index.getReplyFromConfig( const result = await index.getReplyFromConfig(
@@ -1468,7 +1466,7 @@ describe("config and templating", () => {
it("refreshes typing indicator while command runs", async () => { it("refreshes typing indicator while command runs", async () => {
const onReplyStart = vi.fn(); const onReplyStart = vi.fn();
const runSpy = vi.spyOn(index, "runCommandWithTimeout").mockImplementation( vi.spyOn(tauRpc, "runPiRpc").mockImplementation(
() => () =>
new Promise((resolve) => new Promise((resolve) =>
setTimeout( setTimeout(
@@ -1488,7 +1486,7 @@ describe("config and templating", () => {
inbound: { inbound: {
reply: { reply: {
mode: "command" as const, mode: "command" as const,
command: ["echo", "{{Body}}"], command: ["pi", "{{Body}}"],
typingIntervalSeconds: 0.02, typingIntervalSeconds: 0.02,
}, },
}, },
@@ -1498,16 +1496,15 @@ describe("config and templating", () => {
{ Body: "hi", From: "+1", To: "+2" }, { Body: "hi", From: "+1", To: "+2" },
{ onReplyStart }, { onReplyStart },
cfg, cfg,
runSpy,
); );
await new Promise((r) => setTimeout(r, 200)); await new Promise((r) => setTimeout(r, 200));
await promise; await promise;
expect(onReplyStart.mock.calls.length).toBeGreaterThanOrEqual(3); expect(onReplyStart.mock.calls.length).toBeGreaterThanOrEqual(2);
}); });
it("uses session typing interval override", async () => { it("uses session typing interval override", async () => {
const onReplyStart = vi.fn(); const onReplyStart = vi.fn();
const runSpy = vi.spyOn(index, "runCommandWithTimeout").mockImplementation( vi.spyOn(tauRpc, "runPiRpc").mockImplementation(
() => () =>
new Promise((resolve) => new Promise((resolve) =>
setTimeout( setTimeout(
@@ -1527,7 +1524,7 @@ describe("config and templating", () => {
inbound: { inbound: {
reply: { reply: {
mode: "command" as const, mode: "command" as const,
command: ["echo", "{{Body}}"], command: ["pi", "{{Body}}"],
session: { typingIntervalSeconds: 0.02 }, session: { typingIntervalSeconds: 0.02 },
}, },
}, },
@@ -1537,29 +1534,30 @@ describe("config and templating", () => {
{ Body: "hi", From: "+1", To: "+2" }, { Body: "hi", From: "+1", To: "+2" },
{ onReplyStart }, { onReplyStart },
cfg, cfg,
runSpy,
); );
await new Promise((r) => setTimeout(r, 200)); await new Promise((r) => setTimeout(r, 200));
await promise; await promise;
expect(onReplyStart.mock.calls.length).toBeGreaterThanOrEqual(3); expect(onReplyStart.mock.calls.length).toBeGreaterThanOrEqual(2);
}); });
it("serializes command auto-replies via the queue", async () => { it("serializes command auto-replies via the queue", async () => {
let active = 0; let active = 0;
let maxActive = 0; let maxActive = 0;
const runSpy = vi.fn(async () => { const runSpy = vi
active += 1; .spyOn(index, "runCommandWithTimeout")
maxActive = Math.max(maxActive, active); .mockImplementation(async () => {
await new Promise((resolve) => setTimeout(resolve, 25)); active += 1;
active -= 1; maxActive = Math.max(maxActive, active);
return { await new Promise((resolve) => setTimeout(resolve, 25));
stdout: "ok", active -= 1;
stderr: "", return {
code: 0, stdout: "ok",
signal: null, stderr: "",
killed: false, code: 0,
}; signal: null,
}); killed: false,
};
});
const cfg = { const cfg = {
inbound: { inbound: {