diff --git a/src/agents/pi-embedded.ts b/src/agents/pi-embedded.ts index fdef78c30..cbfeb2f0b 100644 --- a/src/agents/pi-embedded.ts +++ b/src/agents/pi-embedded.ts @@ -9,6 +9,7 @@ import { type ThinkingLevel, } from "@mariozechner/pi-agent-core"; import { + type AgentToolResult, type Api, type AssistantMessage, getApiKey, @@ -38,7 +39,10 @@ import { inferToolMetaFromArgs, } from "./pi-embedded-utils.js"; import { getAnthropicOAuthToken } from "./pi-oauth.js"; -import { createClawdisCodingTools } from "./pi-tools.js"; +import { + createClawdisCodingTools, + sanitizeContentBlocksImages, +} from "./pi-tools.js"; import { buildAgentSystemPrompt } from "./system-prompt.js"; import { loadWorkspaceBootstrapFiles } from "./workspace.js"; @@ -130,6 +134,65 @@ async function getApiKeyForProvider( return getApiKey(provider) ?? undefined; } +type ContentBlock = AgentToolResult["content"][number]; + +async function sanitizeSessionMessagesImages( + messages: AppMessage[], + label: string, +): Promise { + const out: AppMessage[] = []; + for (const msg of messages) { + if (!msg || typeof msg !== "object") { + out.push(msg); + continue; + } + + const role = (msg as { role?: unknown }).role; + if (role === "toolResult") { + const toolMsg = msg as Extract; + const content = Array.isArray(toolMsg.content) ? toolMsg.content : []; + const nextContent = (await sanitizeContentBlocksImages( + content as ContentBlock[], + label, + )) as unknown as typeof toolMsg.content; + out.push({ ...toolMsg, content: nextContent }); + continue; + } + + if (role === "user") { + const userMsg = msg as Extract; + const content = userMsg.content; + if (Array.isArray(content)) { + const nextContent = (await sanitizeContentBlocksImages( + content as unknown as ContentBlock[], + label, + )) as unknown as typeof userMsg.content; + out.push({ ...userMsg, content: nextContent }); + continue; + } + } + + out.push(msg); + } + return out; +} + +function formatAssistantErrorText(msg: AssistantMessage): string | undefined { + if (msg.stopReason !== "error") return undefined; + const raw = (msg.errorMessage ?? "").trim(); + if (!raw) return "LLM request failed with an unknown error."; + + const invalidRequest = raw.match( + /"type":"invalid_request_error".*?"message":"([^"]+)"/, + ); + if (invalidRequest?.[1]) { + return `LLM request rejected: ${invalidRequest[1]}`; + } + + // Keep it short for WhatsApp. + return raw.length > 600 ? `${raw.slice(0, 600)}…` : raw; +} + export async function runEmbeddedPiAgent(params: { sessionId: string; sessionFile: string; @@ -221,7 +284,11 @@ export async function runEmbeddedPiAgent(params: { }); // Resume messages from the transcript if present. - const prior = sessionManager.loadSession().messages; + const priorRaw = sessionManager.loadSession().messages; + const prior = await sanitizeSessionMessagesImages( + priorRaw, + "session:history", + ); if (prior.length > 0) { agent.replaceMessages(prior); } @@ -456,6 +523,11 @@ export async function runEmbeddedPiAgent(params: { const replyItems: Array<{ text: string; media?: string[] }> = []; + const errorText = lastAssistant + ? formatAssistantErrorText(lastAssistant) + : undefined; + if (errorText) replyItems.push({ text: errorText }); + const inlineToolResults = params.verboseLevel === "on" && !params.onPartialReply && diff --git a/src/agents/pi-tools.test.ts b/src/agents/pi-tools.test.ts index c7ad2a4e5..97ac6f017 100644 --- a/src/agents/pi-tools.test.ts +++ b/src/agents/pi-tools.test.ts @@ -1,6 +1,7 @@ import fs from "node:fs/promises"; import os from "node:os"; import path from "node:path"; +import sharp from "sharp"; import { describe, expect, it } from "vitest"; import { createClawdisCodingTools } from "./pi-tools.js"; @@ -29,4 +30,42 @@ describe("createClawdisCodingTools", () => { expect(image?.mimeType).toBe("image/png"); }); + + it("downscales oversized images for LLM safety", async () => { + const tmpDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdis-pi-")); + const filePath = path.join(tmpDir, "oversized.png"); + + const buf = await sharp({ + create: { + width: 2001, + height: 10, + channels: 3, + background: { r: 0, g: 0, b: 0 }, + }, + }) + .png() + .toBuffer(); + await fs.writeFile(filePath, buf); + + const read = createClawdisCodingTools().find((t) => t.name === "read"); + expect(read).toBeTruthy(); + if (!read) throw new Error("read tool missing"); + + const res = await read.execute("toolCallId", { path: filePath }); + const image = res.content.find( + (b): b is { type: "image"; mimeType: string; data: string } => + !!b && + typeof b === "object" && + (b as Record).type === "image" && + typeof (b as Record).mimeType === "string" && + typeof (b as Record).data === "string", + ); + expect(image).toBeTruthy(); + if (!image) throw new Error("image block missing"); + + const decoded = Buffer.from(image.data, "base64"); + const meta = await sharp(decoded).metadata(); + expect(meta.width).toBeLessThanOrEqual(2000); + expect(meta.height).toBeLessThanOrEqual(2000); + }); }); diff --git a/src/agents/pi-tools.ts b/src/agents/pi-tools.ts index eb3e70454..e62d624e2 100644 --- a/src/agents/pi-tools.ts +++ b/src/agents/pi-tools.ts @@ -1,6 +1,7 @@ import type { AgentTool, AgentToolResult } from "@mariozechner/pi-ai"; -import { codingTools, readTool } from "@mariozechner/pi-coding-agent"; +import { bashTool, codingTools, readTool } from "@mariozechner/pi-coding-agent"; import type { TSchema } from "@sinclair/typebox"; +import sharp from "sharp"; import { detectMime } from "../media/mime.js"; @@ -10,6 +11,8 @@ type ToolContentBlock = AgentToolResult["content"][number]; type ImageContentBlock = Extract; type TextContentBlock = Extract; +const MAX_IMAGE_DIMENSION_PX = 2000; + function sniffMimeFromBase64(base64: string): string | undefined { const trimmed = base64.trim(); if (!trimmed) return undefined; @@ -94,6 +97,122 @@ function normalizeReadImageResult( type AnyAgentTool = AgentTool; +function isImageBlock(block: unknown): block is ImageContentBlock { + if (!block || typeof block !== "object") return false; + const rec = block as Record; + return ( + rec.type === "image" && + typeof rec.data === "string" && + typeof rec.mimeType === "string" + ); +} + +function isTextBlock(block: unknown): block is TextContentBlock { + if (!block || typeof block !== "object") return false; + const rec = block as Record; + return rec.type === "text" && typeof rec.text === "string"; +} + +async function resizeImageBase64IfNeeded(params: { + base64: string; + mimeType: string; + maxDimensionPx: number; +}): Promise<{ base64: string; mimeType: string; resized: boolean }> { + const buf = Buffer.from(params.base64, "base64"); + const img = sharp(buf, { failOnError: false }); + const meta = await img.metadata(); + + const width = meta.width; + const height = meta.height; + if ( + typeof width !== "number" || + typeof height !== "number" || + (width <= params.maxDimensionPx && height <= params.maxDimensionPx) + ) { + return { base64: params.base64, mimeType: params.mimeType, resized: false }; + } + + const resized = img.resize({ + width: params.maxDimensionPx, + height: params.maxDimensionPx, + fit: "inside", + withoutEnlargement: true, + }); + + const mime = params.mimeType.toLowerCase(); + let out: Buffer; + if (mime === "image/jpeg" || mime === "image/jpg") { + out = await resized.jpeg({ quality: 85 }).toBuffer(); + } else if (mime === "image/webp") { + out = await resized.webp({ quality: 85 }).toBuffer(); + } else if (mime === "image/png") { + out = await resized.png().toBuffer(); + } else { + out = await resized.png().toBuffer(); + } + + const sniffed = detectMime({ buffer: out.slice(0, 256) }); + const nextMime = sniffed?.startsWith("image/") ? sniffed : params.mimeType; + + return { base64: out.toString("base64"), mimeType: nextMime, resized: true }; +} + +export async function sanitizeContentBlocksImages( + blocks: ToolContentBlock[], + label: string, + opts: { maxDimensionPx?: number } = {}, +): Promise { + const maxDimensionPx = Math.max( + opts.maxDimensionPx ?? MAX_IMAGE_DIMENSION_PX, + 1, + ); + const out: ToolContentBlock[] = []; + + for (const block of blocks) { + if (!isImageBlock(block)) { + out.push(block); + continue; + } + + const data = block.data.trim(); + if (!data) { + out.push({ + type: "text", + text: `[${label}] omitted empty image payload`, + } satisfies TextContentBlock); + continue; + } + + try { + const resized = await resizeImageBase64IfNeeded({ + base64: data, + mimeType: block.mimeType, + maxDimensionPx, + }); + out.push({ ...block, data: resized.base64, mimeType: resized.mimeType }); + } catch (err) { + out.push({ + type: "text", + text: `[${label}] omitted image payload: ${String(err)}`, + } satisfies TextContentBlock); + } + } + + return out; +} + +export async function sanitizeToolResultImages( + result: AgentToolResult, + label: string, + opts: { maxDimensionPx?: number } = {}, +): Promise> { + const content = Array.isArray(result.content) ? result.content : []; + if (!content.some((b) => isImageBlock(b) || isTextBlock(b))) return result; + + const next = await sanitizeContentBlocksImages(content, label, opts); + return { ...result, content: next }; +} + function createClawdisReadTool(base: AnyAgentTool): AnyAgentTool { return { ...base, @@ -109,7 +228,22 @@ function createClawdisReadTool(base: AnyAgentTool): AnyAgentTool { : undefined; const filePath = typeof record?.path === "string" ? String(record.path) : ""; - return normalizeReadImageResult(result, filePath); + const normalized = normalizeReadImageResult(result, filePath); + return sanitizeToolResultImages(normalized, `read:${filePath}`); + }, + }; +} + +function createClawdisBashTool(base: AnyAgentTool): AnyAgentTool { + return { + ...base, + execute: async (toolCallId, params, signal) => { + const result = (await base.execute( + toolCallId, + params, + signal, + )) as AgentToolResult; + return sanitizeToolResultImages(result, "bash"); }, }; } @@ -118,6 +252,8 @@ export function createClawdisCodingTools(): AnyAgentTool[] { return (codingTools as unknown as AnyAgentTool[]).map((tool) => tool.name === readTool.name ? createClawdisReadTool(tool) - : (tool as AnyAgentTool), + : tool.name === bashTool.name + ? createClawdisBashTool(tool) + : (tool as AnyAgentTool), ); }