From 3fcdd6c9d7af3531bccb17fab96097dd154b05e6 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Wed, 24 Dec 2025 00:52:33 +0000 Subject: [PATCH] feat: enforce final tag parsing for embedded PI --- src/agents/pi-embedded-runner.ts | 5 +- src/agents/pi-embedded-subscribe.test.ts | 95 ++++++++++++++++++++++++ src/agents/pi-embedded-subscribe.ts | 57 +++++++++++++- 3 files changed, 151 insertions(+), 6 deletions(-) create mode 100644 src/agents/pi-embedded-subscribe.test.ts diff --git a/src/agents/pi-embedded-runner.ts b/src/agents/pi-embedded-runner.ts index 2939ed78a..5fd25fd44 100644 --- a/src/agents/pi-embedded-runner.ts +++ b/src/agents/pi-embedded-runner.ts @@ -270,6 +270,7 @@ export async function runEmbeddedPiAgent(params: { enqueue?: typeof enqueueCommand; extraSystemPrompt?: string; ownerNumbers?: string[]; + enforceFinalTag?: boolean; }): Promise { const enqueue = params.enqueue ?? enqueueCommand; return enqueue(async () => { @@ -333,8 +334,7 @@ export async function runEmbeddedPiAgent(params: { node: process.version, model: `${provider}/${modelId}`, }; - const reasoningTagHint = - provider === "lmstudio" || provider === "ollama"; + const reasoningTagHint = provider === "lmstudio" || provider === "ollama"; const systemPrompt = buildSystemPrompt({ appendPrompt: buildAgentSystemPromptAppend({ workspaceDir: resolvedWorkspace, @@ -403,6 +403,7 @@ export async function runEmbeddedPiAgent(params: { onToolResult: params.onToolResult, onPartialReply: params.onPartialReply, onAgentEvent: params.onAgentEvent, + enforceFinalTag: params.enforceFinalTag, }); const abortTimer = setTimeout( diff --git a/src/agents/pi-embedded-subscribe.test.ts b/src/agents/pi-embedded-subscribe.test.ts new file mode 100644 index 000000000..e5cd3e75f --- /dev/null +++ b/src/agents/pi-embedded-subscribe.test.ts @@ -0,0 +1,95 @@ +import { describe, expect, it, vi } from "vitest"; +import { subscribeEmbeddedPiSession } from "./pi-embedded-subscribe.js"; + +type StubSession = { + subscribe: (fn: (evt: unknown) => void) => () => void; +}; + +describe("subscribeEmbeddedPiSession", () => { + it("filters to and falls back when tags are malformed", () => { + let handler: ((evt: unknown) => void) | undefined; + const session: StubSession = { + subscribe: (fn) => { + handler = fn; + return () => {}; + }, + }; + + const onPartialReply = vi.fn(); + const onAgentEvent = vi.fn(); + + subscribeEmbeddedPiSession({ + session: session as unknown as Parameters< + typeof subscribeEmbeddedPiSession + >[0]["session"], + runId: "run", + enforceFinalTag: true, + onPartialReply, + onAgentEvent, + }); + + handler?.({ + type: "message_update", + message: { role: "assistant" }, + assistantMessageEvent: { + type: "text_delta", + delta: "Hi there", + }, + }); + + expect(onPartialReply).toHaveBeenCalled(); + const firstPayload = onPartialReply.mock.calls[0][0]; + expect(firstPayload.text).toBe("Hi there"); + + onPartialReply.mockReset(); + + handler?.({ + type: "message_end", + message: { role: "assistant" }, + }); + + handler?.({ + type: "message_update", + message: { role: "assistant" }, + assistantMessageEvent: { + type: "text_delta", + delta: "Oops no start", + }, + }); + + const secondPayload = onPartialReply.mock.calls[0][0]; + expect(secondPayload.text).toContain("Oops no start"); + }); + + it("does not require when enforcement is off", () => { + let handler: ((evt: unknown) => void) | undefined; + const session: StubSession = { + subscribe: (fn) => { + handler = fn; + return () => {}; + }, + }; + + const onPartialReply = vi.fn(); + + subscribeEmbeddedPiSession({ + session: session as unknown as Parameters< + typeof subscribeEmbeddedPiSession + >[0]["session"], + runId: "run", + onPartialReply, + }); + + handler?.({ + type: "message_update", + message: { role: "assistant" }, + assistantMessageEvent: { + type: "text_delta", + delta: "Hello world", + }, + }); + + const payload = onPartialReply.mock.calls[0][0]; + expect(payload.text).toBe("Hello world"); + }); +}); diff --git a/src/agents/pi-embedded-subscribe.ts b/src/agents/pi-embedded-subscribe.ts index 6ab029295..165c96e86 100644 --- a/src/agents/pi-embedded-subscribe.ts +++ b/src/agents/pi-embedded-subscribe.ts @@ -14,6 +14,8 @@ import { } from "./pi-embedded-utils.js"; const THINKING_TAG_RE = /<\s*\/?\s*think(?:ing)?\s*>/gi; +const THINKING_OPEN_RE = /<\s*think(?:ing)?\s*>/i; +const THINKING_CLOSE_RE = /<\s*\/\s*think(?:ing)?\s*>/i; function stripThinkingSegments(text: string): string { if (!text || !THINKING_TAG_RE.test(text)) return text; @@ -36,6 +38,16 @@ function stripThinkingSegments(text: string): string { return result; } +function stripUnpairedThinkingTags(text: string): string { + if (!text) return text; + const hasOpen = THINKING_OPEN_RE.test(text); + const hasClose = THINKING_CLOSE_RE.test(text); + if (hasOpen && hasClose) return text; + if (!hasOpen) return text.replace(THINKING_CLOSE_RE, ""); + if (!hasClose) return text.replace(THINKING_OPEN_RE, ""); + return text; +} + export function subscribeEmbeddedPiSession(params: { session: AgentSession; runId: string; @@ -53,12 +65,34 @@ export function subscribeEmbeddedPiSession(params: { stream: string; data: Record; }) => void; + enforceFinalTag?: boolean; }) { const assistantTexts: string[] = []; const toolMetas: Array<{ toolName?: string; meta?: string }> = []; const toolMetaById = new Map(); let deltaBuffer = ""; let lastStreamedAssistant: string | undefined; + const FINAL_START_RE = /<\s*final\s*>/i; + const FINAL_END_RE = /<\s*\/\s*final\s*>/i; + // Local providers sometimes emit malformed tags; normalize before filtering. + const sanitizeFinalText = (text: string): string => { + if (!text) return text; + const hasStart = FINAL_START_RE.test(text); + const hasEnd = FINAL_END_RE.test(text); + if (hasStart && !hasEnd) return text.replace(FINAL_START_RE, ""); + if (!hasStart && hasEnd) return text.replace(FINAL_END_RE, ""); + return text; + }; + const extractFinalText = (text: string): string | undefined => { + const cleaned = sanitizeFinalText(text); + const startMatch = FINAL_START_RE.exec(cleaned); + if (!startMatch) return undefined; + const startIndex = startMatch.index + startMatch[0].length; + const afterStart = cleaned.slice(startIndex); + const endMatch = FINAL_END_RE.exec(afterStart); + const endIndex = endMatch ? endMatch.index : afterStart.length; + return afterStart.slice(0, endIndex); + }; const toolDebouncer = createToolDebouncer((toolName, metas) => { if (!params.onPartialReply) return; @@ -182,7 +216,12 @@ export function subscribeEmbeddedPiSession(params: { : ""; if (chunk) { deltaBuffer += chunk; - const next = stripThinkingSegments(deltaBuffer).trim(); + const cleaned = params.enforceFinalTag + ? stripThinkingSegments(stripUnpairedThinkingTags(deltaBuffer)) + : stripThinkingSegments(deltaBuffer); + const next = params.enforceFinalTag + ? (extractFinalText(cleaned)?.trim() ?? cleaned.trim()) + : cleaned.trim(); if (next && next !== lastStreamedAssistant) { lastStreamedAssistant = next; const { text: cleanedText, mediaUrls } = @@ -217,9 +256,19 @@ export function subscribeEmbeddedPiSession(params: { if (evt.type === "message_end") { const msg = (evt as AgentEvent & { message: AppMessage }).message; if (msg?.role === "assistant") { - const text = stripThinkingSegments( - extractAssistantText(msg as AssistantMessage), - ); + const cleaned = params.enforceFinalTag + ? stripThinkingSegments( + stripUnpairedThinkingTags( + extractAssistantText(msg as AssistantMessage), + ), + ) + : stripThinkingSegments( + extractAssistantText(msg as AssistantMessage), + ); + const text = + params.enforceFinalTag && cleaned + ? (extractFinalText(cleaned)?.trim() ?? cleaned) + : cleaned; if (text) assistantTexts.push(text); deltaBuffer = ""; }