feat: enforce final tag parsing for embedded PI

This commit is contained in:
Peter Steinberger
2025-12-24 00:52:33 +00:00
parent bc916dbf35
commit 3fcdd6c9d7
3 changed files with 151 additions and 6 deletions

View File

@@ -270,6 +270,7 @@ export async function runEmbeddedPiAgent(params: {
enqueue?: typeof enqueueCommand;
extraSystemPrompt?: string;
ownerNumbers?: string[];
enforceFinalTag?: boolean;
}): Promise<EmbeddedPiRunResult> {
const enqueue = params.enqueue ?? enqueueCommand;
return enqueue(async () => {
@@ -333,8 +334,7 @@ export async function runEmbeddedPiAgent(params: {
node: process.version,
model: `${provider}/${modelId}`,
};
const reasoningTagHint =
provider === "lmstudio" || provider === "ollama";
const reasoningTagHint = provider === "lmstudio" || provider === "ollama";
const systemPrompt = buildSystemPrompt({
appendPrompt: buildAgentSystemPromptAppend({
workspaceDir: resolvedWorkspace,
@@ -403,6 +403,7 @@ export async function runEmbeddedPiAgent(params: {
onToolResult: params.onToolResult,
onPartialReply: params.onPartialReply,
onAgentEvent: params.onAgentEvent,
enforceFinalTag: params.enforceFinalTag,
});
const abortTimer = setTimeout(

View File

@@ -0,0 +1,95 @@
import { describe, expect, it, vi } from "vitest";
import { subscribeEmbeddedPiSession } from "./pi-embedded-subscribe.js";
type StubSession = {
subscribe: (fn: (evt: unknown) => void) => () => void;
};
describe("subscribeEmbeddedPiSession", () => {
it("filters to <final> and falls back when tags are malformed", () => {
let handler: ((evt: unknown) => void) | undefined;
const session: StubSession = {
subscribe: (fn) => {
handler = fn;
return () => {};
},
};
const onPartialReply = vi.fn();
const onAgentEvent = vi.fn();
subscribeEmbeddedPiSession({
session: session as unknown as Parameters<
typeof subscribeEmbeddedPiSession
>[0]["session"],
runId: "run",
enforceFinalTag: true,
onPartialReply,
onAgentEvent,
});
handler?.({
type: "message_update",
message: { role: "assistant" },
assistantMessageEvent: {
type: "text_delta",
delta: "<final>Hi there</final>",
},
});
expect(onPartialReply).toHaveBeenCalled();
const firstPayload = onPartialReply.mock.calls[0][0];
expect(firstPayload.text).toBe("Hi there");
onPartialReply.mockReset();
handler?.({
type: "message_end",
message: { role: "assistant" },
});
handler?.({
type: "message_update",
message: { role: "assistant" },
assistantMessageEvent: {
type: "text_delta",
delta: "</final>Oops no start",
},
});
const secondPayload = onPartialReply.mock.calls[0][0];
expect(secondPayload.text).toContain("Oops no start");
});
it("does not require <final> when enforcement is off", () => {
let handler: ((evt: unknown) => void) | undefined;
const session: StubSession = {
subscribe: (fn) => {
handler = fn;
return () => {};
},
};
const onPartialReply = vi.fn();
subscribeEmbeddedPiSession({
session: session as unknown as Parameters<
typeof subscribeEmbeddedPiSession
>[0]["session"],
runId: "run",
onPartialReply,
});
handler?.({
type: "message_update",
message: { role: "assistant" },
assistantMessageEvent: {
type: "text_delta",
delta: "Hello world",
},
});
const payload = onPartialReply.mock.calls[0][0];
expect(payload.text).toBe("Hello world");
});
});

View File

@@ -14,6 +14,8 @@ import {
} from "./pi-embedded-utils.js";
const THINKING_TAG_RE = /<\s*\/?\s*think(?:ing)?\s*>/gi;
const THINKING_OPEN_RE = /<\s*think(?:ing)?\s*>/i;
const THINKING_CLOSE_RE = /<\s*\/\s*think(?:ing)?\s*>/i;
function stripThinkingSegments(text: string): string {
if (!text || !THINKING_TAG_RE.test(text)) return text;
@@ -36,6 +38,16 @@ function stripThinkingSegments(text: string): string {
return result;
}
function stripUnpairedThinkingTags(text: string): string {
if (!text) return text;
const hasOpen = THINKING_OPEN_RE.test(text);
const hasClose = THINKING_CLOSE_RE.test(text);
if (hasOpen && hasClose) return text;
if (!hasOpen) return text.replace(THINKING_CLOSE_RE, "");
if (!hasClose) return text.replace(THINKING_OPEN_RE, "");
return text;
}
export function subscribeEmbeddedPiSession(params: {
session: AgentSession;
runId: string;
@@ -53,12 +65,34 @@ export function subscribeEmbeddedPiSession(params: {
stream: string;
data: Record<string, unknown>;
}) => void;
enforceFinalTag?: boolean;
}) {
const assistantTexts: string[] = [];
const toolMetas: Array<{ toolName?: string; meta?: string }> = [];
const toolMetaById = new Map<string, string | undefined>();
let deltaBuffer = "";
let lastStreamedAssistant: string | undefined;
const FINAL_START_RE = /<\s*final\s*>/i;
const FINAL_END_RE = /<\s*\/\s*final\s*>/i;
// Local providers sometimes emit malformed tags; normalize before filtering.
const sanitizeFinalText = (text: string): string => {
if (!text) return text;
const hasStart = FINAL_START_RE.test(text);
const hasEnd = FINAL_END_RE.test(text);
if (hasStart && !hasEnd) return text.replace(FINAL_START_RE, "");
if (!hasStart && hasEnd) return text.replace(FINAL_END_RE, "");
return text;
};
const extractFinalText = (text: string): string | undefined => {
const cleaned = sanitizeFinalText(text);
const startMatch = FINAL_START_RE.exec(cleaned);
if (!startMatch) return undefined;
const startIndex = startMatch.index + startMatch[0].length;
const afterStart = cleaned.slice(startIndex);
const endMatch = FINAL_END_RE.exec(afterStart);
const endIndex = endMatch ? endMatch.index : afterStart.length;
return afterStart.slice(0, endIndex);
};
const toolDebouncer = createToolDebouncer((toolName, metas) => {
if (!params.onPartialReply) return;
@@ -182,7 +216,12 @@ export function subscribeEmbeddedPiSession(params: {
: "";
if (chunk) {
deltaBuffer += chunk;
const next = stripThinkingSegments(deltaBuffer).trim();
const cleaned = params.enforceFinalTag
? stripThinkingSegments(stripUnpairedThinkingTags(deltaBuffer))
: stripThinkingSegments(deltaBuffer);
const next = params.enforceFinalTag
? (extractFinalText(cleaned)?.trim() ?? cleaned.trim())
: cleaned.trim();
if (next && next !== lastStreamedAssistant) {
lastStreamedAssistant = next;
const { text: cleanedText, mediaUrls } =
@@ -217,9 +256,19 @@ export function subscribeEmbeddedPiSession(params: {
if (evt.type === "message_end") {
const msg = (evt as AgentEvent & { message: AppMessage }).message;
if (msg?.role === "assistant") {
const text = stripThinkingSegments(
extractAssistantText(msg as AssistantMessage),
);
const cleaned = params.enforceFinalTag
? stripThinkingSegments(
stripUnpairedThinkingTags(
extractAssistantText(msg as AssistantMessage),
),
)
: stripThinkingSegments(
extractAssistantText(msg as AssistantMessage),
);
const text =
params.enforceFinalTag && cleaned
? (extractFinalText(cleaned)?.trim() ?? cleaned)
: cleaned;
if (text) assistantTexts.push(text);
deltaBuffer = "";
}