fix: abort runs between tool calls

This commit is contained in:
Peter Steinberger
2026-01-10 01:26:20 +01:00
parent a0a64a625e
commit 5898304fa0
6 changed files with 216 additions and 1 deletions

View File

@@ -1,3 +1,22 @@
import { abortEmbeddedPiRun } from "../../agents/pi-embedded.js";
import type { ClawdbotConfig } from "../../config/config.js";
import {
loadSessionStore,
resolveStorePath,
saveSessionStore,
} from "../../config/sessions.js";
import {
parseAgentSessionKey,
resolveAgentIdFromSessionKey,
} from "../../routing/session-key.js";
import { resolveCommandAuthorization } from "../command-auth.js";
import {
normalizeCommandBody,
shouldHandleTextCommands,
} from "../commands-registry.js";
import type { MsgContext } from "../templating.js";
import { stripMentions, stripStructuralPrefixes } from "./mentions.js";
const ABORT_TRIGGERS = new Set(["stop", "esc", "abort", "wait", "exit"]);
const ABORT_MEMORY = new Map<string, boolean>();
@@ -14,3 +33,82 @@ export function getAbortMemory(key: string): boolean | undefined {
export function setAbortMemory(key: string, value: boolean): void {
ABORT_MEMORY.set(key, value);
}
function resolveSessionEntryForKey(
store: Record<string, { sessionId: string; updatedAt: number }> | undefined,
sessionKey: string | undefined,
) {
if (!store || !sessionKey) return {};
const direct = store[sessionKey];
if (direct) return { entry: direct, key: sessionKey };
const parsed = parseAgentSessionKey(sessionKey);
const legacyKey = parsed?.rest;
if (legacyKey && store[legacyKey]) {
return { entry: store[legacyKey], key: legacyKey };
}
return {};
}
function resolveAbortTargetKey(ctx: MsgContext): string | undefined {
const target = ctx.CommandTargetSessionKey?.trim();
if (target) return target;
const sessionKey = ctx.SessionKey?.trim();
return sessionKey || undefined;
}
export async function tryFastAbortFromMessage(params: {
ctx: MsgContext;
cfg: ClawdbotConfig;
}): Promise<{ handled: boolean; aborted: boolean }> {
const { ctx, cfg } = params;
const surface = (ctx.Surface ?? ctx.Provider ?? "").trim().toLowerCase();
const allowTextCommands = shouldHandleTextCommands({
cfg,
surface,
commandSource: ctx.CommandSource,
});
if (!allowTextCommands) return { handled: false, aborted: false };
const commandAuthorized = ctx.CommandAuthorized ?? true;
const auth = resolveCommandAuthorization({
ctx,
cfg,
commandAuthorized,
});
if (!auth.isAuthorizedSender) return { handled: false, aborted: false };
const targetKey = resolveAbortTargetKey(ctx);
const agentId = resolveAgentIdFromSessionKey(
targetKey ?? ctx.SessionKey ?? "",
);
const raw = stripStructuralPrefixes(ctx.Body ?? "");
const isGroup = ctx.ChatType?.trim().toLowerCase() === "group";
const stripped = isGroup ? stripMentions(raw, ctx, cfg, agentId) : raw;
const normalized = normalizeCommandBody(stripped);
const abortRequested = normalized === "/stop" || isAbortTrigger(stripped);
if (!abortRequested) return { handled: false, aborted: false };
const abortKey = targetKey ?? auth.from ?? auth.to;
if (targetKey) {
const storePath = resolveStorePath(cfg.session?.store, { agentId });
const store = loadSessionStore(storePath);
const { entry, key } = resolveSessionEntryForKey(store, targetKey);
const sessionId = entry?.sessionId;
const aborted = sessionId ? abortEmbeddedPiRun(sessionId) : false;
if (entry && key) {
entry.abortedLastRun = true;
entry.updatedAt = Date.now();
store[key] = entry;
await saveSessionStore(storePath, store);
} else if (abortKey) {
setAbortMemory(abortKey, true);
}
return { handled: true, aborted };
}
if (abortKey) {
setAbortMemory(abortKey, true);
}
return { handled: true, aborted: false };
}

View File

@@ -7,6 +7,7 @@ import type { ReplyDispatcher } from "./reply-dispatcher.js";
const mocks = vi.hoisted(() => ({
routeReply: vi.fn(async () => ({ ok: true, messageId: "mock" })),
tryFastAbortFromMessage: vi.fn(async () => ({ handled: false, aborted: false })),
}));
vi.mock("./route-reply.js", () => ({
@@ -25,6 +26,10 @@ vi.mock("./route-reply.js", () => ({
routeReply: mocks.routeReply,
}));
vi.mock("./abort.js", () => ({
tryFastAbortFromMessage: mocks.tryFastAbortFromMessage,
}));
const { dispatchReplyFromConfig } = await import("./dispatch-from-config.js");
function createDispatcher(): ReplyDispatcher {
@@ -39,6 +44,10 @@ function createDispatcher(): ReplyDispatcher {
describe("dispatchReplyFromConfig", () => {
it("does not route when Provider matches OriginatingChannel (even if Surface is missing)", async () => {
mocks.tryFastAbortFromMessage.mockResolvedValue({
handled: false,
aborted: false,
});
mocks.routeReply.mockClear();
const cfg = {} as ClawdbotConfig;
const dispatcher = createDispatcher();
@@ -60,6 +69,10 @@ describe("dispatchReplyFromConfig", () => {
});
it("routes when OriginatingChannel differs from Provider", async () => {
mocks.tryFastAbortFromMessage.mockResolvedValue({
handled: false,
aborted: false,
});
mocks.routeReply.mockClear();
const cfg = {} as ClawdbotConfig;
const dispatcher = createDispatcher();
@@ -88,4 +101,25 @@ describe("dispatchReplyFromConfig", () => {
}),
);
});
it("fast-aborts without calling the reply resolver", async () => {
mocks.tryFastAbortFromMessage.mockResolvedValue({
handled: true,
aborted: true,
});
const cfg = {} as ClawdbotConfig;
const dispatcher = createDispatcher();
const ctx: MsgContext = {
Provider: "telegram",
Body: "/stop",
};
const replyResolver = vi.fn(async () => ({ text: "hi" }) as ReplyPayload);
await dispatchReplyFromConfig({ ctx, cfg, dispatcher, replyResolver });
expect(replyResolver).not.toHaveBeenCalled();
expect(dispatcher.sendFinalReply).toHaveBeenCalledWith({
text: "⚙️ Agent was aborted.",
});
});
});

View File

@@ -3,6 +3,7 @@ import { logVerbose } from "../../globals.js";
import { getReplyFromConfig } from "../reply.js";
import type { MsgContext } from "../templating.js";
import type { GetReplyOptions, ReplyPayload } from "../types.js";
import { tryFastAbortFromMessage } from "./abort.js";
import type { ReplyDispatcher, ReplyDispatchKind } from "./reply-dispatcher.js";
import { isRoutableChannel, routeReply } from "./route-reply.js";
@@ -66,6 +67,37 @@ export async function dispatchReplyFromConfig(params: {
}
};
const fastAbort = await tryFastAbortFromMessage({ ctx, cfg });
if (fastAbort.handled) {
const payload = { text: "⚙️ Agent was aborted." } satisfies ReplyPayload;
let queuedFinal = false;
let routedFinalCount = 0;
if (shouldRouteToOriginating && originatingChannel && originatingTo) {
const result = await routeReply({
payload,
channel: originatingChannel,
to: originatingTo,
sessionKey: ctx.SessionKey,
accountId: ctx.AccountId,
threadId: ctx.MessageThreadId,
cfg,
});
queuedFinal = result.ok;
if (result.ok) routedFinalCount += 1;
if (!result.ok) {
logVerbose(
`dispatch-from-config: route-reply (abort) failed: ${result.error ?? "unknown error"}`,
);
}
} else {
queuedFinal = dispatcher.sendFinalReply(payload);
}
await dispatcher.waitForIdle();
const counts = dispatcher.getQueuedCounts();
counts.final += routedFinalCount;
return { queuedFinal, counts };
}
const replyResult = await (params.replyResolver ?? getReplyFromConfig)(
ctx,
{