From 5898304fa04972d0a4d4da72b616f1c4f25d5d1b Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Sat, 10 Jan 2026 01:26:20 +0100 Subject: [PATCH] fix: abort runs between tool calls --- CHANGELOG.md | 1 + src/agents/pi-embedded-runner.ts | 4 + src/agents/pi-tools.ts | 48 ++++++++- src/auto-reply/reply/abort.ts | 98 +++++++++++++++++++ .../reply/dispatch-from-config.test.ts | 34 +++++++ src/auto-reply/reply/dispatch-from-config.ts | 32 ++++++ 6 files changed, 216 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 71253cc4e..ba7bd9fd0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Unreleased +- Agent: fast abort on /stop and cancel tool calls between tool boundaries. (#617) - Models/Auth: add OpenCode Zen (multi-model proxy) onboarding. (#623) — thanks @magimetal - WhatsApp: refactor vCard parsing helper and improve empty contact card summaries. (#624) — thanks @steipete - WhatsApp: include phone numbers when multiple contacts are shared. (#625) — thanks @mahmoudashraf93 diff --git a/src/agents/pi-embedded-runner.ts b/src/agents/pi-embedded-runner.ts index a3c2fd545..5baf03776 100644 --- a/src/agents/pi-embedded-runner.ts +++ b/src/agents/pi-embedded-runner.ts @@ -853,6 +853,7 @@ export async function compactEmbeddedPiSession(params: { sessionKey: params.sessionKey ?? params.sessionId, agentDir, config: params.config, + abortSignal: runAbortController.signal, // No currentChannelId/currentThreadTs for compaction - not in message context }); const machineName = await getMachineDisplayName(); @@ -1045,6 +1046,7 @@ export async function runEmbeddedPiAgent(params: { const enqueueGlobal = params.enqueue ?? ((task, opts) => enqueueCommandInLane(globalLane, task, opts)); + const runAbortController = new AbortController(); return enqueueCommandInLane(sessionLane, () => enqueueGlobal(async () => { const started = Date.now(); @@ -1223,6 +1225,7 @@ export async function runEmbeddedPiAgent(params: { sessionKey: params.sessionKey ?? params.sessionId, agentDir, config: params.config, + abortSignal: runAbortController.signal, currentChannelId: params.currentChannelId, currentThreadTs: params.currentThreadTs, replyToMode: params.replyToMode, @@ -1326,6 +1329,7 @@ export async function runEmbeddedPiAgent(params: { const abortRun = (isTimeout = false) => { aborted = true; if (isTimeout) timedOut = true; + runAbortController.abort(); void session.abort(); }; let subscription: ReturnType; diff --git a/src/agents/pi-tools.ts b/src/agents/pi-tools.ts index 5a475bb1b..b330506dc 100644 --- a/src/agents/pi-tools.ts +++ b/src/agents/pi-tools.ts @@ -503,6 +503,48 @@ export const __testing = { cleanToolSchemaForGemini, } as const; +function throwAbortError(): never { + const err = new Error("Aborted"); + err.name = "AbortError"; + throw err; +} + +function combineAbortSignals( + a?: AbortSignal, + b?: AbortSignal, +): AbortSignal | undefined { + if (!a && !b) return undefined; + if (a && !b) return a; + if (b && !a) return b; + if (a?.aborted) return a; + if (b?.aborted) return b; + if (typeof AbortSignal.any === "function") { + return AbortSignal.any([a as AbortSignal, b as AbortSignal]); + } + const controller = new AbortController(); + const onAbort = () => controller.abort(); + a?.addEventListener("abort", onAbort, { once: true }); + b?.addEventListener("abort", onAbort, { once: true }); + return controller.signal; +} + +function wrapToolWithAbortSignal( + tool: AnyAgentTool, + abortSignal?: AbortSignal, +): AnyAgentTool { + if (!abortSignal) return tool; + const execute = tool.execute; + if (!execute) return tool; + return { + ...tool, + execute: async (toolCallId, params, signal, onUpdate) => { + const combined = combineAbortSignals(signal, abortSignal); + if (combined?.aborted) throwAbortError(); + return await execute(toolCallId, params, combined, onUpdate); + }, + }; +} + export function createClawdbotCodingTools(options?: { bash?: BashToolDefaults & ProcessToolDefaults; messageProvider?: string; @@ -511,6 +553,7 @@ export function createClawdbotCodingTools(options?: { sessionKey?: string; agentDir?: string; config?: ClawdbotConfig; + abortSignal?: AbortSignal; /** Current channel ID for auto-threading (Slack). */ currentChannelId?: string; /** Current thread timestamp for auto-threading (Slack). */ @@ -607,8 +650,11 @@ export function createClawdbotCodingTools(options?: { // Always normalize tool JSON Schemas before handing them to pi-agent/pi-ai. // Without this, some providers (notably OpenAI) will reject root-level union schemas. const normalized = subagentFiltered.map(normalizeToolParameters); + const withAbort = options?.abortSignal + ? normalized.map((tool) => wrapToolWithAbortSignal(tool, options.abortSignal)) + : normalized; // Anthropic blocks specific lowercase tool names (bash, read, write, edit) with OAuth tokens. // Always use capitalized versions for compatibility with both OAuth and regular API keys. - return renameBlockedToolsForOAuth(normalized); + return renameBlockedToolsForOAuth(withAbort); } diff --git a/src/auto-reply/reply/abort.ts b/src/auto-reply/reply/abort.ts index c0d431708..7f9fefa05 100644 --- a/src/auto-reply/reply/abort.ts +++ b/src/auto-reply/reply/abort.ts @@ -1,3 +1,22 @@ +import { abortEmbeddedPiRun } from "../../agents/pi-embedded.js"; +import type { ClawdbotConfig } from "../../config/config.js"; +import { + loadSessionStore, + resolveStorePath, + saveSessionStore, +} from "../../config/sessions.js"; +import { + parseAgentSessionKey, + resolveAgentIdFromSessionKey, +} from "../../routing/session-key.js"; +import { resolveCommandAuthorization } from "../command-auth.js"; +import { + normalizeCommandBody, + shouldHandleTextCommands, +} from "../commands-registry.js"; +import type { MsgContext } from "../templating.js"; +import { stripMentions, stripStructuralPrefixes } from "./mentions.js"; + const ABORT_TRIGGERS = new Set(["stop", "esc", "abort", "wait", "exit"]); const ABORT_MEMORY = new Map(); @@ -14,3 +33,82 @@ export function getAbortMemory(key: string): boolean | undefined { export function setAbortMemory(key: string, value: boolean): void { ABORT_MEMORY.set(key, value); } + +function resolveSessionEntryForKey( + store: Record | undefined, + sessionKey: string | undefined, +) { + if (!store || !sessionKey) return {}; + const direct = store[sessionKey]; + if (direct) return { entry: direct, key: sessionKey }; + const parsed = parseAgentSessionKey(sessionKey); + const legacyKey = parsed?.rest; + if (legacyKey && store[legacyKey]) { + return { entry: store[legacyKey], key: legacyKey }; + } + return {}; +} + +function resolveAbortTargetKey(ctx: MsgContext): string | undefined { + const target = ctx.CommandTargetSessionKey?.trim(); + if (target) return target; + const sessionKey = ctx.SessionKey?.trim(); + return sessionKey || undefined; +} + +export async function tryFastAbortFromMessage(params: { + ctx: MsgContext; + cfg: ClawdbotConfig; +}): Promise<{ handled: boolean; aborted: boolean }> { + const { ctx, cfg } = params; + const surface = (ctx.Surface ?? ctx.Provider ?? "").trim().toLowerCase(); + const allowTextCommands = shouldHandleTextCommands({ + cfg, + surface, + commandSource: ctx.CommandSource, + }); + if (!allowTextCommands) return { handled: false, aborted: false }; + + const commandAuthorized = ctx.CommandAuthorized ?? true; + const auth = resolveCommandAuthorization({ + ctx, + cfg, + commandAuthorized, + }); + if (!auth.isAuthorizedSender) return { handled: false, aborted: false }; + + const targetKey = resolveAbortTargetKey(ctx); + const agentId = resolveAgentIdFromSessionKey( + targetKey ?? ctx.SessionKey ?? "", + ); + const raw = stripStructuralPrefixes(ctx.Body ?? ""); + const isGroup = ctx.ChatType?.trim().toLowerCase() === "group"; + const stripped = isGroup ? stripMentions(raw, ctx, cfg, agentId) : raw; + const normalized = normalizeCommandBody(stripped); + const abortRequested = normalized === "/stop" || isAbortTrigger(stripped); + if (!abortRequested) return { handled: false, aborted: false }; + + const abortKey = targetKey ?? auth.from ?? auth.to; + + if (targetKey) { + const storePath = resolveStorePath(cfg.session?.store, { agentId }); + const store = loadSessionStore(storePath); + const { entry, key } = resolveSessionEntryForKey(store, targetKey); + const sessionId = entry?.sessionId; + const aborted = sessionId ? abortEmbeddedPiRun(sessionId) : false; + if (entry && key) { + entry.abortedLastRun = true; + entry.updatedAt = Date.now(); + store[key] = entry; + await saveSessionStore(storePath, store); + } else if (abortKey) { + setAbortMemory(abortKey, true); + } + return { handled: true, aborted }; + } + + if (abortKey) { + setAbortMemory(abortKey, true); + } + return { handled: true, aborted: false }; +} diff --git a/src/auto-reply/reply/dispatch-from-config.test.ts b/src/auto-reply/reply/dispatch-from-config.test.ts index 69ea5102d..557e804d1 100644 --- a/src/auto-reply/reply/dispatch-from-config.test.ts +++ b/src/auto-reply/reply/dispatch-from-config.test.ts @@ -7,6 +7,7 @@ import type { ReplyDispatcher } from "./reply-dispatcher.js"; const mocks = vi.hoisted(() => ({ routeReply: vi.fn(async () => ({ ok: true, messageId: "mock" })), + tryFastAbortFromMessage: vi.fn(async () => ({ handled: false, aborted: false })), })); vi.mock("./route-reply.js", () => ({ @@ -25,6 +26,10 @@ vi.mock("./route-reply.js", () => ({ routeReply: mocks.routeReply, })); +vi.mock("./abort.js", () => ({ + tryFastAbortFromMessage: mocks.tryFastAbortFromMessage, +})); + const { dispatchReplyFromConfig } = await import("./dispatch-from-config.js"); function createDispatcher(): ReplyDispatcher { @@ -39,6 +44,10 @@ function createDispatcher(): ReplyDispatcher { describe("dispatchReplyFromConfig", () => { it("does not route when Provider matches OriginatingChannel (even if Surface is missing)", async () => { + mocks.tryFastAbortFromMessage.mockResolvedValue({ + handled: false, + aborted: false, + }); mocks.routeReply.mockClear(); const cfg = {} as ClawdbotConfig; const dispatcher = createDispatcher(); @@ -60,6 +69,10 @@ describe("dispatchReplyFromConfig", () => { }); it("routes when OriginatingChannel differs from Provider", async () => { + mocks.tryFastAbortFromMessage.mockResolvedValue({ + handled: false, + aborted: false, + }); mocks.routeReply.mockClear(); const cfg = {} as ClawdbotConfig; const dispatcher = createDispatcher(); @@ -88,4 +101,25 @@ describe("dispatchReplyFromConfig", () => { }), ); }); + + it("fast-aborts without calling the reply resolver", async () => { + mocks.tryFastAbortFromMessage.mockResolvedValue({ + handled: true, + aborted: true, + }); + const cfg = {} as ClawdbotConfig; + const dispatcher = createDispatcher(); + const ctx: MsgContext = { + Provider: "telegram", + Body: "/stop", + }; + const replyResolver = vi.fn(async () => ({ text: "hi" }) as ReplyPayload); + + await dispatchReplyFromConfig({ ctx, cfg, dispatcher, replyResolver }); + + expect(replyResolver).not.toHaveBeenCalled(); + expect(dispatcher.sendFinalReply).toHaveBeenCalledWith({ + text: "⚙️ Agent was aborted.", + }); + }); }); diff --git a/src/auto-reply/reply/dispatch-from-config.ts b/src/auto-reply/reply/dispatch-from-config.ts index d53ca18d9..a2b8d5738 100644 --- a/src/auto-reply/reply/dispatch-from-config.ts +++ b/src/auto-reply/reply/dispatch-from-config.ts @@ -3,6 +3,7 @@ import { logVerbose } from "../../globals.js"; import { getReplyFromConfig } from "../reply.js"; import type { MsgContext } from "../templating.js"; import type { GetReplyOptions, ReplyPayload } from "../types.js"; +import { tryFastAbortFromMessage } from "./abort.js"; import type { ReplyDispatcher, ReplyDispatchKind } from "./reply-dispatcher.js"; import { isRoutableChannel, routeReply } from "./route-reply.js"; @@ -66,6 +67,37 @@ export async function dispatchReplyFromConfig(params: { } }; + const fastAbort = await tryFastAbortFromMessage({ ctx, cfg }); + if (fastAbort.handled) { + const payload = { text: "⚙️ Agent was aborted." } satisfies ReplyPayload; + let queuedFinal = false; + let routedFinalCount = 0; + if (shouldRouteToOriginating && originatingChannel && originatingTo) { + const result = await routeReply({ + payload, + channel: originatingChannel, + to: originatingTo, + sessionKey: ctx.SessionKey, + accountId: ctx.AccountId, + threadId: ctx.MessageThreadId, + cfg, + }); + queuedFinal = result.ok; + if (result.ok) routedFinalCount += 1; + if (!result.ok) { + logVerbose( + `dispatch-from-config: route-reply (abort) failed: ${result.error ?? "unknown error"}`, + ); + } + } else { + queuedFinal = dispatcher.sendFinalReply(payload); + } + await dispatcher.waitForIdle(); + const counts = dispatcher.getQueuedCounts(); + counts.final += routedFinalCount; + return { queuedFinal, counts }; + } + const replyResult = await (params.replyResolver ?? getReplyFromConfig)( ctx, {