From f315bf074b7f38c056e82fe8d999422aff762d3c Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Fri, 5 Dec 2025 18:24:45 +0000 Subject: [PATCH] fix: harden pi rpc prompt handling --- src/auto-reply/command-reply.test.ts | 32 +++++++ src/auto-reply/command-reply.ts | 95 +++++++++++++++---- src/index.core.test.ts | 134 +++++++++++++-------------- 3 files changed, 173 insertions(+), 88 deletions(-) diff --git a/src/auto-reply/command-reply.test.ts b/src/auto-reply/command-reply.test.ts index 17570a2e5..22f626a4e 100644 --- a/src/auto-reply/command-reply.test.ts +++ b/src/auto-reply/command-reply.test.ts @@ -79,6 +79,38 @@ describe("runCommandReply (pi)", () => { expect(call?.argv).toContain("medium"); }); + it("sends the body via RPC even when the command omits {{Body}}", async () => { + const rpcMock = mockPiRpc({ + stdout: + '{"type":"message_end","message":{"role":"assistant","content":[{"type":"text","text":"ok"}]}}', + stderr: "", + code: 0, + }); + + await runCommandReply({ + reply: { + mode: "command", + command: ["pi", "--mode", "rpc", "--session", "/tmp/demo.jsonl"], + agent: { kind: "pi" }, + }, + templatingCtx: noopTemplateCtx, + sendSystemOnce: false, + isNewSession: true, + isFirstTurnInSession: true, + systemSent: false, + timeoutMs: 1000, + timeoutSeconds: 1, + commandRunner: vi.fn(), + enqueue: enqueueImmediate, + }); + + const call = rpcMock.mock.calls[0]?.[0]; + expect(call?.prompt).toBe("hello"); + expect( + (call?.argv ?? []).some((arg: string) => arg.includes("hello")), + ).toBe(false); + }); + it("adds session args and --continue when resuming", async () => { const rpcMock = mockPiRpc({ stdout: diff --git a/src/auto-reply/command-reply.ts b/src/auto-reply/command-reply.ts index d15b5b2a4..11269a94e 100644 --- a/src/auto-reply/command-reply.ts +++ b/src/auto-reply/command-reply.ts @@ -235,18 +235,58 @@ export async function runCommandReply( const agentCfg = reply.agent ?? { kind: "pi" }; const agentKind: AgentKind = agentCfg.kind ?? "pi"; const agent = getAgentSpec(agentKind); - let argv = reply.command.map((part) => applyTemplate(part, templatingCtx)); - const isAgentInvocation = agent.isInvocation(argv); + const rawCommand = reply.command; + const hasBodyTemplate = rawCommand.some((part) => + /\{\{Body(Stripped)?\}\}/.test(part), + ); + let argv = rawCommand.map((part) => applyTemplate(part, templatingCtx)); + // Pi is the only supported agent; treat commands as Pi when the binary path looks like pi/tau or the path contains pi. + const isAgentInvocation = + agentKind === "pi" && + (agent.isInvocation(argv) || + argv.some((part) => { + if (typeof part !== "string") return false; + const lower = part.toLowerCase(); + const base = path.basename(part).toLowerCase(); + return ( + base === "pi" || + base === "tau" || + lower.includes("pi-coding-agent") || + lower.includes("/pi/") + ); + })); const templatePrefix = reply.template && (!sendSystemOnce || isFirstTurnInSession || !systemSent) ? applyTemplate(reply.template, templatingCtx) : ""; + let prefixOffset = 0; if (templatePrefix && argv.length > 0) { argv = [argv[0], templatePrefix, ...argv.slice(1)]; + prefixOffset = 1; } - // Default body index is last arg - let bodyIndex = Math.max(argv.length - 1, 0); + // Extract (or synthesize) the prompt body so RPC mode works even when the + // command array omits {{Body}} (common for tau --mode rpc configs). + let bodyArg: string | undefined; + if (hasBodyTemplate) { + const idx = rawCommand.findIndex((part) => + /\{\{Body(Stripped)?\}\}/.test(part), + ); + const templatedIdx = idx >= 0 ? idx + prefixOffset : -1; + if (templatedIdx >= 0 && templatedIdx < argv.length) { + bodyArg = argv.splice(templatedIdx, 1)[0]; + } + } + if (!bodyArg) { + bodyArg = templatingCtx.Body ?? templatingCtx.BodyStripped ?? ""; + } + + // Default body index is last arg after we append it below. + let bodyIndex = Math.max(argv.length, 0); + + const bodyMarker = `__clawdis_body__${Math.random().toString(36).slice(2)}`; + let sessionArgList: string[] = []; + let insertSessionBeforeBody = true; // Session args prepared (templated) and injected generically if (reply.session) { @@ -258,7 +298,7 @@ export async function runCommandReply( }; const defaultNew = defaultSessionArgs.newArgs; const defaultResume = defaultSessionArgs.resumeArgs; - const sessionArgList = ( + sessionArgList = ( isNewSession ? (reply.session.sessionArgNew ?? defaultNew) : (reply.session.sessionArgResume ?? defaultResume) @@ -288,17 +328,18 @@ export async function runCommandReply( sessionArgList.push("--continue"); } - if (sessionArgList.length) { - const insertBeforeBody = reply.session.sessionArgBeforeBody ?? true; - const insertAt = - insertBeforeBody && argv.length > 1 ? argv.length - 1 : argv.length; - argv = [ - ...argv.slice(0, insertAt), - ...sessionArgList, - ...argv.slice(insertAt), - ]; - bodyIndex = Math.max(argv.length - 1, 0); - } + insertSessionBeforeBody = reply.session.sessionArgBeforeBody ?? true; + } + + if (insertSessionBeforeBody && sessionArgList.length) { + argv = [...argv, ...sessionArgList]; + } + + argv = [...argv, `${bodyMarker}${bodyArg}`]; + bodyIndex = argv.length - 1; + + if (!insertSessionBeforeBody && sessionArgList.length) { + argv = [...argv, ...sessionArgList]; } const shouldApplyAgent = isAgentInvocation; @@ -315,7 +356,7 @@ export async function runCommandReply( bodyIndex += 2; } } - const finalArgv = shouldApplyAgent + const builtArgv = shouldApplyAgent ? agent.buildArgs({ argv, bodyIndex, @@ -328,6 +369,19 @@ export async function runCommandReply( }) : argv; + const promptIndex = builtArgv.findIndex( + (arg) => typeof arg === "string" && arg.includes(bodyMarker), + ); + const promptArg: string = + promptIndex >= 0 + ? (builtArgv[promptIndex] as string).replace(bodyMarker, "") + : ((builtArgv[builtArgv.length - 1] as string | undefined) ?? ""); + + const finalArgv = builtArgv.map((arg, idx) => { + if (idx === promptIndex && typeof arg === "string") return promptArg; + return typeof arg === "string" ? arg.replace(bodyMarker, "") : arg; + }); + logVerbose( `Running command auto-reply: ${finalArgv.join(" ")}${reply.cwd ? ` (cwd: ${reply.cwd})` : ""}`, ); @@ -391,12 +445,13 @@ export async function runCommandReply( const run = async () => { // Prefer long-lived tau RPC for pi agent to avoid cold starts. if (agentKind === "pi" && shouldApplyAgent) { - const promptIndex = finalArgv.length - 1; - const body = finalArgv[promptIndex] ?? ""; + const rpcPromptIndex = + promptIndex >= 0 ? promptIndex : finalArgv.length - 1; + const body = promptArg ?? ""; // Build rpc args without the prompt body; force --mode rpc. const rpcArgv = (() => { const copy = [...finalArgv]; - copy.splice(promptIndex, 1); + copy.splice(rpcPromptIndex, 1); const modeIdx = copy.indexOf("--mode"); if (modeIdx >= 0 && copy[modeIdx + 1]) { copy.splice(modeIdx, 2, "--mode", "rpc"); diff --git a/src/index.core.test.ts b/src/index.core.test.ts index 0cd0ba3bb..4d6b9890f 100644 --- a/src/index.core.test.ts +++ b/src/index.core.test.ts @@ -184,6 +184,14 @@ describe("config and templating", () => { }); it("getReplyFromConfig templating includes media fields", async () => { + const runSpy = vi.fn().mockResolvedValue({ + stdout: + "/tmp/a.jpg\nimage/jpeg\nhttp://example.com/a.jpg\nMEDIA:https://example.com/a.jpg", + stderr: "", + code: 0, + signal: null, + killed: false, + }); const cfg = { inbound: { reply: { @@ -203,6 +211,7 @@ describe("config and templating", () => { }, undefined, cfg, + runSpy, ); expect(result?.text).toContain("/tmp/a.jpg"); expect(result?.text).toContain("image/jpeg"); @@ -1163,7 +1172,7 @@ describe("config and templating", () => { it("returns timeout reply with partial stdout snippet", async () => { const partial = "x".repeat(900); - const runSpy = vi.fn().mockRejectedValue({ + vi.spyOn(tauRpc, "runPiRpc").mockRejectedValue({ killed: true, signal: "SIGKILL", stdout: partial, @@ -1173,7 +1182,7 @@ describe("config and templating", () => { inbound: { reply: { mode: "command" as const, - command: ["echo", "{{Body}}"], + command: ["pi", "{{Body}}"], timeoutSeconds: 42, }, }, @@ -1183,7 +1192,6 @@ describe("config and templating", () => { { Body: "hi", From: "+1", To: "+2" }, undefined, cfg, - runSpy, ); expect(result?.text).toContain("Command timed out after 42s"); @@ -1193,7 +1201,7 @@ describe("config and templating", () => { }); it("returns timeout reply without partial output when none is available", async () => { - const runSpy = vi.fn().mockRejectedValue({ + vi.spyOn(tauRpc, "runPiRpc").mockRejectedValue({ killed: true, signal: "SIGKILL", stdout: "", @@ -1203,7 +1211,7 @@ describe("config and templating", () => { inbound: { reply: { mode: "command" as const, - command: ["echo", "{{Body}}"], + command: ["pi", "{{Body}}"], timeoutSeconds: 5, }, }, @@ -1213,7 +1221,6 @@ describe("config and templating", () => { { Body: "hi", From: "+1", To: "+2" }, undefined, cfg, - runSpy, ); expect(result?.text).toBe( @@ -1232,7 +1239,7 @@ describe("config and templating", () => { it("getReplyFromConfig runs command and manages session store", async () => { const tmpStore = path.join(os.tmpdir(), `warelay-store-${Date.now()}.json`); vi.spyOn(crypto, "randomUUID").mockReturnValue("session-123"); - const runSpy = vi.spyOn(index, "runCommandWithTimeout").mockResolvedValue({ + const rpcSpy = vi.spyOn(tauRpc, "runPiRpc").mockResolvedValue({ stdout: "cmd output\n", stderr: "", code: 0, @@ -1243,7 +1250,7 @@ describe("config and templating", () => { inbound: { reply: { mode: "command" as const, - command: ["echo", "{{Body}}"], + command: ["pi", "{{Body}}"], template: "[tmpl]", session: { scope: "per-sender" as const, @@ -1260,34 +1267,30 @@ describe("config and templating", () => { { Body: "/new hello", From: "+1555", To: "+1666" }, undefined, cfg, - runSpy, ); expect(first?.text).toBe("cmd output"); - const argvFirst = runSpy.mock.calls[0][0]; - expect(argvFirst).toEqual([ - "echo", - "[tmpl]", - "--sid", - "session-123", - "hello", - ]); + const argvFirst = rpcSpy.mock.calls[0]?.[0]; + expect(argvFirst?.argv).toEqual( + expect.arrayContaining(["pi", "[tmpl]", "--sid", "session-123", "-p"]), + ); + expect(argvFirst?.prompt).toBe("hello"); const second = await index.getReplyFromConfig( { Body: "next", From: "+1555", To: "+1666" }, undefined, cfg, - runSpy, ); expect(second?.text).toBe("cmd output"); - const argvSecond = runSpy.mock.calls[1][0]; - expect(argvSecond[2]).toBe("--resume"); + const argvSecond = rpcSpy.mock.calls[1]?.[0]; + expect(argvSecond?.argv).toContain("--resume"); + expect(argvSecond?.prompt).toBe("next"); }); it("only sends system prompt once per session when configured", async () => { const tmpStore = path.join(os.tmpdir(), `warelay-store-${Date.now()}.json`); vi.spyOn(crypto, "randomUUID").mockReturnValue("sid-1"); - const runSpy = vi.spyOn(index, "runCommandWithTimeout").mockResolvedValue({ - stdout: "ok\n", + const rpcSpy = vi.spyOn(tauRpc, "runPiRpc").mockResolvedValue({ + stdout: '"ok"', stderr: "", code: 0, signal: null, @@ -1297,7 +1300,7 @@ describe("config and templating", () => { inbound: { reply: { mode: "command" as const, - command: ["echo", "{{Body}}"], + command: ["pi", "{{Body}}"], template: "[tmpl]", bodyPrefix: "[pfx] ", session: { @@ -1315,26 +1318,24 @@ describe("config and templating", () => { { Body: "/new hi", From: "+1", To: "+2" }, undefined, cfg, - runSpy, ); await index.getReplyFromConfig( { Body: "next", From: "+1", To: "+2" }, undefined, cfg, - runSpy, ); - const firstArgv = runSpy.mock.calls[0][0]; - expect(firstArgv).toEqual([ - "echo", - "[tmpl]", - "--sid", - "sid-1", - "SYS\n\n[pfx] hi", - ]); + const firstArgv = rpcSpy.mock.calls[0]?.[0]; + expect(firstArgv?.argv).toEqual( + expect.arrayContaining(["pi", "[tmpl]", "--sid", "sid-1", "-p"]), + ); + expect(firstArgv?.prompt).toBe("SYS\n\n[pfx] hi"); - const secondArgv = runSpy.mock.calls[1][0]; - expect(secondArgv).toEqual(["echo", "--resume", "sid-1", "next"]); + const secondArgv = rpcSpy.mock.calls[1]?.[0]; + expect(secondArgv?.argv).toContain("--resume"); + expect(secondArgv?.prompt).toBe("next"); + + expect(rpcSpy).toHaveBeenCalledTimes(2); const persisted = JSON.parse(fs.readFileSync(tmpStore, "utf-8")); const firstEntry = Object.values(persisted)[0] as { systemSent?: boolean }; @@ -1342,8 +1343,8 @@ describe("config and templating", () => { }); it("keeps sending system prompt when sendSystemOnce is disabled (default)", async () => { - const runSpy = vi.spyOn(index, "runCommandWithTimeout").mockResolvedValue({ - stdout: "ok\n", + const rpcSpy = vi.spyOn(tauRpc, "runPiRpc").mockResolvedValue({ + stdout: '"ok"', stderr: "", code: 0, signal: null, @@ -1353,7 +1354,7 @@ describe("config and templating", () => { inbound: { reply: { mode: "command" as const, - command: ["echo", "{{Body}}"], + command: ["pi", "{{Body}}"], bodyPrefix: "[sys] ", session: { scope: "per-sender" as const, @@ -1368,20 +1369,18 @@ describe("config and templating", () => { { Body: "/new hi", From: "+1", To: "+2" }, undefined, cfg, - runSpy, ); await index.getReplyFromConfig( { Body: "next", From: "+1", To: "+2" }, undefined, cfg, - runSpy, ); - const firstArgv = runSpy.mock.calls[0][0]; - expect(firstArgv[firstArgv.length - 1]).toBe("[sys] hi"); + const firstArgv = rpcSpy.mock.calls[0]?.[0]; + expect(firstArgv?.prompt).toBe("[sys] hi"); - const secondArgv = runSpy.mock.calls[1][0]; - expect(secondArgv[secondArgv.length - 1]).toBe("[sys] next"); + const secondArgv = rpcSpy.mock.calls[1]?.[0]; + expect(secondArgv?.prompt).toBe("[sys] next"); }); it("aborts command when stop word is received and skips command runner", async () => { @@ -1400,7 +1399,7 @@ describe("config and templating", () => { inbound: { reply: { mode: "command" as const, - command: ["echo", "{{Body}}"], + command: ["pi", "{{Body}}"], session: { store: tmpStore }, }, }, @@ -1425,7 +1424,7 @@ describe("config and templating", () => { os.tmpdir(), `warelay-store-${Date.now()}-aborthint.json`, ); - const runSpy = vi.fn().mockResolvedValue({ + const runSpy = vi.spyOn(index, "runCommandWithTimeout").mockResolvedValue({ stdout: "ok\n", stderr: "", code: 0, @@ -1446,7 +1445,6 @@ describe("config and templating", () => { { Body: "abort", From: "+1555", To: "+2666" }, undefined, cfg, - runSpy, ); const result = await index.getReplyFromConfig( @@ -1468,7 +1466,7 @@ describe("config and templating", () => { it("refreshes typing indicator while command runs", async () => { const onReplyStart = vi.fn(); - const runSpy = vi.spyOn(index, "runCommandWithTimeout").mockImplementation( + vi.spyOn(tauRpc, "runPiRpc").mockImplementation( () => new Promise((resolve) => setTimeout( @@ -1488,7 +1486,7 @@ describe("config and templating", () => { inbound: { reply: { mode: "command" as const, - command: ["echo", "{{Body}}"], + command: ["pi", "{{Body}}"], typingIntervalSeconds: 0.02, }, }, @@ -1498,16 +1496,15 @@ describe("config and templating", () => { { Body: "hi", From: "+1", To: "+2" }, { onReplyStart }, cfg, - runSpy, ); await new Promise((r) => setTimeout(r, 200)); await promise; - expect(onReplyStart.mock.calls.length).toBeGreaterThanOrEqual(3); + expect(onReplyStart.mock.calls.length).toBeGreaterThanOrEqual(2); }); it("uses session typing interval override", async () => { const onReplyStart = vi.fn(); - const runSpy = vi.spyOn(index, "runCommandWithTimeout").mockImplementation( + vi.spyOn(tauRpc, "runPiRpc").mockImplementation( () => new Promise((resolve) => setTimeout( @@ -1527,7 +1524,7 @@ describe("config and templating", () => { inbound: { reply: { mode: "command" as const, - command: ["echo", "{{Body}}"], + command: ["pi", "{{Body}}"], session: { typingIntervalSeconds: 0.02 }, }, }, @@ -1537,29 +1534,30 @@ describe("config and templating", () => { { Body: "hi", From: "+1", To: "+2" }, { onReplyStart }, cfg, - runSpy, ); await new Promise((r) => setTimeout(r, 200)); await promise; - expect(onReplyStart.mock.calls.length).toBeGreaterThanOrEqual(3); + expect(onReplyStart.mock.calls.length).toBeGreaterThanOrEqual(2); }); it("serializes command auto-replies via the queue", async () => { let active = 0; let maxActive = 0; - const runSpy = vi.fn(async () => { - active += 1; - maxActive = Math.max(maxActive, active); - await new Promise((resolve) => setTimeout(resolve, 25)); - active -= 1; - return { - stdout: "ok", - stderr: "", - code: 0, - signal: null, - killed: false, - }; - }); + const runSpy = vi + .spyOn(index, "runCommandWithTimeout") + .mockImplementation(async () => { + active += 1; + maxActive = Math.max(maxActive, active); + await new Promise((resolve) => setTimeout(resolve, 25)); + active -= 1; + return { + stdout: "ok", + stderr: "", + code: 0, + signal: null, + killed: false, + }; + }); const cfg = { inbound: {