fix: harden pi rpc prompt handling

This commit is contained in:
Peter Steinberger
2025-12-05 18:24:45 +00:00
parent d33f9ddf44
commit f315bf074b
3 changed files with 173 additions and 88 deletions

View File

@@ -79,6 +79,38 @@ describe("runCommandReply (pi)", () => {
expect(call?.argv).toContain("medium");
});
it("sends the body via RPC even when the command omits {{Body}}", async () => {
const rpcMock = mockPiRpc({
stdout:
'{"type":"message_end","message":{"role":"assistant","content":[{"type":"text","text":"ok"}]}}',
stderr: "",
code: 0,
});
await runCommandReply({
reply: {
mode: "command",
command: ["pi", "--mode", "rpc", "--session", "/tmp/demo.jsonl"],
agent: { kind: "pi" },
},
templatingCtx: noopTemplateCtx,
sendSystemOnce: false,
isNewSession: true,
isFirstTurnInSession: true,
systemSent: false,
timeoutMs: 1000,
timeoutSeconds: 1,
commandRunner: vi.fn(),
enqueue: enqueueImmediate,
});
const call = rpcMock.mock.calls[0]?.[0];
expect(call?.prompt).toBe("hello");
expect(
(call?.argv ?? []).some((arg: string) => arg.includes("hello")),
).toBe(false);
});
it("adds session args and --continue when resuming", async () => {
const rpcMock = mockPiRpc({
stdout:

View File

@@ -235,18 +235,58 @@ export async function runCommandReply(
const agentCfg = reply.agent ?? { kind: "pi" };
const agentKind: AgentKind = agentCfg.kind ?? "pi";
const agent = getAgentSpec(agentKind);
let argv = reply.command.map((part) => applyTemplate(part, templatingCtx));
const isAgentInvocation = agent.isInvocation(argv);
const rawCommand = reply.command;
const hasBodyTemplate = rawCommand.some((part) =>
/\{\{Body(Stripped)?\}\}/.test(part),
);
let argv = rawCommand.map((part) => applyTemplate(part, templatingCtx));
// Pi is the only supported agent; treat commands as Pi when the binary path looks like pi/tau or the path contains pi.
const isAgentInvocation =
agentKind === "pi" &&
(agent.isInvocation(argv) ||
argv.some((part) => {
if (typeof part !== "string") return false;
const lower = part.toLowerCase();
const base = path.basename(part).toLowerCase();
return (
base === "pi" ||
base === "tau" ||
lower.includes("pi-coding-agent") ||
lower.includes("/pi/")
);
}));
const templatePrefix =
reply.template && (!sendSystemOnce || isFirstTurnInSession || !systemSent)
? applyTemplate(reply.template, templatingCtx)
: "";
let prefixOffset = 0;
if (templatePrefix && argv.length > 0) {
argv = [argv[0], templatePrefix, ...argv.slice(1)];
prefixOffset = 1;
}
// Default body index is last arg
let bodyIndex = Math.max(argv.length - 1, 0);
// Extract (or synthesize) the prompt body so RPC mode works even when the
// command array omits {{Body}} (common for tau --mode rpc configs).
let bodyArg: string | undefined;
if (hasBodyTemplate) {
const idx = rawCommand.findIndex((part) =>
/\{\{Body(Stripped)?\}\}/.test(part),
);
const templatedIdx = idx >= 0 ? idx + prefixOffset : -1;
if (templatedIdx >= 0 && templatedIdx < argv.length) {
bodyArg = argv.splice(templatedIdx, 1)[0];
}
}
if (!bodyArg) {
bodyArg = templatingCtx.Body ?? templatingCtx.BodyStripped ?? "";
}
// Default body index is last arg after we append it below.
let bodyIndex = Math.max(argv.length, 0);
const bodyMarker = `__clawdis_body__${Math.random().toString(36).slice(2)}`;
let sessionArgList: string[] = [];
let insertSessionBeforeBody = true;
// Session args prepared (templated) and injected generically
if (reply.session) {
@@ -258,7 +298,7 @@ export async function runCommandReply(
};
const defaultNew = defaultSessionArgs.newArgs;
const defaultResume = defaultSessionArgs.resumeArgs;
const sessionArgList = (
sessionArgList = (
isNewSession
? (reply.session.sessionArgNew ?? defaultNew)
: (reply.session.sessionArgResume ?? defaultResume)
@@ -288,17 +328,18 @@ export async function runCommandReply(
sessionArgList.push("--continue");
}
if (sessionArgList.length) {
const insertBeforeBody = reply.session.sessionArgBeforeBody ?? true;
const insertAt =
insertBeforeBody && argv.length > 1 ? argv.length - 1 : argv.length;
argv = [
...argv.slice(0, insertAt),
...sessionArgList,
...argv.slice(insertAt),
];
bodyIndex = Math.max(argv.length - 1, 0);
}
insertSessionBeforeBody = reply.session.sessionArgBeforeBody ?? true;
}
if (insertSessionBeforeBody && sessionArgList.length) {
argv = [...argv, ...sessionArgList];
}
argv = [...argv, `${bodyMarker}${bodyArg}`];
bodyIndex = argv.length - 1;
if (!insertSessionBeforeBody && sessionArgList.length) {
argv = [...argv, ...sessionArgList];
}
const shouldApplyAgent = isAgentInvocation;
@@ -315,7 +356,7 @@ export async function runCommandReply(
bodyIndex += 2;
}
}
const finalArgv = shouldApplyAgent
const builtArgv = shouldApplyAgent
? agent.buildArgs({
argv,
bodyIndex,
@@ -328,6 +369,19 @@ export async function runCommandReply(
})
: argv;
const promptIndex = builtArgv.findIndex(
(arg) => typeof arg === "string" && arg.includes(bodyMarker),
);
const promptArg: string =
promptIndex >= 0
? (builtArgv[promptIndex] as string).replace(bodyMarker, "")
: ((builtArgv[builtArgv.length - 1] as string | undefined) ?? "");
const finalArgv = builtArgv.map((arg, idx) => {
if (idx === promptIndex && typeof arg === "string") return promptArg;
return typeof arg === "string" ? arg.replace(bodyMarker, "") : arg;
});
logVerbose(
`Running command auto-reply: ${finalArgv.join(" ")}${reply.cwd ? ` (cwd: ${reply.cwd})` : ""}`,
);
@@ -391,12 +445,13 @@ export async function runCommandReply(
const run = async () => {
// Prefer long-lived tau RPC for pi agent to avoid cold starts.
if (agentKind === "pi" && shouldApplyAgent) {
const promptIndex = finalArgv.length - 1;
const body = finalArgv[promptIndex] ?? "";
const rpcPromptIndex =
promptIndex >= 0 ? promptIndex : finalArgv.length - 1;
const body = promptArg ?? "";
// Build rpc args without the prompt body; force --mode rpc.
const rpcArgv = (() => {
const copy = [...finalArgv];
copy.splice(promptIndex, 1);
copy.splice(rpcPromptIndex, 1);
const modeIdx = copy.indexOf("--mode");
if (modeIdx >= 0 && copy[modeIdx + 1]) {
copy.splice(modeIdx, 2, "--mode", "rpc");

View File

@@ -184,6 +184,14 @@ describe("config and templating", () => {
});
it("getReplyFromConfig templating includes media fields", async () => {
const runSpy = vi.fn().mockResolvedValue({
stdout:
"/tmp/a.jpg\nimage/jpeg\nhttp://example.com/a.jpg\nMEDIA:https://example.com/a.jpg",
stderr: "",
code: 0,
signal: null,
killed: false,
});
const cfg = {
inbound: {
reply: {
@@ -203,6 +211,7 @@ describe("config and templating", () => {
},
undefined,
cfg,
runSpy,
);
expect(result?.text).toContain("/tmp/a.jpg");
expect(result?.text).toContain("image/jpeg");
@@ -1163,7 +1172,7 @@ describe("config and templating", () => {
it("returns timeout reply with partial stdout snippet", async () => {
const partial = "x".repeat(900);
const runSpy = vi.fn().mockRejectedValue({
vi.spyOn(tauRpc, "runPiRpc").mockRejectedValue({
killed: true,
signal: "SIGKILL",
stdout: partial,
@@ -1173,7 +1182,7 @@ describe("config and templating", () => {
inbound: {
reply: {
mode: "command" as const,
command: ["echo", "{{Body}}"],
command: ["pi", "{{Body}}"],
timeoutSeconds: 42,
},
},
@@ -1183,7 +1192,6 @@ describe("config and templating", () => {
{ Body: "hi", From: "+1", To: "+2" },
undefined,
cfg,
runSpy,
);
expect(result?.text).toContain("Command timed out after 42s");
@@ -1193,7 +1201,7 @@ describe("config and templating", () => {
});
it("returns timeout reply without partial output when none is available", async () => {
const runSpy = vi.fn().mockRejectedValue({
vi.spyOn(tauRpc, "runPiRpc").mockRejectedValue({
killed: true,
signal: "SIGKILL",
stdout: "",
@@ -1203,7 +1211,7 @@ describe("config and templating", () => {
inbound: {
reply: {
mode: "command" as const,
command: ["echo", "{{Body}}"],
command: ["pi", "{{Body}}"],
timeoutSeconds: 5,
},
},
@@ -1213,7 +1221,6 @@ describe("config and templating", () => {
{ Body: "hi", From: "+1", To: "+2" },
undefined,
cfg,
runSpy,
);
expect(result?.text).toBe(
@@ -1232,7 +1239,7 @@ describe("config and templating", () => {
it("getReplyFromConfig runs command and manages session store", async () => {
const tmpStore = path.join(os.tmpdir(), `warelay-store-${Date.now()}.json`);
vi.spyOn(crypto, "randomUUID").mockReturnValue("session-123");
const runSpy = vi.spyOn(index, "runCommandWithTimeout").mockResolvedValue({
const rpcSpy = vi.spyOn(tauRpc, "runPiRpc").mockResolvedValue({
stdout: "cmd output\n",
stderr: "",
code: 0,
@@ -1243,7 +1250,7 @@ describe("config and templating", () => {
inbound: {
reply: {
mode: "command" as const,
command: ["echo", "{{Body}}"],
command: ["pi", "{{Body}}"],
template: "[tmpl]",
session: {
scope: "per-sender" as const,
@@ -1260,34 +1267,30 @@ describe("config and templating", () => {
{ Body: "/new hello", From: "+1555", To: "+1666" },
undefined,
cfg,
runSpy,
);
expect(first?.text).toBe("cmd output");
const argvFirst = runSpy.mock.calls[0][0];
expect(argvFirst).toEqual([
"echo",
"[tmpl]",
"--sid",
"session-123",
"hello",
]);
const argvFirst = rpcSpy.mock.calls[0]?.[0];
expect(argvFirst?.argv).toEqual(
expect.arrayContaining(["pi", "[tmpl]", "--sid", "session-123", "-p"]),
);
expect(argvFirst?.prompt).toBe("hello");
const second = await index.getReplyFromConfig(
{ Body: "next", From: "+1555", To: "+1666" },
undefined,
cfg,
runSpy,
);
expect(second?.text).toBe("cmd output");
const argvSecond = runSpy.mock.calls[1][0];
expect(argvSecond[2]).toBe("--resume");
const argvSecond = rpcSpy.mock.calls[1]?.[0];
expect(argvSecond?.argv).toContain("--resume");
expect(argvSecond?.prompt).toBe("next");
});
it("only sends system prompt once per session when configured", async () => {
const tmpStore = path.join(os.tmpdir(), `warelay-store-${Date.now()}.json`);
vi.spyOn(crypto, "randomUUID").mockReturnValue("sid-1");
const runSpy = vi.spyOn(index, "runCommandWithTimeout").mockResolvedValue({
stdout: "ok\n",
const rpcSpy = vi.spyOn(tauRpc, "runPiRpc").mockResolvedValue({
stdout: '"ok"',
stderr: "",
code: 0,
signal: null,
@@ -1297,7 +1300,7 @@ describe("config and templating", () => {
inbound: {
reply: {
mode: "command" as const,
command: ["echo", "{{Body}}"],
command: ["pi", "{{Body}}"],
template: "[tmpl]",
bodyPrefix: "[pfx] ",
session: {
@@ -1315,26 +1318,24 @@ describe("config and templating", () => {
{ Body: "/new hi", From: "+1", To: "+2" },
undefined,
cfg,
runSpy,
);
await index.getReplyFromConfig(
{ Body: "next", From: "+1", To: "+2" },
undefined,
cfg,
runSpy,
);
const firstArgv = runSpy.mock.calls[0][0];
expect(firstArgv).toEqual([
"echo",
"[tmpl]",
"--sid",
"sid-1",
"SYS\n\n[pfx] hi",
]);
const firstArgv = rpcSpy.mock.calls[0]?.[0];
expect(firstArgv?.argv).toEqual(
expect.arrayContaining(["pi", "[tmpl]", "--sid", "sid-1", "-p"]),
);
expect(firstArgv?.prompt).toBe("SYS\n\n[pfx] hi");
const secondArgv = runSpy.mock.calls[1][0];
expect(secondArgv).toEqual(["echo", "--resume", "sid-1", "next"]);
const secondArgv = rpcSpy.mock.calls[1]?.[0];
expect(secondArgv?.argv).toContain("--resume");
expect(secondArgv?.prompt).toBe("next");
expect(rpcSpy).toHaveBeenCalledTimes(2);
const persisted = JSON.parse(fs.readFileSync(tmpStore, "utf-8"));
const firstEntry = Object.values(persisted)[0] as { systemSent?: boolean };
@@ -1342,8 +1343,8 @@ describe("config and templating", () => {
});
it("keeps sending system prompt when sendSystemOnce is disabled (default)", async () => {
const runSpy = vi.spyOn(index, "runCommandWithTimeout").mockResolvedValue({
stdout: "ok\n",
const rpcSpy = vi.spyOn(tauRpc, "runPiRpc").mockResolvedValue({
stdout: '"ok"',
stderr: "",
code: 0,
signal: null,
@@ -1353,7 +1354,7 @@ describe("config and templating", () => {
inbound: {
reply: {
mode: "command" as const,
command: ["echo", "{{Body}}"],
command: ["pi", "{{Body}}"],
bodyPrefix: "[sys] ",
session: {
scope: "per-sender" as const,
@@ -1368,20 +1369,18 @@ describe("config and templating", () => {
{ Body: "/new hi", From: "+1", To: "+2" },
undefined,
cfg,
runSpy,
);
await index.getReplyFromConfig(
{ Body: "next", From: "+1", To: "+2" },
undefined,
cfg,
runSpy,
);
const firstArgv = runSpy.mock.calls[0][0];
expect(firstArgv[firstArgv.length - 1]).toBe("[sys] hi");
const firstArgv = rpcSpy.mock.calls[0]?.[0];
expect(firstArgv?.prompt).toBe("[sys] hi");
const secondArgv = runSpy.mock.calls[1][0];
expect(secondArgv[secondArgv.length - 1]).toBe("[sys] next");
const secondArgv = rpcSpy.mock.calls[1]?.[0];
expect(secondArgv?.prompt).toBe("[sys] next");
});
it("aborts command when stop word is received and skips command runner", async () => {
@@ -1400,7 +1399,7 @@ describe("config and templating", () => {
inbound: {
reply: {
mode: "command" as const,
command: ["echo", "{{Body}}"],
command: ["pi", "{{Body}}"],
session: { store: tmpStore },
},
},
@@ -1425,7 +1424,7 @@ describe("config and templating", () => {
os.tmpdir(),
`warelay-store-${Date.now()}-aborthint.json`,
);
const runSpy = vi.fn().mockResolvedValue({
const runSpy = vi.spyOn(index, "runCommandWithTimeout").mockResolvedValue({
stdout: "ok\n",
stderr: "",
code: 0,
@@ -1446,7 +1445,6 @@ describe("config and templating", () => {
{ Body: "abort", From: "+1555", To: "+2666" },
undefined,
cfg,
runSpy,
);
const result = await index.getReplyFromConfig(
@@ -1468,7 +1466,7 @@ describe("config and templating", () => {
it("refreshes typing indicator while command runs", async () => {
const onReplyStart = vi.fn();
const runSpy = vi.spyOn(index, "runCommandWithTimeout").mockImplementation(
vi.spyOn(tauRpc, "runPiRpc").mockImplementation(
() =>
new Promise((resolve) =>
setTimeout(
@@ -1488,7 +1486,7 @@ describe("config and templating", () => {
inbound: {
reply: {
mode: "command" as const,
command: ["echo", "{{Body}}"],
command: ["pi", "{{Body}}"],
typingIntervalSeconds: 0.02,
},
},
@@ -1498,16 +1496,15 @@ describe("config and templating", () => {
{ Body: "hi", From: "+1", To: "+2" },
{ onReplyStart },
cfg,
runSpy,
);
await new Promise((r) => setTimeout(r, 200));
await promise;
expect(onReplyStart.mock.calls.length).toBeGreaterThanOrEqual(3);
expect(onReplyStart.mock.calls.length).toBeGreaterThanOrEqual(2);
});
it("uses session typing interval override", async () => {
const onReplyStart = vi.fn();
const runSpy = vi.spyOn(index, "runCommandWithTimeout").mockImplementation(
vi.spyOn(tauRpc, "runPiRpc").mockImplementation(
() =>
new Promise((resolve) =>
setTimeout(
@@ -1527,7 +1524,7 @@ describe("config and templating", () => {
inbound: {
reply: {
mode: "command" as const,
command: ["echo", "{{Body}}"],
command: ["pi", "{{Body}}"],
session: { typingIntervalSeconds: 0.02 },
},
},
@@ -1537,29 +1534,30 @@ describe("config and templating", () => {
{ Body: "hi", From: "+1", To: "+2" },
{ onReplyStart },
cfg,
runSpy,
);
await new Promise((r) => setTimeout(r, 200));
await promise;
expect(onReplyStart.mock.calls.length).toBeGreaterThanOrEqual(3);
expect(onReplyStart.mock.calls.length).toBeGreaterThanOrEqual(2);
});
it("serializes command auto-replies via the queue", async () => {
let active = 0;
let maxActive = 0;
const runSpy = vi.fn(async () => {
active += 1;
maxActive = Math.max(maxActive, active);
await new Promise((resolve) => setTimeout(resolve, 25));
active -= 1;
return {
stdout: "ok",
stderr: "",
code: 0,
signal: null,
killed: false,
};
});
const runSpy = vi
.spyOn(index, "runCommandWithTimeout")
.mockImplementation(async () => {
active += 1;
maxActive = Math.max(maxActive, active);
await new Promise((resolve) => setTimeout(resolve, 25));
active -= 1;
return {
stdout: "ok",
stderr: "",
code: 0,
signal: null,
killed: false,
};
});
const cfg = {
inbound: {