Merge origin/main into fix-discord-accountId
This commit is contained in:
@@ -1,70 +0,0 @@
|
||||
import fs from "node:fs/promises";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import {
|
||||
type AuthProfileStore,
|
||||
ensureAuthProfileStore,
|
||||
resolveApiKeyForProfile,
|
||||
} from "./auth-profiles.js";
|
||||
|
||||
vi.mock("@mariozechner/pi-ai", () => ({
|
||||
getOAuthApiKey: vi.fn(() => {
|
||||
throw new Error("refresh should not be called");
|
||||
}),
|
||||
}));
|
||||
|
||||
describe("auth-profiles (github-copilot)", () => {
|
||||
const previousStateDir = process.env.CLAWDBOT_STATE_DIR;
|
||||
const previousAgentDir = process.env.CLAWDBOT_AGENT_DIR;
|
||||
const previousPiAgentDir = process.env.PI_CODING_AGENT_DIR;
|
||||
let tempDir: string | null = null;
|
||||
|
||||
afterEach(async () => {
|
||||
vi.unstubAllGlobals();
|
||||
if (tempDir) {
|
||||
await fs.rm(tempDir, { recursive: true, force: true });
|
||||
tempDir = null;
|
||||
}
|
||||
if (previousStateDir === undefined) delete process.env.CLAWDBOT_STATE_DIR;
|
||||
else process.env.CLAWDBOT_STATE_DIR = previousStateDir;
|
||||
if (previousAgentDir === undefined) delete process.env.CLAWDBOT_AGENT_DIR;
|
||||
else process.env.CLAWDBOT_AGENT_DIR = previousAgentDir;
|
||||
if (previousPiAgentDir === undefined) delete process.env.PI_CODING_AGENT_DIR;
|
||||
else process.env.PI_CODING_AGENT_DIR = previousPiAgentDir;
|
||||
});
|
||||
|
||||
it("treats copilot oauth tokens with expires=0 as non-expiring", async () => {
|
||||
tempDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-copilot-"));
|
||||
process.env.CLAWDBOT_STATE_DIR = tempDir;
|
||||
process.env.CLAWDBOT_AGENT_DIR = path.join(tempDir, "agents", "main", "agent");
|
||||
process.env.PI_CODING_AGENT_DIR = process.env.CLAWDBOT_AGENT_DIR;
|
||||
|
||||
const authProfilePath = path.join(tempDir, "agents", "main", "agent", "auth-profiles.json");
|
||||
await fs.mkdir(path.dirname(authProfilePath), { recursive: true });
|
||||
|
||||
const store: AuthProfileStore = {
|
||||
version: 1,
|
||||
profiles: {
|
||||
"github-copilot:github": {
|
||||
type: "oauth",
|
||||
provider: "github-copilot",
|
||||
refresh: "gh-token",
|
||||
access: "gh-token",
|
||||
expires: 0,
|
||||
enterpriseUrl: "company.ghe.com",
|
||||
},
|
||||
},
|
||||
};
|
||||
await fs.writeFile(authProfilePath, `${JSON.stringify(store)}\n`);
|
||||
|
||||
const loaded = ensureAuthProfileStore();
|
||||
const resolved = await resolveApiKeyForProfile({
|
||||
store: loaded,
|
||||
profileId: "github-copilot:github",
|
||||
});
|
||||
|
||||
expect(resolved?.apiKey).toBe("gh-token");
|
||||
});
|
||||
});
|
||||
@@ -39,15 +39,6 @@ async function refreshOAuthTokenWithLock(params: {
|
||||
const cred = store.profiles[params.profileId];
|
||||
if (!cred || cred.type !== "oauth") return null;
|
||||
|
||||
if (
|
||||
cred.provider === "github-copilot" &&
|
||||
(!Number.isFinite(cred.expires) || cred.expires <= 0)
|
||||
) {
|
||||
return {
|
||||
apiKey: buildOAuthApiKey(cred.provider, cred),
|
||||
newCredentials: cred,
|
||||
};
|
||||
}
|
||||
if (Date.now() < cred.expires) {
|
||||
return {
|
||||
apiKey: buildOAuthApiKey(cred.provider, cred),
|
||||
@@ -112,20 +103,6 @@ async function tryResolveOAuthProfile(params: {
|
||||
if (profileConfig && profileConfig.provider !== cred.provider) return null;
|
||||
if (profileConfig && profileConfig.mode !== cred.type) return null;
|
||||
|
||||
if (cred.provider === "github-copilot" && (!Number.isFinite(cred.expires) || cred.expires <= 0)) {
|
||||
return {
|
||||
apiKey: buildOAuthApiKey(cred.provider, cred),
|
||||
provider: cred.provider,
|
||||
email: cred.email,
|
||||
};
|
||||
}
|
||||
if (cred.provider === "github-copilot" && (!Number.isFinite(cred.expires) || cred.expires <= 0)) {
|
||||
return {
|
||||
apiKey: buildOAuthApiKey(cred.provider, cred),
|
||||
provider: cred.provider,
|
||||
email: cred.email,
|
||||
};
|
||||
}
|
||||
if (Date.now() < cred.expires) {
|
||||
return {
|
||||
apiKey: buildOAuthApiKey(cred.provider, cred),
|
||||
|
||||
@@ -19,7 +19,6 @@ export type TokenCredential = {
|
||||
token: string;
|
||||
/** Optional expiry timestamp (ms since epoch). */
|
||||
expires?: number;
|
||||
enterpriseUrl?: string;
|
||||
email?: string;
|
||||
};
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import path from "node:path";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { withTempHome as withTempHomeBase } from "../../test/helpers/temp-home.js";
|
||||
import type { ClawdbotConfig } from "../config/config.js";
|
||||
import { DEFAULT_GITHUB_COPILOT_BASE_URL } from "../providers/github-copilot-utils.js";
|
||||
|
||||
async function withTempHome<T>(fn: (home: string) => Promise<T>): Promise<T> {
|
||||
return withTempHomeBase(fn, { prefix: "clawdbot-models-" });
|
||||
@@ -52,6 +51,16 @@ describe("models-config", () => {
|
||||
try {
|
||||
vi.resetModules();
|
||||
|
||||
vi.doMock("../providers/github-copilot-token.js", () => ({
|
||||
DEFAULT_COPILOT_API_BASE_URL: "https://api.individual.githubcopilot.com",
|
||||
resolveCopilotApiToken: vi.fn().mockResolvedValue({
|
||||
token: "copilot",
|
||||
expiresAt: Date.now() + 60 * 60 * 1000,
|
||||
source: "mock",
|
||||
baseUrl: "https://api.copilot.example",
|
||||
}),
|
||||
}));
|
||||
|
||||
const { ensureClawdbotModelsJson } = await import("./models-config.js");
|
||||
|
||||
const agentDir = path.join(home, "agent-default-base-url");
|
||||
@@ -62,55 +71,48 @@ describe("models-config", () => {
|
||||
providers: Record<string, { baseUrl?: string; models?: unknown[] }>;
|
||||
};
|
||||
|
||||
expect(parsed.providers["github-copilot"]?.baseUrl).toBe(DEFAULT_GITHUB_COPILOT_BASE_URL);
|
||||
expect(parsed.providers["github-copilot"]?.baseUrl).toBe("https://api.copilot.example");
|
||||
expect(parsed.providers["github-copilot"]?.models?.length ?? 0).toBe(0);
|
||||
} finally {
|
||||
process.env.COPILOT_GITHUB_TOKEN = previous;
|
||||
}
|
||||
});
|
||||
});
|
||||
it("uses enterprise URL from auth profiles to derive base URL", async () => {
|
||||
it("prefers COPILOT_GITHUB_TOKEN over GH_TOKEN and GITHUB_TOKEN", async () => {
|
||||
await withTempHome(async () => {
|
||||
const previous = process.env.COPILOT_GITHUB_TOKEN;
|
||||
const previousGh = process.env.GH_TOKEN;
|
||||
const previousGithub = process.env.GITHUB_TOKEN;
|
||||
process.env.COPILOT_GITHUB_TOKEN = "copilot-token";
|
||||
process.env.GH_TOKEN = "gh-token";
|
||||
process.env.GITHUB_TOKEN = "github-token";
|
||||
|
||||
try {
|
||||
vi.resetModules();
|
||||
|
||||
const agentDir = path.join(process.env.HOME ?? home, "agent-enterprise");
|
||||
await fs.mkdir(agentDir, { recursive: true });
|
||||
await fs.writeFile(
|
||||
path.join(agentDir, "auth-profiles.json"),
|
||||
JSON.stringify(
|
||||
{
|
||||
version: 1,
|
||||
profiles: {
|
||||
"github-copilot:github": {
|
||||
type: "oauth",
|
||||
provider: "github-copilot",
|
||||
refresh: "gh-token",
|
||||
access: "gh-token",
|
||||
expires: 0,
|
||||
enterpriseUrl: "company.ghe.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
null,
|
||||
2,
|
||||
),
|
||||
);
|
||||
const resolveCopilotApiToken = vi.fn().mockResolvedValue({
|
||||
token: "copilot",
|
||||
expiresAt: Date.now() + 60 * 60 * 1000,
|
||||
source: "mock",
|
||||
baseUrl: "https://api.copilot.example",
|
||||
});
|
||||
|
||||
vi.doMock("../providers/github-copilot-token.js", () => ({
|
||||
DEFAULT_COPILOT_API_BASE_URL: "https://api.individual.githubcopilot.com",
|
||||
resolveCopilotApiToken,
|
||||
}));
|
||||
|
||||
const { ensureClawdbotModelsJson } = await import("./models-config.js");
|
||||
|
||||
await ensureClawdbotModelsJson({ models: { providers: {} } }, agentDir);
|
||||
await ensureClawdbotModelsJson({ models: { providers: {} } });
|
||||
|
||||
const raw = await fs.readFile(path.join(agentDir, "models.json"), "utf8");
|
||||
const parsed = JSON.parse(raw) as {
|
||||
providers: Record<string, { baseUrl?: string }>;
|
||||
};
|
||||
|
||||
expect(parsed.providers["github-copilot"]?.baseUrl).toBe(
|
||||
"https://copilot-api.company.ghe.com",
|
||||
expect(resolveCopilotApiToken).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ githubToken: "copilot-token" }),
|
||||
);
|
||||
} finally {
|
||||
// no-op
|
||||
process.env.COPILOT_GITHUB_TOKEN = previous;
|
||||
process.env.GH_TOKEN = previousGh;
|
||||
process.env.GITHUB_TOKEN = previousGithub;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -3,7 +3,6 @@ import path from "node:path";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { withTempHome as withTempHomeBase } from "../../test/helpers/temp-home.js";
|
||||
import type { ClawdbotConfig } from "../config/config.js";
|
||||
import { DEFAULT_GITHUB_COPILOT_BASE_URL } from "../providers/github-copilot-utils.js";
|
||||
|
||||
async function withTempHome<T>(fn: (home: string) => Promise<T>): Promise<T> {
|
||||
return withTempHomeBase(fn, { prefix: "clawdbot-models-" });
|
||||
@@ -44,7 +43,7 @@ describe("models-config", () => {
|
||||
process.env.HOME = previousHome;
|
||||
});
|
||||
|
||||
it("uses default baseUrl when env token is present", async () => {
|
||||
it("falls back to default baseUrl when token exchange fails", async () => {
|
||||
await withTempHome(async () => {
|
||||
const previous = process.env.COPILOT_GITHUB_TOKEN;
|
||||
process.env.COPILOT_GITHUB_TOKEN = "gh-token";
|
||||
@@ -52,6 +51,11 @@ describe("models-config", () => {
|
||||
try {
|
||||
vi.resetModules();
|
||||
|
||||
vi.doMock("../providers/github-copilot-token.js", () => ({
|
||||
DEFAULT_COPILOT_API_BASE_URL: "https://api.default.test",
|
||||
resolveCopilotApiToken: vi.fn().mockRejectedValue(new Error("boom")),
|
||||
}));
|
||||
|
||||
const { ensureClawdbotModelsJson } = await import("./models-config.js");
|
||||
const { resolveClawdbotAgentDir } = await import("./agent-paths.js");
|
||||
|
||||
@@ -63,13 +67,13 @@ describe("models-config", () => {
|
||||
providers: Record<string, { baseUrl?: string }>;
|
||||
};
|
||||
|
||||
expect(parsed.providers["github-copilot"]?.baseUrl).toBe(DEFAULT_GITHUB_COPILOT_BASE_URL);
|
||||
expect(parsed.providers["github-copilot"]?.baseUrl).toBe("https://api.default.test");
|
||||
} finally {
|
||||
process.env.COPILOT_GITHUB_TOKEN = previous;
|
||||
}
|
||||
});
|
||||
});
|
||||
it("normalizes enterprise URL when deriving base URL", async () => {
|
||||
it("uses agentDir override auth profiles for copilot injection", async () => {
|
||||
await withTempHome(async (home) => {
|
||||
const previous = process.env.COPILOT_GITHUB_TOKEN;
|
||||
const previousGh = process.env.GH_TOKEN;
|
||||
@@ -90,12 +94,9 @@ describe("models-config", () => {
|
||||
version: 1,
|
||||
profiles: {
|
||||
"github-copilot:github": {
|
||||
type: "oauth",
|
||||
type: "token",
|
||||
provider: "github-copilot",
|
||||
refresh: "gh-profile-token",
|
||||
access: "gh-profile-token",
|
||||
expires: 0,
|
||||
enterpriseUrl: "https://company.ghe.com/",
|
||||
token: "gh-profile-token",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -104,6 +105,16 @@ describe("models-config", () => {
|
||||
),
|
||||
);
|
||||
|
||||
vi.doMock("../providers/github-copilot-token.js", () => ({
|
||||
DEFAULT_COPILOT_API_BASE_URL: "https://api.individual.githubcopilot.com",
|
||||
resolveCopilotApiToken: vi.fn().mockResolvedValue({
|
||||
token: "copilot",
|
||||
expiresAt: Date.now() + 60 * 60 * 1000,
|
||||
source: "mock",
|
||||
baseUrl: "https://api.copilot.example",
|
||||
}),
|
||||
}));
|
||||
|
||||
const { ensureClawdbotModelsJson } = await import("./models-config.js");
|
||||
|
||||
await ensureClawdbotModelsJson({ models: { providers: {} } }, agentDir);
|
||||
@@ -113,9 +124,7 @@ describe("models-config", () => {
|
||||
providers: Record<string, { baseUrl?: string }>;
|
||||
};
|
||||
|
||||
expect(parsed.providers["github-copilot"]?.baseUrl).toBe(
|
||||
"https://copilot-api.company.ghe.com",
|
||||
);
|
||||
expect(parsed.providers["github-copilot"]?.baseUrl).toBe("https://api.copilot.example");
|
||||
} finally {
|
||||
if (previous === undefined) delete process.env.COPILOT_GITHUB_TOKEN;
|
||||
else process.env.COPILOT_GITHUB_TOKEN = previous;
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import type { ClawdbotConfig } from "../config/config.js";
|
||||
import {
|
||||
normalizeGithubCopilotDomain,
|
||||
resolveGithubCopilotBaseUrl,
|
||||
} from "../providers/github-copilot-utils.js";
|
||||
DEFAULT_COPILOT_API_BASE_URL,
|
||||
resolveCopilotApiToken,
|
||||
} from "../providers/github-copilot-token.js";
|
||||
import { ensureAuthProfileStore, listProfilesForProvider } from "./auth-profiles.js";
|
||||
import { resolveAwsSdkEnvVarName, resolveEnvApiKey } from "./model-auth.js";
|
||||
import {
|
||||
@@ -331,18 +331,29 @@ export async function resolveImplicitCopilotProvider(params: {
|
||||
|
||||
if (!hasProfile && !githubToken) return null;
|
||||
|
||||
let enterpriseDomain: string | null = null;
|
||||
if (hasProfile) {
|
||||
let selectedGithubToken = githubToken;
|
||||
if (!selectedGithubToken && hasProfile) {
|
||||
// Use the first available profile as a default for discovery (it will be
|
||||
// re-resolved per-run by the embedded runner).
|
||||
const profileId = listProfilesForProvider(authStore, "github-copilot")[0];
|
||||
const profile = profileId ? authStore.profiles[profileId] : undefined;
|
||||
if (profile && "enterpriseUrl" in profile && typeof profile.enterpriseUrl === "string") {
|
||||
enterpriseDomain = normalizeGithubCopilotDomain(profile.enterpriseUrl);
|
||||
if (profile && profile.type === "token") {
|
||||
selectedGithubToken = profile.token;
|
||||
}
|
||||
}
|
||||
|
||||
const baseUrl = resolveGithubCopilotBaseUrl(enterpriseDomain);
|
||||
let baseUrl = DEFAULT_COPILOT_API_BASE_URL;
|
||||
if (selectedGithubToken) {
|
||||
try {
|
||||
const token = await resolveCopilotApiToken({
|
||||
githubToken: selectedGithubToken,
|
||||
env,
|
||||
});
|
||||
baseUrl = token.baseUrl;
|
||||
} catch {
|
||||
baseUrl = DEFAULT_COPILOT_API_BASE_URL;
|
||||
}
|
||||
}
|
||||
|
||||
// pi-coding-agent's ModelRegistry marks a model "available" only if its
|
||||
// `AuthStorage` has auth configured for that provider (via auth.json/env/etc).
|
||||
@@ -353,7 +364,7 @@ export async function resolveImplicitCopilotProvider(params: {
|
||||
// GitHub token (not the exchanged Copilot token), and (3) matches existing
|
||||
// patterns for OAuth-like providers in pi-coding-agent.
|
||||
// Note: we deliberately do not write pi-coding-agent's `auth.json` here.
|
||||
// Clawdbot uses its own auth store and passes the GitHub token at runtime.
|
||||
// Clawdbot uses its own auth store and exchanges tokens at runtime.
|
||||
// `models list` uses Clawdbot's auth heuristics for availability.
|
||||
|
||||
// We intentionally do NOT define custom models for Copilot in models.json.
|
||||
|
||||
@@ -3,7 +3,6 @@ import path from "node:path";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { withTempHome as withTempHomeBase } from "../../test/helpers/temp-home.js";
|
||||
import type { ClawdbotConfig } from "../config/config.js";
|
||||
import { DEFAULT_GITHUB_COPILOT_BASE_URL } from "../providers/github-copilot-utils.js";
|
||||
|
||||
async function withTempHome<T>(fn: (home: string) => Promise<T>): Promise<T> {
|
||||
return withTempHomeBase(fn, { prefix: "clawdbot-models-" });
|
||||
@@ -81,16 +80,25 @@ describe("models-config", () => {
|
||||
),
|
||||
);
|
||||
|
||||
const resolveCopilotApiToken = vi.fn().mockResolvedValue({
|
||||
token: "copilot",
|
||||
expiresAt: Date.now() + 60 * 60 * 1000,
|
||||
source: "mock",
|
||||
baseUrl: "https://api.copilot.example",
|
||||
});
|
||||
|
||||
vi.doMock("../providers/github-copilot-token.js", () => ({
|
||||
DEFAULT_COPILOT_API_BASE_URL: "https://api.individual.githubcopilot.com",
|
||||
resolveCopilotApiToken,
|
||||
}));
|
||||
|
||||
const { ensureClawdbotModelsJson } = await import("./models-config.js");
|
||||
|
||||
await ensureClawdbotModelsJson({ models: { providers: {} } }, agentDir);
|
||||
|
||||
const raw = await fs.readFile(path.join(agentDir, "models.json"), "utf8");
|
||||
const parsed = JSON.parse(raw) as {
|
||||
providers: Record<string, { baseUrl?: string }>;
|
||||
};
|
||||
|
||||
expect(parsed.providers["github-copilot"]?.baseUrl).toBe(DEFAULT_GITHUB_COPILOT_BASE_URL);
|
||||
expect(resolveCopilotApiToken).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ githubToken: "alpha-token" }),
|
||||
);
|
||||
} finally {
|
||||
if (previous === undefined) delete process.env.COPILOT_GITHUB_TOKEN;
|
||||
else process.env.COPILOT_GITHUB_TOKEN = previous;
|
||||
@@ -109,6 +117,16 @@ describe("models-config", () => {
|
||||
try {
|
||||
vi.resetModules();
|
||||
|
||||
vi.doMock("../providers/github-copilot-token.js", () => ({
|
||||
DEFAULT_COPILOT_API_BASE_URL: "https://api.individual.githubcopilot.com",
|
||||
resolveCopilotApiToken: vi.fn().mockResolvedValue({
|
||||
token: "copilot",
|
||||
expiresAt: Date.now() + 60 * 60 * 1000,
|
||||
source: "mock",
|
||||
baseUrl: "https://api.copilot.example",
|
||||
}),
|
||||
}));
|
||||
|
||||
const { ensureClawdbotModelsJson } = await import("./models-config.js");
|
||||
const { resolveClawdbotAgentDir } = await import("./agent-paths.js");
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import os from "node:os";
|
||||
import path from "node:path";
|
||||
|
||||
import type { AssistantMessage } from "@mariozechner/pi-ai";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
import type { ClawdbotConfig } from "../config/config.js";
|
||||
import type { EmbeddedRunAttemptResult } from "./pi-embedded-runner/run/types.js";
|
||||
@@ -16,13 +16,15 @@ vi.mock("./pi-embedded-runner/run/attempt.js", () => ({
|
||||
|
||||
let runEmbeddedPiAgent: typeof import("./pi-embedded-runner.js").runEmbeddedPiAgent;
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.useRealTimers();
|
||||
vi.resetModules();
|
||||
runEmbeddedAttemptMock.mockReset();
|
||||
beforeAll(async () => {
|
||||
({ runEmbeddedPiAgent } = await import("./pi-embedded-runner.js"));
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useRealTimers();
|
||||
runEmbeddedAttemptMock.mockReset();
|
||||
});
|
||||
|
||||
const baseUsage = {
|
||||
input: 0,
|
||||
output: 0,
|
||||
|
||||
@@ -128,6 +128,13 @@ export async function compactEmbeddedPiSession(params: {
|
||||
`No API key resolved for provider "${model.provider}" (auth mode: ${apiKeyInfo.mode}).`,
|
||||
);
|
||||
}
|
||||
} else if (model.provider === "github-copilot") {
|
||||
const { resolveCopilotApiToken } =
|
||||
await import("../../providers/github-copilot-token.js");
|
||||
const copilotToken = await resolveCopilotApiToken({
|
||||
githubToken: apiKeyInfo.apiKey,
|
||||
});
|
||||
authStorage.setRuntimeApiKey(model.provider, copilotToken.token);
|
||||
} else {
|
||||
authStorage.setRuntimeApiKey(model.provider, apiKeyInfo.apiKey);
|
||||
}
|
||||
|
||||
@@ -7,20 +7,9 @@ import { resolveClawdbotAgentDir } from "../agent-paths.js";
|
||||
import { DEFAULT_CONTEXT_TOKENS } from "../defaults.js";
|
||||
import { normalizeModelCompat } from "../model-compat.js";
|
||||
import { normalizeProviderId } from "../model-selection.js";
|
||||
import { resolveGithubCopilotUserAgent } from "../../providers/github-copilot-utils.js";
|
||||
|
||||
type InlineModelEntry = ModelDefinitionConfig & { provider: string };
|
||||
|
||||
function applyProviderModelOverrides(model: Model<Api>): Model<Api> {
|
||||
if (model.provider === "github-copilot") {
|
||||
const headers = model.headers
|
||||
? { ...model.headers, "User-Agent": resolveGithubCopilotUserAgent() }
|
||||
: { "User-Agent": resolveGithubCopilotUserAgent() };
|
||||
return { ...model, headers };
|
||||
}
|
||||
return model;
|
||||
}
|
||||
|
||||
export function buildInlineProviderModels(
|
||||
providers: Record<string, { models?: ModelDefinitionConfig[] }>,
|
||||
): InlineModelEntry[] {
|
||||
@@ -71,7 +60,7 @@ export function resolveModel(
|
||||
if (inlineMatch) {
|
||||
const normalized = normalizeModelCompat(inlineMatch as Model<Api>);
|
||||
return {
|
||||
model: applyProviderModelOverrides(normalized),
|
||||
model: normalized,
|
||||
authStorage,
|
||||
modelRegistry,
|
||||
};
|
||||
@@ -89,7 +78,7 @@ export function resolveModel(
|
||||
contextWindow: providerCfg?.models?.[0]?.contextWindow ?? DEFAULT_CONTEXT_TOKENS,
|
||||
maxTokens: providerCfg?.models?.[0]?.maxTokens ?? DEFAULT_CONTEXT_TOKENS,
|
||||
} as Model<Api>);
|
||||
return { model: applyProviderModelOverrides(fallbackModel), authStorage, modelRegistry };
|
||||
return { model: fallbackModel, authStorage, modelRegistry };
|
||||
}
|
||||
return {
|
||||
error: `Unknown model: ${provider}/${modelId}`,
|
||||
@@ -97,9 +86,5 @@ export function resolveModel(
|
||||
modelRegistry,
|
||||
};
|
||||
}
|
||||
return {
|
||||
model: applyProviderModelOverrides(normalizeModelCompat(model)),
|
||||
authStorage,
|
||||
modelRegistry,
|
||||
};
|
||||
return { model: normalizeModelCompat(model), authStorage, modelRegistry };
|
||||
}
|
||||
|
||||
@@ -184,8 +184,17 @@ export async function runEmbeddedPiAgent(
|
||||
lastProfileId = resolvedProfileId;
|
||||
return;
|
||||
}
|
||||
authStorage.setRuntimeApiKey(model.provider, apiKeyInfo.apiKey);
|
||||
lastProfileId = resolvedProfileId;
|
||||
if (model.provider === "github-copilot") {
|
||||
const { resolveCopilotApiToken } =
|
||||
await import("../../providers/github-copilot-token.js");
|
||||
const copilotToken = await resolveCopilotApiToken({
|
||||
githubToken: apiKeyInfo.apiKey,
|
||||
});
|
||||
authStorage.setRuntimeApiKey(model.provider, copilotToken.token);
|
||||
} else {
|
||||
authStorage.setRuntimeApiKey(model.provider, apiKeyInfo.apiKey);
|
||||
}
|
||||
lastProfileId = apiKeyInfo.profileId;
|
||||
};
|
||||
|
||||
const advanceAuthProfile = async (): Promise<boolean> => {
|
||||
|
||||
@@ -1,212 +0,0 @@
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { __testing, createClawdbotCodingTools } from "./pi-tools.js";
|
||||
import { createBrowserTool } from "./tools/browser-tool.js";
|
||||
|
||||
describe("createClawdbotCodingTools", () => {
|
||||
describe("Claude/Gemini alias support", () => {
|
||||
it("adds Claude-style aliases to schemas without dropping metadata", () => {
|
||||
const base: AgentTool = {
|
||||
name: "write",
|
||||
description: "test",
|
||||
parameters: {
|
||||
type: "object",
|
||||
required: ["path", "content"],
|
||||
properties: {
|
||||
path: { type: "string", description: "Path" },
|
||||
content: { type: "string", description: "Body" },
|
||||
},
|
||||
},
|
||||
execute: vi.fn(),
|
||||
};
|
||||
|
||||
const patched = __testing.patchToolSchemaForClaudeCompatibility(base);
|
||||
const params = patched.parameters as {
|
||||
properties?: Record<string, unknown>;
|
||||
required?: string[];
|
||||
};
|
||||
const props = params.properties ?? {};
|
||||
|
||||
expect(props.file_path).toEqual(props.path);
|
||||
expect(params.required ?? []).not.toContain("path");
|
||||
expect(params.required ?? []).not.toContain("file_path");
|
||||
});
|
||||
|
||||
it("normalizes file_path to path and enforces required groups at runtime", async () => {
|
||||
const execute = vi.fn(async (_id, args) => args);
|
||||
const tool: AgentTool = {
|
||||
name: "write",
|
||||
description: "test",
|
||||
parameters: {
|
||||
type: "object",
|
||||
required: ["path", "content"],
|
||||
properties: {
|
||||
path: { type: "string" },
|
||||
content: { type: "string" },
|
||||
},
|
||||
},
|
||||
execute,
|
||||
};
|
||||
|
||||
const wrapped = __testing.wrapToolParamNormalization(tool, [{ keys: ["path", "file_path"] }]);
|
||||
|
||||
await wrapped.execute("tool-1", { file_path: "foo.txt", content: "x" });
|
||||
expect(execute).toHaveBeenCalledWith(
|
||||
"tool-1",
|
||||
{ path: "foo.txt", content: "x" },
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
|
||||
await expect(wrapped.execute("tool-2", { content: "x" })).rejects.toThrow(
|
||||
/Missing required parameter/,
|
||||
);
|
||||
await expect(wrapped.execute("tool-3", { file_path: " ", content: "x" })).rejects.toThrow(
|
||||
/Missing required parameter/,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("keeps browser tool schema OpenAI-compatible without normalization", () => {
|
||||
const browser = createBrowserTool();
|
||||
const schema = browser.parameters as { type?: unknown; anyOf?: unknown };
|
||||
expect(schema.type).toBe("object");
|
||||
expect(schema.anyOf).toBeUndefined();
|
||||
});
|
||||
it("mentions Chrome extension relay in browser tool description", () => {
|
||||
const browser = createBrowserTool();
|
||||
expect(browser.description).toMatch(/Chrome extension/i);
|
||||
expect(browser.description).toMatch(/profile="chrome"/i);
|
||||
});
|
||||
it("keeps browser tool schema properties after normalization", () => {
|
||||
const tools = createClawdbotCodingTools();
|
||||
const browser = tools.find((tool) => tool.name === "browser");
|
||||
expect(browser).toBeDefined();
|
||||
const parameters = browser?.parameters as {
|
||||
anyOf?: unknown[];
|
||||
properties?: Record<string, unknown>;
|
||||
required?: string[];
|
||||
};
|
||||
expect(parameters.properties?.action).toBeDefined();
|
||||
expect(parameters.properties?.target).toBeDefined();
|
||||
expect(parameters.properties?.controlUrl).toBeDefined();
|
||||
expect(parameters.properties?.targetUrl).toBeDefined();
|
||||
expect(parameters.properties?.request).toBeDefined();
|
||||
expect(parameters.required ?? []).toContain("action");
|
||||
});
|
||||
it("exposes raw for gateway config.apply tool calls", () => {
|
||||
const tools = createClawdbotCodingTools();
|
||||
const gateway = tools.find((tool) => tool.name === "gateway");
|
||||
expect(gateway).toBeDefined();
|
||||
|
||||
const parameters = gateway?.parameters as {
|
||||
type?: unknown;
|
||||
required?: string[];
|
||||
properties?: Record<string, unknown>;
|
||||
};
|
||||
expect(parameters.type).toBe("object");
|
||||
expect(parameters.properties?.raw).toBeDefined();
|
||||
expect(parameters.required ?? []).not.toContain("raw");
|
||||
});
|
||||
it("flattens anyOf-of-literals to enum for provider compatibility", () => {
|
||||
const tools = createClawdbotCodingTools();
|
||||
const browser = tools.find((tool) => tool.name === "browser");
|
||||
expect(browser).toBeDefined();
|
||||
|
||||
const parameters = browser?.parameters as {
|
||||
properties?: Record<string, unknown>;
|
||||
};
|
||||
const action = parameters.properties?.action as
|
||||
| {
|
||||
type?: unknown;
|
||||
enum?: unknown[];
|
||||
anyOf?: unknown[];
|
||||
}
|
||||
| undefined;
|
||||
|
||||
expect(action?.type).toBe("string");
|
||||
expect(action?.anyOf).toBeUndefined();
|
||||
expect(Array.isArray(action?.enum)).toBe(true);
|
||||
expect(action?.enum).toContain("act");
|
||||
|
||||
const snapshotFormat = parameters.properties?.snapshotFormat as
|
||||
| {
|
||||
type?: unknown;
|
||||
enum?: unknown[];
|
||||
anyOf?: unknown[];
|
||||
}
|
||||
| undefined;
|
||||
expect(snapshotFormat?.type).toBe("string");
|
||||
expect(snapshotFormat?.anyOf).toBeUndefined();
|
||||
expect(snapshotFormat?.enum).toEqual(["aria", "ai"]);
|
||||
});
|
||||
it("inlines local $ref before removing unsupported keywords", () => {
|
||||
const cleaned = __testing.cleanToolSchemaForGemini({
|
||||
type: "object",
|
||||
properties: {
|
||||
foo: { $ref: "#/$defs/Foo" },
|
||||
},
|
||||
$defs: {
|
||||
Foo: { type: "string", enum: ["a", "b"] },
|
||||
},
|
||||
}) as {
|
||||
$defs?: unknown;
|
||||
properties?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
expect(cleaned.$defs).toBeUndefined();
|
||||
expect(cleaned.properties).toBeDefined();
|
||||
expect(cleaned.properties?.foo).toMatchObject({
|
||||
type: "string",
|
||||
enum: ["a", "b"],
|
||||
});
|
||||
});
|
||||
it("cleans tuple items schemas", () => {
|
||||
const cleaned = __testing.cleanToolSchemaForGemini({
|
||||
type: "object",
|
||||
properties: {
|
||||
tuples: {
|
||||
type: "array",
|
||||
items: [
|
||||
{ type: "string", format: "uuid" },
|
||||
{ type: "number", minimum: 1 },
|
||||
],
|
||||
},
|
||||
},
|
||||
}) as {
|
||||
properties?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
const tuples = cleaned.properties?.tuples as { items?: unknown } | undefined;
|
||||
const items = Array.isArray(tuples?.items) ? tuples?.items : [];
|
||||
const first = items[0] as { format?: unknown } | undefined;
|
||||
const second = items[1] as { minimum?: unknown } | undefined;
|
||||
|
||||
expect(first?.format).toBeUndefined();
|
||||
expect(second?.minimum).toBeUndefined();
|
||||
});
|
||||
it("drops null-only union variants without flattening other unions", () => {
|
||||
const cleaned = __testing.cleanToolSchemaForGemini({
|
||||
type: "object",
|
||||
properties: {
|
||||
parentId: { anyOf: [{ type: "string" }, { type: "null" }] },
|
||||
count: { oneOf: [{ type: "string" }, { type: "number" }] },
|
||||
},
|
||||
}) as {
|
||||
properties?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
const parentId = cleaned.properties?.parentId as
|
||||
| { type?: unknown; anyOf?: unknown; oneOf?: unknown }
|
||||
| undefined;
|
||||
expect(parentId?.anyOf).toBeUndefined();
|
||||
expect(parentId?.oneOf).toBeUndefined();
|
||||
expect(parentId?.type).toBe("string");
|
||||
|
||||
const count = cleaned.properties?.count as
|
||||
| { type?: unknown; anyOf?: unknown; oneOf?: unknown }
|
||||
| undefined;
|
||||
expect(count?.anyOf).toBeUndefined();
|
||||
expect(Array.isArray(count?.oneOf)).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -1,74 +1,11 @@
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import type { ClawdbotConfig } from "../config/config.js";
|
||||
import { __testing, createClawdbotCodingTools } from "./pi-tools.js";
|
||||
import { createClawdbotCodingTools } from "./pi-tools.js";
|
||||
|
||||
const defaultTools = createClawdbotCodingTools();
|
||||
|
||||
describe("createClawdbotCodingTools", () => {
|
||||
describe("Claude/Gemini alias support", () => {
|
||||
it("adds Claude-style aliases to schemas without dropping metadata", () => {
|
||||
const base: AgentTool = {
|
||||
name: "write",
|
||||
description: "test",
|
||||
parameters: {
|
||||
type: "object",
|
||||
required: ["path", "content"],
|
||||
properties: {
|
||||
path: { type: "string", description: "Path" },
|
||||
content: { type: "string", description: "Body" },
|
||||
},
|
||||
},
|
||||
execute: vi.fn(),
|
||||
};
|
||||
|
||||
const patched = __testing.patchToolSchemaForClaudeCompatibility(base);
|
||||
const params = patched.parameters as {
|
||||
properties?: Record<string, unknown>;
|
||||
required?: string[];
|
||||
};
|
||||
const props = params.properties ?? {};
|
||||
|
||||
expect(props.file_path).toEqual(props.path);
|
||||
expect(params.required ?? []).not.toContain("path");
|
||||
expect(params.required ?? []).not.toContain("file_path");
|
||||
});
|
||||
|
||||
it("normalizes file_path to path and enforces required groups at runtime", async () => {
|
||||
const execute = vi.fn(async (_id, args) => args);
|
||||
const tool: AgentTool = {
|
||||
name: "write",
|
||||
description: "test",
|
||||
parameters: {
|
||||
type: "object",
|
||||
required: ["path", "content"],
|
||||
properties: {
|
||||
path: { type: "string" },
|
||||
content: { type: "string" },
|
||||
},
|
||||
},
|
||||
execute,
|
||||
};
|
||||
|
||||
const wrapped = __testing.wrapToolParamNormalization(tool, [{ keys: ["path", "file_path"] }]);
|
||||
|
||||
await wrapped.execute("tool-1", { file_path: "foo.txt", content: "x" });
|
||||
expect(execute).toHaveBeenCalledWith(
|
||||
"tool-1",
|
||||
{ path: "foo.txt", content: "x" },
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
|
||||
await expect(wrapped.execute("tool-2", { content: "x" })).rejects.toThrow(
|
||||
/Missing required parameter/,
|
||||
);
|
||||
await expect(wrapped.execute("tool-3", { file_path: " ", content: "x" })).rejects.toThrow(
|
||||
/Missing required parameter/,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("preserves action enums in normalized schemas", () => {
|
||||
const tools = createClawdbotCodingTools();
|
||||
const toolNames = ["browser", "canvas", "nodes", "cron", "gateway", "message"];
|
||||
|
||||
const collectActionValues = (schema: unknown, values: Set<string>): void => {
|
||||
@@ -88,7 +25,7 @@ describe("createClawdbotCodingTools", () => {
|
||||
};
|
||||
|
||||
for (const name of toolNames) {
|
||||
const tool = tools.find((candidate) => candidate.name === name);
|
||||
const tool = defaultTools.find((candidate) => candidate.name === name);
|
||||
expect(tool).toBeDefined();
|
||||
const parameters = tool?.parameters as {
|
||||
properties?: Record<string, unknown>;
|
||||
@@ -108,10 +45,9 @@ describe("createClawdbotCodingTools", () => {
|
||||
}
|
||||
});
|
||||
it("includes exec and process tools by default", () => {
|
||||
const tools = createClawdbotCodingTools();
|
||||
expect(tools.some((tool) => tool.name === "exec")).toBe(true);
|
||||
expect(tools.some((tool) => tool.name === "process")).toBe(true);
|
||||
expect(tools.some((tool) => tool.name === "apply_patch")).toBe(false);
|
||||
expect(defaultTools.some((tool) => tool.name === "exec")).toBe(true);
|
||||
expect(defaultTools.some((tool) => tool.name === "process")).toBe(true);
|
||||
expect(defaultTools.some((tool) => tool.name === "apply_patch")).toBe(false);
|
||||
});
|
||||
it("gates apply_patch behind tools.exec.applyPatch for OpenAI models", () => {
|
||||
const config: ClawdbotConfig = {
|
||||
|
||||
@@ -1,78 +1,15 @@
|
||||
import fs from "node:fs/promises";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import sharp from "sharp";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { __testing, createClawdbotCodingTools } from "./pi-tools.js";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { createClawdbotCodingTools } from "./pi-tools.js";
|
||||
|
||||
const defaultTools = createClawdbotCodingTools();
|
||||
|
||||
describe("createClawdbotCodingTools", () => {
|
||||
describe("Claude/Gemini alias support", () => {
|
||||
it("adds Claude-style aliases to schemas without dropping metadata", () => {
|
||||
const base: AgentTool = {
|
||||
name: "write",
|
||||
description: "test",
|
||||
parameters: {
|
||||
type: "object",
|
||||
required: ["path", "content"],
|
||||
properties: {
|
||||
path: { type: "string", description: "Path" },
|
||||
content: { type: "string", description: "Body" },
|
||||
},
|
||||
},
|
||||
execute: vi.fn(),
|
||||
};
|
||||
|
||||
const patched = __testing.patchToolSchemaForClaudeCompatibility(base);
|
||||
const params = patched.parameters as {
|
||||
properties?: Record<string, unknown>;
|
||||
required?: string[];
|
||||
};
|
||||
const props = params.properties ?? {};
|
||||
|
||||
expect(props.file_path).toEqual(props.path);
|
||||
expect(params.required ?? []).not.toContain("path");
|
||||
expect(params.required ?? []).not.toContain("file_path");
|
||||
});
|
||||
|
||||
it("normalizes file_path to path and enforces required groups at runtime", async () => {
|
||||
const execute = vi.fn(async (_id, args) => args);
|
||||
const tool: AgentTool = {
|
||||
name: "write",
|
||||
description: "test",
|
||||
parameters: {
|
||||
type: "object",
|
||||
required: ["path", "content"],
|
||||
properties: {
|
||||
path: { type: "string" },
|
||||
content: { type: "string" },
|
||||
},
|
||||
},
|
||||
execute,
|
||||
};
|
||||
|
||||
const wrapped = __testing.wrapToolParamNormalization(tool, [{ keys: ["path", "file_path"] }]);
|
||||
|
||||
await wrapped.execute("tool-1", { file_path: "foo.txt", content: "x" });
|
||||
expect(execute).toHaveBeenCalledWith(
|
||||
"tool-1",
|
||||
{ path: "foo.txt", content: "x" },
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
|
||||
await expect(wrapped.execute("tool-2", { content: "x" })).rejects.toThrow(
|
||||
/Missing required parameter/,
|
||||
);
|
||||
await expect(wrapped.execute("tool-3", { file_path: " ", content: "x" })).rejects.toThrow(
|
||||
/Missing required parameter/,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("keeps read tool image metadata intact", async () => {
|
||||
const tools = createClawdbotCodingTools();
|
||||
const readTool = tools.find((tool) => tool.name === "read");
|
||||
const readTool = defaultTools.find((tool) => tool.name === "read");
|
||||
expect(readTool).toBeDefined();
|
||||
|
||||
const tmpDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-read-"));
|
||||
|
||||
@@ -1,183 +0,0 @@
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { __testing, createClawdbotCodingTools } from "./pi-tools.js";
|
||||
|
||||
describe("createClawdbotCodingTools", () => {
|
||||
describe("Claude/Gemini alias support", () => {
|
||||
it("adds Claude-style aliases to schemas without dropping metadata", () => {
|
||||
const base: AgentTool = {
|
||||
name: "write",
|
||||
description: "test",
|
||||
parameters: {
|
||||
type: "object",
|
||||
required: ["path", "content"],
|
||||
properties: {
|
||||
path: { type: "string", description: "Path" },
|
||||
content: { type: "string", description: "Body" },
|
||||
},
|
||||
},
|
||||
execute: vi.fn(),
|
||||
};
|
||||
|
||||
const patched = __testing.patchToolSchemaForClaudeCompatibility(base);
|
||||
const params = patched.parameters as {
|
||||
properties?: Record<string, unknown>;
|
||||
required?: string[];
|
||||
};
|
||||
const props = params.properties ?? {};
|
||||
|
||||
expect(props.file_path).toEqual(props.path);
|
||||
expect(params.required ?? []).not.toContain("path");
|
||||
expect(params.required ?? []).not.toContain("file_path");
|
||||
});
|
||||
|
||||
it("normalizes file_path to path and enforces required groups at runtime", async () => {
|
||||
const execute = vi.fn(async (_id, args) => args);
|
||||
const tool: AgentTool = {
|
||||
name: "write",
|
||||
description: "test",
|
||||
parameters: {
|
||||
type: "object",
|
||||
required: ["path", "content"],
|
||||
properties: {
|
||||
path: { type: "string" },
|
||||
content: { type: "string" },
|
||||
},
|
||||
},
|
||||
execute,
|
||||
};
|
||||
|
||||
const wrapped = __testing.wrapToolParamNormalization(tool, [{ keys: ["path", "file_path"] }]);
|
||||
|
||||
await wrapped.execute("tool-1", { file_path: "foo.txt", content: "x" });
|
||||
expect(execute).toHaveBeenCalledWith(
|
||||
"tool-1",
|
||||
{ path: "foo.txt", content: "x" },
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
|
||||
await expect(wrapped.execute("tool-2", { content: "x" })).rejects.toThrow(
|
||||
/Missing required parameter/,
|
||||
);
|
||||
await expect(wrapped.execute("tool-3", { file_path: " ", content: "x" })).rejects.toThrow(
|
||||
/Missing required parameter/,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("applies tool profiles before allow/deny policies", () => {
|
||||
const tools = createClawdbotCodingTools({
|
||||
config: { tools: { profile: "messaging" } },
|
||||
});
|
||||
const names = new Set(tools.map((tool) => tool.name));
|
||||
expect(names.has("message")).toBe(true);
|
||||
expect(names.has("sessions_send")).toBe(true);
|
||||
expect(names.has("sessions_spawn")).toBe(false);
|
||||
expect(names.has("exec")).toBe(false);
|
||||
expect(names.has("browser")).toBe(false);
|
||||
});
|
||||
it("expands group shorthands in global tool policy", () => {
|
||||
const tools = createClawdbotCodingTools({
|
||||
config: { tools: { allow: ["group:fs"] } },
|
||||
});
|
||||
const names = new Set(tools.map((tool) => tool.name));
|
||||
expect(names.has("read")).toBe(true);
|
||||
expect(names.has("write")).toBe(true);
|
||||
expect(names.has("edit")).toBe(true);
|
||||
expect(names.has("exec")).toBe(false);
|
||||
expect(names.has("browser")).toBe(false);
|
||||
});
|
||||
it("expands group shorthands in global tool deny policy", () => {
|
||||
const tools = createClawdbotCodingTools({
|
||||
config: { tools: { deny: ["group:fs"] } },
|
||||
});
|
||||
const names = new Set(tools.map((tool) => tool.name));
|
||||
expect(names.has("read")).toBe(false);
|
||||
expect(names.has("write")).toBe(false);
|
||||
expect(names.has("edit")).toBe(false);
|
||||
expect(names.has("exec")).toBe(true);
|
||||
});
|
||||
it("lets agent profiles override global profiles", () => {
|
||||
const tools = createClawdbotCodingTools({
|
||||
sessionKey: "agent:work:main",
|
||||
config: {
|
||||
tools: { profile: "coding" },
|
||||
agents: {
|
||||
list: [{ id: "work", tools: { profile: "messaging" } }],
|
||||
},
|
||||
},
|
||||
});
|
||||
const names = new Set(tools.map((tool) => tool.name));
|
||||
expect(names.has("message")).toBe(true);
|
||||
expect(names.has("exec")).toBe(false);
|
||||
expect(names.has("read")).toBe(false);
|
||||
});
|
||||
it("removes unsupported JSON Schema keywords for Cloud Code Assist API compatibility", () => {
|
||||
const tools = createClawdbotCodingTools();
|
||||
|
||||
// Helper to recursively check schema for unsupported keywords
|
||||
const unsupportedKeywords = new Set([
|
||||
"patternProperties",
|
||||
"additionalProperties",
|
||||
"$schema",
|
||||
"$id",
|
||||
"$ref",
|
||||
"$defs",
|
||||
"definitions",
|
||||
"examples",
|
||||
"minLength",
|
||||
"maxLength",
|
||||
"minimum",
|
||||
"maximum",
|
||||
"multipleOf",
|
||||
"pattern",
|
||||
"format",
|
||||
"minItems",
|
||||
"maxItems",
|
||||
"uniqueItems",
|
||||
"minProperties",
|
||||
"maxProperties",
|
||||
]);
|
||||
|
||||
const findUnsupportedKeywords = (schema: unknown, path: string): string[] => {
|
||||
const found: string[] = [];
|
||||
if (!schema || typeof schema !== "object") return found;
|
||||
if (Array.isArray(schema)) {
|
||||
schema.forEach((item, i) => {
|
||||
found.push(...findUnsupportedKeywords(item, `${path}[${i}]`));
|
||||
});
|
||||
return found;
|
||||
}
|
||||
|
||||
const record = schema as Record<string, unknown>;
|
||||
const properties =
|
||||
record.properties &&
|
||||
typeof record.properties === "object" &&
|
||||
!Array.isArray(record.properties)
|
||||
? (record.properties as Record<string, unknown>)
|
||||
: undefined;
|
||||
if (properties) {
|
||||
for (const [key, value] of Object.entries(properties)) {
|
||||
found.push(...findUnsupportedKeywords(value, `${path}.properties.${key}`));
|
||||
}
|
||||
}
|
||||
|
||||
for (const [key, value] of Object.entries(record)) {
|
||||
if (key === "properties") continue;
|
||||
if (unsupportedKeywords.has(key)) {
|
||||
found.push(`${path}.${key}`);
|
||||
}
|
||||
if (value && typeof value === "object") {
|
||||
found.push(...findUnsupportedKeywords(value, `${path}.${key}`));
|
||||
}
|
||||
}
|
||||
return found;
|
||||
};
|
||||
|
||||
for (const tool of tools) {
|
||||
const violations = findUnsupportedKeywords(tool.parameters, `${tool.name}.parameters`);
|
||||
expect(violations).toEqual([]);
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -1,74 +1,10 @@
|
||||
import fs from "node:fs/promises";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { __testing, createClawdbotCodingTools } from "./pi-tools.js";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { createClawdbotCodingTools } from "./pi-tools.js";
|
||||
|
||||
describe("createClawdbotCodingTools", () => {
|
||||
describe("Claude/Gemini alias support", () => {
|
||||
it("adds Claude-style aliases to schemas without dropping metadata", () => {
|
||||
const base: AgentTool = {
|
||||
name: "write",
|
||||
description: "test",
|
||||
parameters: {
|
||||
type: "object",
|
||||
required: ["path", "content"],
|
||||
properties: {
|
||||
path: { type: "string", description: "Path" },
|
||||
content: { type: "string", description: "Body" },
|
||||
},
|
||||
},
|
||||
execute: vi.fn(),
|
||||
};
|
||||
|
||||
const patched = __testing.patchToolSchemaForClaudeCompatibility(base);
|
||||
const params = patched.parameters as {
|
||||
properties?: Record<string, unknown>;
|
||||
required?: string[];
|
||||
};
|
||||
const props = params.properties ?? {};
|
||||
|
||||
expect(props.file_path).toEqual(props.path);
|
||||
expect(params.required ?? []).not.toContain("path");
|
||||
expect(params.required ?? []).not.toContain("file_path");
|
||||
});
|
||||
|
||||
it("normalizes file_path to path and enforces required groups at runtime", async () => {
|
||||
const execute = vi.fn(async (_id, args) => args);
|
||||
const tool: AgentTool = {
|
||||
name: "write",
|
||||
description: "test",
|
||||
parameters: {
|
||||
type: "object",
|
||||
required: ["path", "content"],
|
||||
properties: {
|
||||
path: { type: "string" },
|
||||
content: { type: "string" },
|
||||
},
|
||||
},
|
||||
execute,
|
||||
};
|
||||
|
||||
const wrapped = __testing.wrapToolParamNormalization(tool, [{ keys: ["path", "file_path"] }]);
|
||||
|
||||
await wrapped.execute("tool-1", { file_path: "foo.txt", content: "x" });
|
||||
expect(execute).toHaveBeenCalledWith(
|
||||
"tool-1",
|
||||
{ path: "foo.txt", content: "x" },
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
|
||||
await expect(wrapped.execute("tool-2", { content: "x" })).rejects.toThrow(
|
||||
/Missing required parameter/,
|
||||
);
|
||||
await expect(wrapped.execute("tool-3", { file_path: " ", content: "x" })).rejects.toThrow(
|
||||
/Missing required parameter/,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("uses workspaceDir for Read tool path resolution", async () => {
|
||||
const tmpDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-ws-"));
|
||||
try {
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
import fs from "node:fs/promises";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { __testing, createClawdbotCodingTools } from "./pi-tools.js";
|
||||
import { createSandboxedReadTool } from "./pi-tools.read.js";
|
||||
|
||||
describe("createClawdbotCodingTools", () => {
|
||||
describe("Claude/Gemini alias support", () => {
|
||||
it("adds Claude-style aliases to schemas without dropping metadata", () => {
|
||||
const base: AgentTool = {
|
||||
name: "write",
|
||||
description: "test",
|
||||
parameters: {
|
||||
type: "object",
|
||||
required: ["path", "content"],
|
||||
properties: {
|
||||
path: { type: "string", description: "Path" },
|
||||
content: { type: "string", description: "Body" },
|
||||
},
|
||||
},
|
||||
execute: vi.fn(),
|
||||
};
|
||||
|
||||
const patched = __testing.patchToolSchemaForClaudeCompatibility(base);
|
||||
const params = patched.parameters as {
|
||||
properties?: Record<string, unknown>;
|
||||
required?: string[];
|
||||
};
|
||||
const props = params.properties ?? {};
|
||||
|
||||
expect(props.file_path).toEqual(props.path);
|
||||
expect(params.required ?? []).not.toContain("path");
|
||||
expect(params.required ?? []).not.toContain("file_path");
|
||||
});
|
||||
|
||||
it("normalizes file_path to path and enforces required groups at runtime", async () => {
|
||||
const execute = vi.fn(async (_id, args) => args);
|
||||
const tool: AgentTool = {
|
||||
name: "write",
|
||||
description: "test",
|
||||
parameters: {
|
||||
type: "object",
|
||||
required: ["path", "content"],
|
||||
properties: {
|
||||
path: { type: "string" },
|
||||
content: { type: "string" },
|
||||
},
|
||||
},
|
||||
execute,
|
||||
};
|
||||
|
||||
const wrapped = __testing.wrapToolParamNormalization(tool, [{ keys: ["path", "file_path"] }]);
|
||||
|
||||
await wrapped.execute("tool-1", { file_path: "foo.txt", content: "x" });
|
||||
expect(execute).toHaveBeenCalledWith(
|
||||
"tool-1",
|
||||
{ path: "foo.txt", content: "x" },
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
|
||||
await expect(wrapped.execute("tool-2", { content: "x" })).rejects.toThrow(
|
||||
/Missing required parameter/,
|
||||
);
|
||||
await expect(wrapped.execute("tool-3", { file_path: " ", content: "x" })).rejects.toThrow(
|
||||
/Missing required parameter/,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("applies sandbox path guards to file_path alias", async () => {
|
||||
const tmpDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-sbx-"));
|
||||
const outsidePath = path.join(os.tmpdir(), "clawdbot-outside.txt");
|
||||
await fs.writeFile(outsidePath, "outside", "utf8");
|
||||
try {
|
||||
const readTool = createSandboxedReadTool(tmpDir);
|
||||
await expect(readTool.execute("tool-sbx-1", { file_path: outsidePath })).rejects.toThrow();
|
||||
} finally {
|
||||
await fs.rm(tmpDir, { recursive: true, force: true });
|
||||
await fs.rm(outsidePath, { force: true });
|
||||
}
|
||||
});
|
||||
it("falls back to process.cwd() when workspaceDir not provided", () => {
|
||||
const prevCwd = process.cwd();
|
||||
const tools = createClawdbotCodingTools();
|
||||
// Tools should be created without error
|
||||
expect(tools.some((tool) => tool.name === "read")).toBe(true);
|
||||
expect(tools.some((tool) => tool.name === "write")).toBe(true);
|
||||
expect(tools.some((tool) => tool.name === "edit")).toBe(true);
|
||||
// cwd should be unchanged
|
||||
expect(process.cwd()).toBe(prevCwd);
|
||||
});
|
||||
});
|
||||
@@ -1,7 +1,14 @@
|
||||
import fs from "node:fs/promises";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { createClawdbotTools } from "./clawdbot-tools.js";
|
||||
import { __testing, createClawdbotCodingTools } from "./pi-tools.js";
|
||||
import { createSandboxedReadTool } from "./pi-tools.read.js";
|
||||
import { createBrowserTool } from "./tools/browser-tool.js";
|
||||
|
||||
const defaultTools = createClawdbotCodingTools();
|
||||
|
||||
describe("createClawdbotCodingTools", () => {
|
||||
describe("Claude/Gemini alias support", () => {
|
||||
@@ -67,8 +74,144 @@ describe("createClawdbotCodingTools", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("keeps browser tool schema OpenAI-compatible without normalization", () => {
|
||||
const browser = createBrowserTool();
|
||||
const schema = browser.parameters as { type?: unknown; anyOf?: unknown };
|
||||
expect(schema.type).toBe("object");
|
||||
expect(schema.anyOf).toBeUndefined();
|
||||
});
|
||||
it("mentions Chrome extension relay in browser tool description", () => {
|
||||
const browser = createBrowserTool();
|
||||
expect(browser.description).toMatch(/Chrome extension/i);
|
||||
expect(browser.description).toMatch(/profile="chrome"/i);
|
||||
});
|
||||
it("keeps browser tool schema properties after normalization", () => {
|
||||
const browser = defaultTools.find((tool) => tool.name === "browser");
|
||||
expect(browser).toBeDefined();
|
||||
const parameters = browser?.parameters as {
|
||||
anyOf?: unknown[];
|
||||
properties?: Record<string, unknown>;
|
||||
required?: string[];
|
||||
};
|
||||
expect(parameters.properties?.action).toBeDefined();
|
||||
expect(parameters.properties?.target).toBeDefined();
|
||||
expect(parameters.properties?.controlUrl).toBeDefined();
|
||||
expect(parameters.properties?.targetUrl).toBeDefined();
|
||||
expect(parameters.properties?.request).toBeDefined();
|
||||
expect(parameters.required ?? []).toContain("action");
|
||||
});
|
||||
it("exposes raw for gateway config.apply tool calls", () => {
|
||||
const gateway = defaultTools.find((tool) => tool.name === "gateway");
|
||||
expect(gateway).toBeDefined();
|
||||
|
||||
const parameters = gateway?.parameters as {
|
||||
type?: unknown;
|
||||
required?: string[];
|
||||
properties?: Record<string, unknown>;
|
||||
};
|
||||
expect(parameters.type).toBe("object");
|
||||
expect(parameters.properties?.raw).toBeDefined();
|
||||
expect(parameters.required ?? []).not.toContain("raw");
|
||||
});
|
||||
it("flattens anyOf-of-literals to enum for provider compatibility", () => {
|
||||
const browser = defaultTools.find((tool) => tool.name === "browser");
|
||||
expect(browser).toBeDefined();
|
||||
|
||||
const parameters = browser?.parameters as {
|
||||
properties?: Record<string, unknown>;
|
||||
};
|
||||
const action = parameters.properties?.action as
|
||||
| {
|
||||
type?: unknown;
|
||||
enum?: unknown[];
|
||||
anyOf?: unknown[];
|
||||
}
|
||||
| undefined;
|
||||
|
||||
expect(action?.type).toBe("string");
|
||||
expect(action?.anyOf).toBeUndefined();
|
||||
expect(Array.isArray(action?.enum)).toBe(true);
|
||||
expect(action?.enum).toContain("act");
|
||||
|
||||
const snapshotFormat = parameters.properties?.snapshotFormat as
|
||||
| {
|
||||
type?: unknown;
|
||||
enum?: unknown[];
|
||||
anyOf?: unknown[];
|
||||
}
|
||||
| undefined;
|
||||
expect(snapshotFormat?.type).toBe("string");
|
||||
expect(snapshotFormat?.anyOf).toBeUndefined();
|
||||
expect(snapshotFormat?.enum).toEqual(["aria", "ai"]);
|
||||
});
|
||||
it("inlines local $ref before removing unsupported keywords", () => {
|
||||
const cleaned = __testing.cleanToolSchemaForGemini({
|
||||
type: "object",
|
||||
properties: {
|
||||
foo: { $ref: "#/$defs/Foo" },
|
||||
},
|
||||
$defs: {
|
||||
Foo: { type: "string", enum: ["a", "b"] },
|
||||
},
|
||||
}) as {
|
||||
$defs?: unknown;
|
||||
properties?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
expect(cleaned.$defs).toBeUndefined();
|
||||
expect(cleaned.properties).toBeDefined();
|
||||
expect(cleaned.properties?.foo).toMatchObject({
|
||||
type: "string",
|
||||
enum: ["a", "b"],
|
||||
});
|
||||
});
|
||||
it("cleans tuple items schemas", () => {
|
||||
const cleaned = __testing.cleanToolSchemaForGemini({
|
||||
type: "object",
|
||||
properties: {
|
||||
tuples: {
|
||||
type: "array",
|
||||
items: [
|
||||
{ type: "string", format: "uuid" },
|
||||
{ type: "number", minimum: 1 },
|
||||
],
|
||||
},
|
||||
},
|
||||
}) as {
|
||||
properties?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
const tuples = cleaned.properties?.tuples as { items?: unknown } | undefined;
|
||||
const items = Array.isArray(tuples?.items) ? tuples?.items : [];
|
||||
const first = items[0] as { format?: unknown } | undefined;
|
||||
const second = items[1] as { minimum?: unknown } | undefined;
|
||||
|
||||
expect(first?.format).toBeUndefined();
|
||||
expect(second?.minimum).toBeUndefined();
|
||||
});
|
||||
it("drops null-only union variants without flattening other unions", () => {
|
||||
const cleaned = __testing.cleanToolSchemaForGemini({
|
||||
type: "object",
|
||||
properties: {
|
||||
parentId: { anyOf: [{ type: "string" }, { type: "null" }] },
|
||||
count: { oneOf: [{ type: "string" }, { type: "number" }] },
|
||||
},
|
||||
}) as {
|
||||
properties?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
const parentId = cleaned.properties?.parentId as
|
||||
| { type?: unknown; anyOf?: unknown; oneOf?: unknown }
|
||||
| undefined;
|
||||
const count = cleaned.properties?.count as
|
||||
| { type?: unknown; anyOf?: unknown; oneOf?: unknown }
|
||||
| undefined;
|
||||
|
||||
expect(parentId?.type).toBe("string");
|
||||
expect(parentId?.anyOf).toBeUndefined();
|
||||
expect(count?.oneOf).toBeDefined();
|
||||
});
|
||||
it("avoids anyOf/oneOf/allOf in tool schemas", () => {
|
||||
const tools = createClawdbotCodingTools();
|
||||
const offenders: Array<{
|
||||
name: string;
|
||||
keyword: string;
|
||||
@@ -96,7 +239,7 @@ describe("createClawdbotCodingTools", () => {
|
||||
}
|
||||
};
|
||||
|
||||
for (const tool of tools) {
|
||||
for (const tool of defaultTools) {
|
||||
walk(tool.parameters, "", tool.name);
|
||||
}
|
||||
|
||||
@@ -192,4 +335,131 @@ describe("createClawdbotCodingTools", () => {
|
||||
});
|
||||
expect(tools.map((tool) => tool.name)).toEqual(["read"]);
|
||||
});
|
||||
|
||||
it("applies tool profiles before allow/deny policies", () => {
|
||||
const tools = createClawdbotCodingTools({
|
||||
config: { tools: { profile: "messaging" } },
|
||||
});
|
||||
const names = new Set(tools.map((tool) => tool.name));
|
||||
expect(names.has("message")).toBe(true);
|
||||
expect(names.has("sessions_send")).toBe(true);
|
||||
expect(names.has("sessions_spawn")).toBe(false);
|
||||
expect(names.has("exec")).toBe(false);
|
||||
expect(names.has("browser")).toBe(false);
|
||||
});
|
||||
it("expands group shorthands in global tool policy", () => {
|
||||
const tools = createClawdbotCodingTools({
|
||||
config: { tools: { allow: ["group:fs"] } },
|
||||
});
|
||||
const names = new Set(tools.map((tool) => tool.name));
|
||||
expect(names.has("read")).toBe(true);
|
||||
expect(names.has("write")).toBe(true);
|
||||
expect(names.has("edit")).toBe(true);
|
||||
expect(names.has("exec")).toBe(false);
|
||||
expect(names.has("browser")).toBe(false);
|
||||
});
|
||||
it("expands group shorthands in global tool deny policy", () => {
|
||||
const tools = createClawdbotCodingTools({
|
||||
config: { tools: { deny: ["group:fs"] } },
|
||||
});
|
||||
const names = new Set(tools.map((tool) => tool.name));
|
||||
expect(names.has("read")).toBe(false);
|
||||
expect(names.has("write")).toBe(false);
|
||||
expect(names.has("edit")).toBe(false);
|
||||
expect(names.has("exec")).toBe(true);
|
||||
});
|
||||
it("lets agent profiles override global profiles", () => {
|
||||
const tools = createClawdbotCodingTools({
|
||||
sessionKey: "agent:work:main",
|
||||
config: {
|
||||
tools: { profile: "coding" },
|
||||
agents: {
|
||||
list: [{ id: "work", tools: { profile: "messaging" } }],
|
||||
},
|
||||
},
|
||||
});
|
||||
const names = new Set(tools.map((tool) => tool.name));
|
||||
expect(names.has("message")).toBe(true);
|
||||
expect(names.has("exec")).toBe(false);
|
||||
expect(names.has("read")).toBe(false);
|
||||
});
|
||||
it("removes unsupported JSON Schema keywords for Cloud Code Assist API compatibility", () => {
|
||||
// Helper to recursively check schema for unsupported keywords
|
||||
const unsupportedKeywords = new Set([
|
||||
"patternProperties",
|
||||
"additionalProperties",
|
||||
"$schema",
|
||||
"$id",
|
||||
"$ref",
|
||||
"$defs",
|
||||
"definitions",
|
||||
"examples",
|
||||
"minLength",
|
||||
"maxLength",
|
||||
"minimum",
|
||||
"maximum",
|
||||
"multipleOf",
|
||||
"pattern",
|
||||
"format",
|
||||
"minItems",
|
||||
"maxItems",
|
||||
"uniqueItems",
|
||||
"minProperties",
|
||||
"maxProperties",
|
||||
]);
|
||||
|
||||
const findUnsupportedKeywords = (schema: unknown, path: string): string[] => {
|
||||
const found: string[] = [];
|
||||
if (!schema || typeof schema !== "object") return found;
|
||||
if (Array.isArray(schema)) {
|
||||
schema.forEach((item, i) => {
|
||||
found.push(...findUnsupportedKeywords(item, `${path}[${i}]`));
|
||||
});
|
||||
return found;
|
||||
}
|
||||
|
||||
const record = schema as Record<string, unknown>;
|
||||
const properties =
|
||||
record.properties &&
|
||||
typeof record.properties === "object" &&
|
||||
!Array.isArray(record.properties)
|
||||
? (record.properties as Record<string, unknown>)
|
||||
: undefined;
|
||||
if (properties) {
|
||||
for (const [key, value] of Object.entries(properties)) {
|
||||
found.push(...findUnsupportedKeywords(value, `${path}.properties.${key}`));
|
||||
}
|
||||
}
|
||||
|
||||
for (const [key, value] of Object.entries(record)) {
|
||||
if (key === "properties") continue;
|
||||
if (unsupportedKeywords.has(key)) {
|
||||
found.push(`${path}.${key}`);
|
||||
}
|
||||
if (value && typeof value === "object") {
|
||||
found.push(...findUnsupportedKeywords(value, `${path}.${key}`));
|
||||
}
|
||||
}
|
||||
return found;
|
||||
};
|
||||
|
||||
for (const tool of defaultTools) {
|
||||
const violations = findUnsupportedKeywords(tool.parameters, `${tool.name}.parameters`);
|
||||
expect(violations).toEqual([]);
|
||||
}
|
||||
});
|
||||
it("applies sandbox path guards to file_path alias", async () => {
|
||||
const tmpDir = await fs.mkdtemp(path.join(os.tmpdir(), "clawdbot-sbx-"));
|
||||
const outsidePath = path.join(os.tmpdir(), "clawdbot-outside.txt");
|
||||
await fs.writeFile(outsidePath, "outside", "utf8");
|
||||
try {
|
||||
const readTool = createSandboxedReadTool(tmpDir);
|
||||
await expect(readTool.execute("sandbox-1", { file_path: outsidePath })).rejects.toThrow(
|
||||
/sandbox root/i,
|
||||
);
|
||||
} finally {
|
||||
await fs.rm(outsidePath, { force: true });
|
||||
await fs.rm(tmpDir, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -44,6 +44,7 @@ import {
|
||||
collectExplicitAllowlist,
|
||||
expandPolicyWithPluginGroups,
|
||||
resolveToolProfilePolicy,
|
||||
stripPluginOnlyAllowlist,
|
||||
} from "./tool-policy.js";
|
||||
import { getPluginToolMeta } from "../plugins/tools.js";
|
||||
|
||||
@@ -298,12 +299,30 @@ export function createClawdbotCodingTools(options?: {
|
||||
tools,
|
||||
toolMeta: (tool) => getPluginToolMeta(tool as AnyAgentTool),
|
||||
});
|
||||
const profilePolicyExpanded = expandPolicyWithPluginGroups(profilePolicy, pluginGroups);
|
||||
const providerProfileExpanded = expandPolicyWithPluginGroups(providerProfilePolicy, pluginGroups);
|
||||
const globalPolicyExpanded = expandPolicyWithPluginGroups(globalPolicy, pluginGroups);
|
||||
const globalProviderExpanded = expandPolicyWithPluginGroups(globalProviderPolicy, pluginGroups);
|
||||
const agentPolicyExpanded = expandPolicyWithPluginGroups(agentPolicy, pluginGroups);
|
||||
const agentProviderExpanded = expandPolicyWithPluginGroups(agentProviderPolicy, pluginGroups);
|
||||
const profilePolicyExpanded = expandPolicyWithPluginGroups(
|
||||
stripPluginOnlyAllowlist(profilePolicy, pluginGroups),
|
||||
pluginGroups,
|
||||
);
|
||||
const providerProfileExpanded = expandPolicyWithPluginGroups(
|
||||
stripPluginOnlyAllowlist(providerProfilePolicy, pluginGroups),
|
||||
pluginGroups,
|
||||
);
|
||||
const globalPolicyExpanded = expandPolicyWithPluginGroups(
|
||||
stripPluginOnlyAllowlist(globalPolicy, pluginGroups),
|
||||
pluginGroups,
|
||||
);
|
||||
const globalProviderExpanded = expandPolicyWithPluginGroups(
|
||||
stripPluginOnlyAllowlist(globalProviderPolicy, pluginGroups),
|
||||
pluginGroups,
|
||||
);
|
||||
const agentPolicyExpanded = expandPolicyWithPluginGroups(
|
||||
stripPluginOnlyAllowlist(agentPolicy, pluginGroups),
|
||||
pluginGroups,
|
||||
);
|
||||
const agentProviderExpanded = expandPolicyWithPluginGroups(
|
||||
stripPluginOnlyAllowlist(agentProviderPolicy, pluginGroups),
|
||||
pluginGroups,
|
||||
);
|
||||
const sandboxPolicyExpanded = expandPolicyWithPluginGroups(sandbox?.tools, pluginGroups);
|
||||
const subagentPolicyExpanded = expandPolicyWithPluginGroups(subagentPolicy, pluginGroups);
|
||||
|
||||
|
||||
25
src/agents/tool-policy.plugin-only-allowlist.test.ts
Normal file
25
src/agents/tool-policy.plugin-only-allowlist.test.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
|
||||
import { stripPluginOnlyAllowlist, type PluginToolGroups } from "./tool-policy.js";
|
||||
|
||||
const pluginGroups: PluginToolGroups = {
|
||||
all: ["lobster", "workflow_tool"],
|
||||
byPlugin: new Map([["lobster", ["lobster", "workflow_tool"]]]),
|
||||
};
|
||||
|
||||
describe("stripPluginOnlyAllowlist", () => {
|
||||
it("strips allowlist when it only targets plugin tools", () => {
|
||||
const policy = stripPluginOnlyAllowlist({ allow: ["lobster"] }, pluginGroups);
|
||||
expect(policy?.allow).toBeUndefined();
|
||||
});
|
||||
|
||||
it("strips allowlist when it only targets plugin groups", () => {
|
||||
const policy = stripPluginOnlyAllowlist({ allow: ["group:plugins"] }, pluginGroups);
|
||||
expect(policy?.allow).toBeUndefined();
|
||||
});
|
||||
|
||||
it("keeps allowlist when it mixes plugin and core entries", () => {
|
||||
const policy = stripPluginOnlyAllowlist({ allow: ["lobster", "read"] }, pluginGroups);
|
||||
expect(policy?.allow).toEqual(["lobster", "read"]);
|
||||
});
|
||||
});
|
||||
@@ -178,6 +178,22 @@ export function expandPolicyWithPluginGroups(
|
||||
};
|
||||
}
|
||||
|
||||
export function stripPluginOnlyAllowlist(
|
||||
policy: ToolPolicyLike | undefined,
|
||||
groups: PluginToolGroups,
|
||||
): ToolPolicyLike | undefined {
|
||||
if (!policy?.allow || policy.allow.length === 0) return policy;
|
||||
const normalized = normalizeToolList(policy.allow);
|
||||
if (normalized.length === 0) return policy;
|
||||
const pluginIds = new Set(groups.byPlugin.keys());
|
||||
const pluginTools = new Set(groups.all);
|
||||
const isPluginEntry = (entry: string) =>
|
||||
entry === "group:plugins" || pluginIds.has(entry) || pluginTools.has(entry);
|
||||
const isPluginOnly = normalized.every((entry) => isPluginEntry(entry));
|
||||
if (!isPluginOnly) return policy;
|
||||
return { ...policy, allow: undefined };
|
||||
}
|
||||
|
||||
export function resolveToolProfilePolicy(profile?: string): ToolProfilePolicy | undefined {
|
||||
if (!profile) return undefined;
|
||||
const resolved = TOOL_PROFILES[profile as ToolProfileId];
|
||||
|
||||
@@ -209,4 +209,26 @@ describe("cron tool", () => {
|
||||
const text = cronCall.params?.payload?.text ?? "";
|
||||
expect(text).not.toContain("Recent context:");
|
||||
});
|
||||
|
||||
it("preserves explicit agentId null on add", async () => {
|
||||
callGatewayMock.mockResolvedValueOnce({ ok: true });
|
||||
|
||||
const tool = createCronTool({ agentSessionKey: "main" });
|
||||
await tool.execute("call6", {
|
||||
action: "add",
|
||||
job: {
|
||||
name: "reminder",
|
||||
schedule: { atMs: 123 },
|
||||
agentId: null,
|
||||
payload: { kind: "systemEvent", text: "Reminder: the thing." },
|
||||
},
|
||||
});
|
||||
|
||||
const call = callGatewayMock.mock.calls[0]?.[0] as {
|
||||
method?: string;
|
||||
params?: { agentId?: string | null };
|
||||
};
|
||||
expect(call.method).toBe("cron.add");
|
||||
expect(call.params?.agentId).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -3,9 +3,9 @@ import { normalizeCronJobCreate, normalizeCronJobPatch } from "../../cron/normal
|
||||
import { loadConfig } from "../../config/config.js";
|
||||
import { truncateUtf16Safe } from "../../utils.js";
|
||||
import { optionalStringEnum, stringEnum } from "../schema/typebox.js";
|
||||
import { resolveSessionAgentId } from "../agent-scope.js";
|
||||
import { type AnyAgentTool, jsonResult, readStringParam } from "./common.js";
|
||||
import { callGatewayTool, type GatewayCallOptions } from "./gateway.js";
|
||||
import { resolveSessionAgentId } from "../agent-scope.js";
|
||||
import { resolveInternalSessionKey, resolveMainSessionAlias } from "./sessions-helpers.js";
|
||||
|
||||
// NOTE: We use Type.Object({}, { additionalProperties: true }) for job/patch
|
||||
@@ -159,12 +159,12 @@ export function createCronTool(opts?: CronToolOptions): AnyAgentTool {
|
||||
throw new Error("job required");
|
||||
}
|
||||
const job = normalizeCronJobCreate(params.job) ?? params.job;
|
||||
if (job && typeof job === "object") {
|
||||
if (job && typeof job === "object" && !("agentId" in job)) {
|
||||
const cfg = loadConfig();
|
||||
const agentId = opts?.agentSessionKey
|
||||
? resolveSessionAgentId({ sessionKey: opts.agentSessionKey, config: cfg })
|
||||
: undefined;
|
||||
if (agentId && !("agentId" in (job as { agentId?: unknown }))) {
|
||||
if (agentId) {
|
||||
(job as { agentId?: string }).agentId = agentId;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,8 +40,6 @@ export async function handleDiscordGuildAction(
|
||||
isActionEnabled: ActionGate<DiscordActionConfig>,
|
||||
): Promise<AgentToolResult<unknown>> {
|
||||
const accountId = readStringParam(params, "accountId");
|
||||
const accountOpts = accountId ? { accountId } : {};
|
||||
|
||||
switch (action) {
|
||||
case "memberInfo": {
|
||||
if (!isActionEnabled("memberInfo")) {
|
||||
@@ -53,7 +51,9 @@ export async function handleDiscordGuildAction(
|
||||
const userId = readStringParam(params, "userId", {
|
||||
required: true,
|
||||
});
|
||||
const member = await fetchMemberInfoDiscord(guildId, userId, accountOpts);
|
||||
const member = accountId
|
||||
? await fetchMemberInfoDiscord(guildId, userId, { accountId })
|
||||
: await fetchMemberInfoDiscord(guildId, userId);
|
||||
return jsonResult({ ok: true, member });
|
||||
}
|
||||
case "roleInfo": {
|
||||
@@ -63,7 +63,9 @@ export async function handleDiscordGuildAction(
|
||||
const guildId = readStringParam(params, "guildId", {
|
||||
required: true,
|
||||
});
|
||||
const roles = await fetchRoleInfoDiscord(guildId, accountOpts);
|
||||
const roles = accountId
|
||||
? await fetchRoleInfoDiscord(guildId, { accountId })
|
||||
: await fetchRoleInfoDiscord(guildId);
|
||||
return jsonResult({ ok: true, roles });
|
||||
}
|
||||
case "emojiList": {
|
||||
@@ -73,7 +75,9 @@ export async function handleDiscordGuildAction(
|
||||
const guildId = readStringParam(params, "guildId", {
|
||||
required: true,
|
||||
});
|
||||
const emojis = await listGuildEmojisDiscord(guildId, accountOpts);
|
||||
const emojis = accountId
|
||||
? await listGuildEmojisDiscord(guildId, { accountId })
|
||||
: await listGuildEmojisDiscord(guildId);
|
||||
return jsonResult({ ok: true, emojis });
|
||||
}
|
||||
case "emojiUpload": {
|
||||
@@ -88,15 +92,22 @@ export async function handleDiscordGuildAction(
|
||||
required: true,
|
||||
});
|
||||
const roleIds = readStringArrayParam(params, "roleIds");
|
||||
const emoji = await uploadEmojiDiscord(
|
||||
{
|
||||
guildId,
|
||||
name,
|
||||
mediaUrl,
|
||||
roleIds: roleIds?.length ? roleIds : undefined,
|
||||
},
|
||||
accountOpts,
|
||||
);
|
||||
const emoji = accountId
|
||||
? await uploadEmojiDiscord(
|
||||
{
|
||||
guildId,
|
||||
name,
|
||||
mediaUrl,
|
||||
roleIds: roleIds?.length ? roleIds : undefined,
|
||||
},
|
||||
{ accountId },
|
||||
)
|
||||
: await uploadEmojiDiscord({
|
||||
guildId,
|
||||
name,
|
||||
mediaUrl,
|
||||
roleIds: roleIds?.length ? roleIds : undefined,
|
||||
});
|
||||
return jsonResult({ ok: true, emoji });
|
||||
}
|
||||
case "stickerUpload": {
|
||||
@@ -114,16 +125,24 @@ export async function handleDiscordGuildAction(
|
||||
const mediaUrl = readStringParam(params, "mediaUrl", {
|
||||
required: true,
|
||||
});
|
||||
const sticker = await uploadStickerDiscord(
|
||||
{
|
||||
guildId,
|
||||
name,
|
||||
description,
|
||||
tags,
|
||||
mediaUrl,
|
||||
},
|
||||
accountOpts,
|
||||
);
|
||||
const sticker = accountId
|
||||
? await uploadStickerDiscord(
|
||||
{
|
||||
guildId,
|
||||
name,
|
||||
description,
|
||||
tags,
|
||||
mediaUrl,
|
||||
},
|
||||
{ accountId },
|
||||
)
|
||||
: await uploadStickerDiscord({
|
||||
guildId,
|
||||
name,
|
||||
description,
|
||||
tags,
|
||||
mediaUrl,
|
||||
});
|
||||
return jsonResult({ ok: true, sticker });
|
||||
}
|
||||
case "roleAdd": {
|
||||
@@ -137,7 +156,11 @@ export async function handleDiscordGuildAction(
|
||||
required: true,
|
||||
});
|
||||
const roleId = readStringParam(params, "roleId", { required: true });
|
||||
await addRoleDiscord({ guildId, userId, roleId }, accountOpts);
|
||||
if (accountId) {
|
||||
await addRoleDiscord({ guildId, userId, roleId }, { accountId });
|
||||
} else {
|
||||
await addRoleDiscord({ guildId, userId, roleId });
|
||||
}
|
||||
return jsonResult({ ok: true });
|
||||
}
|
||||
case "roleRemove": {
|
||||
@@ -151,7 +174,11 @@ export async function handleDiscordGuildAction(
|
||||
required: true,
|
||||
});
|
||||
const roleId = readStringParam(params, "roleId", { required: true });
|
||||
await removeRoleDiscord({ guildId, userId, roleId }, accountOpts);
|
||||
if (accountId) {
|
||||
await removeRoleDiscord({ guildId, userId, roleId }, { accountId });
|
||||
} else {
|
||||
await removeRoleDiscord({ guildId, userId, roleId });
|
||||
}
|
||||
return jsonResult({ ok: true });
|
||||
}
|
||||
case "channelInfo": {
|
||||
@@ -161,7 +188,9 @@ export async function handleDiscordGuildAction(
|
||||
const channelId = readStringParam(params, "channelId", {
|
||||
required: true,
|
||||
});
|
||||
const channel = await fetchChannelInfoDiscord(channelId, accountOpts);
|
||||
const channel = accountId
|
||||
? await fetchChannelInfoDiscord(channelId, { accountId })
|
||||
: await fetchChannelInfoDiscord(channelId);
|
||||
return jsonResult({ ok: true, channel });
|
||||
}
|
||||
case "channelList": {
|
||||
@@ -171,7 +200,9 @@ export async function handleDiscordGuildAction(
|
||||
const guildId = readStringParam(params, "guildId", {
|
||||
required: true,
|
||||
});
|
||||
const channels = await listGuildChannelsDiscord(guildId, accountOpts);
|
||||
const channels = accountId
|
||||
? await listGuildChannelsDiscord(guildId, { accountId })
|
||||
: await listGuildChannelsDiscord(guildId);
|
||||
return jsonResult({ ok: true, channels });
|
||||
}
|
||||
case "voiceStatus": {
|
||||
@@ -184,7 +215,9 @@ export async function handleDiscordGuildAction(
|
||||
const userId = readStringParam(params, "userId", {
|
||||
required: true,
|
||||
});
|
||||
const voice = await fetchVoiceStatusDiscord(guildId, userId, accountOpts);
|
||||
const voice = accountId
|
||||
? await fetchVoiceStatusDiscord(guildId, userId, { accountId })
|
||||
: await fetchVoiceStatusDiscord(guildId, userId);
|
||||
return jsonResult({ ok: true, voice });
|
||||
}
|
||||
case "eventList": {
|
||||
@@ -194,7 +227,9 @@ export async function handleDiscordGuildAction(
|
||||
const guildId = readStringParam(params, "guildId", {
|
||||
required: true,
|
||||
});
|
||||
const events = await listScheduledEventsDiscord(guildId, accountOpts);
|
||||
const events = accountId
|
||||
? await listScheduledEventsDiscord(guildId, { accountId })
|
||||
: await listScheduledEventsDiscord(guildId);
|
||||
return jsonResult({ ok: true, events });
|
||||
}
|
||||
case "eventCreate": {
|
||||
@@ -224,7 +259,9 @@ export async function handleDiscordGuildAction(
|
||||
entity_metadata: entityType === 3 && location ? { location } : undefined,
|
||||
privacy_level: 2,
|
||||
};
|
||||
const event = await createScheduledEventDiscord(guildId, payload, accountOpts);
|
||||
const event = accountId
|
||||
? await createScheduledEventDiscord(guildId, payload, { accountId })
|
||||
: await createScheduledEventDiscord(guildId, payload);
|
||||
return jsonResult({ ok: true, event });
|
||||
}
|
||||
case "channelCreate": {
|
||||
@@ -238,18 +275,28 @@ export async function handleDiscordGuildAction(
|
||||
const topic = readStringParam(params, "topic");
|
||||
const position = readNumberParam(params, "position", { integer: true });
|
||||
const nsfw = params.nsfw as boolean | undefined;
|
||||
const channel = await createChannelDiscord(
|
||||
{
|
||||
guildId,
|
||||
name,
|
||||
type: type ?? undefined,
|
||||
parentId: parentId ?? undefined,
|
||||
topic: topic ?? undefined,
|
||||
position: position ?? undefined,
|
||||
nsfw,
|
||||
},
|
||||
accountOpts,
|
||||
);
|
||||
const channel = accountId
|
||||
? await createChannelDiscord(
|
||||
{
|
||||
guildId,
|
||||
name,
|
||||
type: type ?? undefined,
|
||||
parentId: parentId ?? undefined,
|
||||
topic: topic ?? undefined,
|
||||
position: position ?? undefined,
|
||||
nsfw,
|
||||
},
|
||||
{ accountId },
|
||||
)
|
||||
: await createChannelDiscord({
|
||||
guildId,
|
||||
name,
|
||||
type: type ?? undefined,
|
||||
parentId: parentId ?? undefined,
|
||||
topic: topic ?? undefined,
|
||||
position: position ?? undefined,
|
||||
nsfw,
|
||||
});
|
||||
return jsonResult({ ok: true, channel });
|
||||
}
|
||||
case "channelEdit": {
|
||||
@@ -267,18 +314,28 @@ export async function handleDiscordGuildAction(
|
||||
const rateLimitPerUser = readNumberParam(params, "rateLimitPerUser", {
|
||||
integer: true,
|
||||
});
|
||||
const channel = await editChannelDiscord(
|
||||
{
|
||||
channelId,
|
||||
name: name ?? undefined,
|
||||
topic: topic ?? undefined,
|
||||
position: position ?? undefined,
|
||||
parentId,
|
||||
nsfw,
|
||||
rateLimitPerUser: rateLimitPerUser ?? undefined,
|
||||
},
|
||||
accountOpts,
|
||||
);
|
||||
const channel = accountId
|
||||
? await editChannelDiscord(
|
||||
{
|
||||
channelId,
|
||||
name: name ?? undefined,
|
||||
topic: topic ?? undefined,
|
||||
position: position ?? undefined,
|
||||
parentId,
|
||||
nsfw,
|
||||
rateLimitPerUser: rateLimitPerUser ?? undefined,
|
||||
},
|
||||
{ accountId },
|
||||
)
|
||||
: await editChannelDiscord({
|
||||
channelId,
|
||||
name: name ?? undefined,
|
||||
topic: topic ?? undefined,
|
||||
position: position ?? undefined,
|
||||
parentId,
|
||||
nsfw,
|
||||
rateLimitPerUser: rateLimitPerUser ?? undefined,
|
||||
});
|
||||
return jsonResult({ ok: true, channel });
|
||||
}
|
||||
case "channelDelete": {
|
||||
@@ -288,7 +345,9 @@ export async function handleDiscordGuildAction(
|
||||
const channelId = readStringParam(params, "channelId", {
|
||||
required: true,
|
||||
});
|
||||
const result = await deleteChannelDiscord(channelId, accountOpts);
|
||||
const result = accountId
|
||||
? await deleteChannelDiscord(channelId, { accountId })
|
||||
: await deleteChannelDiscord(channelId);
|
||||
return jsonResult(result);
|
||||
}
|
||||
case "channelMove": {
|
||||
@@ -301,15 +360,24 @@ export async function handleDiscordGuildAction(
|
||||
});
|
||||
const parentId = readParentIdParam(params);
|
||||
const position = readNumberParam(params, "position", { integer: true });
|
||||
await moveChannelDiscord(
|
||||
{
|
||||
if (accountId) {
|
||||
await moveChannelDiscord(
|
||||
{
|
||||
guildId,
|
||||
channelId,
|
||||
parentId,
|
||||
position: position ?? undefined,
|
||||
},
|
||||
{ accountId },
|
||||
);
|
||||
} else {
|
||||
await moveChannelDiscord({
|
||||
guildId,
|
||||
channelId,
|
||||
parentId,
|
||||
position: position ?? undefined,
|
||||
},
|
||||
accountOpts,
|
||||
);
|
||||
});
|
||||
}
|
||||
return jsonResult({ ok: true });
|
||||
}
|
||||
case "categoryCreate": {
|
||||
@@ -319,15 +387,22 @@ export async function handleDiscordGuildAction(
|
||||
const guildId = readStringParam(params, "guildId", { required: true });
|
||||
const name = readStringParam(params, "name", { required: true });
|
||||
const position = readNumberParam(params, "position", { integer: true });
|
||||
const channel = await createChannelDiscord(
|
||||
{
|
||||
guildId,
|
||||
name,
|
||||
type: 4,
|
||||
position: position ?? undefined,
|
||||
},
|
||||
accountOpts,
|
||||
);
|
||||
const channel = accountId
|
||||
? await createChannelDiscord(
|
||||
{
|
||||
guildId,
|
||||
name,
|
||||
type: 4,
|
||||
position: position ?? undefined,
|
||||
},
|
||||
{ accountId },
|
||||
)
|
||||
: await createChannelDiscord({
|
||||
guildId,
|
||||
name,
|
||||
type: 4,
|
||||
position: position ?? undefined,
|
||||
});
|
||||
return jsonResult({ ok: true, category: channel });
|
||||
}
|
||||
case "categoryEdit": {
|
||||
@@ -339,14 +414,20 @@ export async function handleDiscordGuildAction(
|
||||
});
|
||||
const name = readStringParam(params, "name");
|
||||
const position = readNumberParam(params, "position", { integer: true });
|
||||
const channel = await editChannelDiscord(
|
||||
{
|
||||
channelId: categoryId,
|
||||
name: name ?? undefined,
|
||||
position: position ?? undefined,
|
||||
},
|
||||
accountOpts,
|
||||
);
|
||||
const channel = accountId
|
||||
? await editChannelDiscord(
|
||||
{
|
||||
channelId: categoryId,
|
||||
name: name ?? undefined,
|
||||
position: position ?? undefined,
|
||||
},
|
||||
{ accountId },
|
||||
)
|
||||
: await editChannelDiscord({
|
||||
channelId: categoryId,
|
||||
name: name ?? undefined,
|
||||
position: position ?? undefined,
|
||||
});
|
||||
return jsonResult({ ok: true, category: channel });
|
||||
}
|
||||
case "categoryDelete": {
|
||||
@@ -356,7 +437,9 @@ export async function handleDiscordGuildAction(
|
||||
const categoryId = readStringParam(params, "categoryId", {
|
||||
required: true,
|
||||
});
|
||||
const result = await deleteChannelDiscord(categoryId, accountOpts);
|
||||
const result = accountId
|
||||
? await deleteChannelDiscord(categoryId, { accountId })
|
||||
: await deleteChannelDiscord(categoryId);
|
||||
return jsonResult(result);
|
||||
}
|
||||
case "channelPermissionSet": {
|
||||
@@ -373,16 +456,26 @@ export async function handleDiscordGuildAction(
|
||||
const targetType = targetTypeRaw === "member" ? 1 : 0;
|
||||
const allow = readStringParam(params, "allow");
|
||||
const deny = readStringParam(params, "deny");
|
||||
await setChannelPermissionDiscord(
|
||||
{
|
||||
if (accountId) {
|
||||
await setChannelPermissionDiscord(
|
||||
{
|
||||
channelId,
|
||||
targetId,
|
||||
targetType,
|
||||
allow: allow ?? undefined,
|
||||
deny: deny ?? undefined,
|
||||
},
|
||||
{ accountId },
|
||||
);
|
||||
} else {
|
||||
await setChannelPermissionDiscord({
|
||||
channelId,
|
||||
targetId,
|
||||
targetType,
|
||||
allow: allow ?? undefined,
|
||||
deny: deny ?? undefined,
|
||||
},
|
||||
accountOpts,
|
||||
);
|
||||
});
|
||||
}
|
||||
return jsonResult({ ok: true });
|
||||
}
|
||||
case "channelPermissionRemove": {
|
||||
@@ -393,7 +486,11 @@ export async function handleDiscordGuildAction(
|
||||
required: true,
|
||||
});
|
||||
const targetId = readStringParam(params, "targetId", { required: true });
|
||||
await removeChannelPermissionDiscord(channelId, targetId, accountOpts);
|
||||
if (accountId) {
|
||||
await removeChannelPermissionDiscord(channelId, targetId, { accountId });
|
||||
} else {
|
||||
await removeChannelPermissionDiscord(channelId, targetId);
|
||||
}
|
||||
return jsonResult({ ok: true });
|
||||
}
|
||||
default:
|
||||
|
||||
@@ -58,6 +58,7 @@ export async function handleDiscordMessagingAction(
|
||||
required: true,
|
||||
}),
|
||||
);
|
||||
const accountId = readStringParam(params, "accountId");
|
||||
const normalizeMessage = (message: unknown) => {
|
||||
if (!message || typeof message !== "object") return message;
|
||||
return withNormalizedTimestamp(
|
||||
@@ -65,8 +66,6 @@ export async function handleDiscordMessagingAction(
|
||||
(message as { timestamp?: unknown }).timestamp,
|
||||
);
|
||||
};
|
||||
const accountId = readStringParam(params, "accountId");
|
||||
const accountOpts = accountId ? { accountId } : {};
|
||||
switch (action) {
|
||||
case "react": {
|
||||
if (!isActionEnabled("reactions")) {
|
||||
@@ -80,14 +79,24 @@ export async function handleDiscordMessagingAction(
|
||||
removeErrorMessage: "Emoji is required to remove a Discord reaction.",
|
||||
});
|
||||
if (remove) {
|
||||
await removeReactionDiscord(channelId, messageId, emoji, accountOpts);
|
||||
if (accountId) {
|
||||
await removeReactionDiscord(channelId, messageId, emoji, { accountId });
|
||||
} else {
|
||||
await removeReactionDiscord(channelId, messageId, emoji);
|
||||
}
|
||||
return jsonResult({ ok: true, removed: emoji });
|
||||
}
|
||||
if (isEmpty) {
|
||||
const removed = await removeOwnReactionsDiscord(channelId, messageId, accountOpts);
|
||||
const removed = accountId
|
||||
? await removeOwnReactionsDiscord(channelId, messageId, { accountId })
|
||||
: await removeOwnReactionsDiscord(channelId, messageId);
|
||||
return jsonResult({ ok: true, removed: removed.removed });
|
||||
}
|
||||
await reactMessageDiscord(channelId, messageId, emoji, accountOpts);
|
||||
if (accountId) {
|
||||
await reactMessageDiscord(channelId, messageId, emoji, { accountId });
|
||||
} else {
|
||||
await reactMessageDiscord(channelId, messageId, emoji);
|
||||
}
|
||||
return jsonResult({ ok: true, added: emoji });
|
||||
}
|
||||
case "reactions": {
|
||||
@@ -102,7 +111,7 @@ export async function handleDiscordMessagingAction(
|
||||
const limit =
|
||||
typeof limitRaw === "number" && Number.isFinite(limitRaw) ? limitRaw : undefined;
|
||||
const reactions = await fetchReactionsDiscord(channelId, messageId, {
|
||||
...accountOpts,
|
||||
...(accountId ? { accountId } : {}),
|
||||
limit,
|
||||
});
|
||||
return jsonResult({ ok: true, reactions });
|
||||
@@ -118,8 +127,8 @@ export async function handleDiscordMessagingAction(
|
||||
label: "stickerIds",
|
||||
});
|
||||
await sendStickerDiscord(to, stickerIds, {
|
||||
...(accountId ? { accountId } : {}),
|
||||
content,
|
||||
accountId: accountId ?? undefined,
|
||||
});
|
||||
return jsonResult({ ok: true });
|
||||
}
|
||||
@@ -146,7 +155,7 @@ export async function handleDiscordMessagingAction(
|
||||
await sendPollDiscord(
|
||||
to,
|
||||
{ question, options: answers, maxSelections, durationHours },
|
||||
{ content, accountId: accountId ?? undefined },
|
||||
{ ...(accountId ? { accountId } : {}), content },
|
||||
);
|
||||
return jsonResult({ ok: true });
|
||||
}
|
||||
@@ -155,7 +164,9 @@ export async function handleDiscordMessagingAction(
|
||||
throw new Error("Discord permissions are disabled.");
|
||||
}
|
||||
const channelId = resolveChannelId();
|
||||
const permissions = await fetchChannelPermissionsDiscord(channelId, accountOpts);
|
||||
const permissions = accountId
|
||||
? await fetchChannelPermissionsDiscord(channelId, { accountId })
|
||||
: await fetchChannelPermissionsDiscord(channelId);
|
||||
return jsonResult({ ok: true, permissions });
|
||||
}
|
||||
case "fetchMessage": {
|
||||
@@ -177,7 +188,9 @@ export async function handleDiscordMessagingAction(
|
||||
"Discord message fetch requires guildId, channelId, and messageId (or a valid messageLink).",
|
||||
);
|
||||
}
|
||||
const message = await fetchMessageDiscord(channelId, messageId, accountOpts);
|
||||
const message = accountId
|
||||
? await fetchMessageDiscord(channelId, messageId, { accountId })
|
||||
: await fetchMessageDiscord(channelId, messageId);
|
||||
return jsonResult({
|
||||
ok: true,
|
||||
message: normalizeMessage(message),
|
||||
@@ -191,19 +204,18 @@ export async function handleDiscordMessagingAction(
|
||||
throw new Error("Discord message reads are disabled.");
|
||||
}
|
||||
const channelId = resolveChannelId();
|
||||
const messages = await readMessagesDiscord(
|
||||
channelId,
|
||||
{
|
||||
limit:
|
||||
typeof params.limit === "number" && Number.isFinite(params.limit)
|
||||
? params.limit
|
||||
: undefined,
|
||||
before: readStringParam(params, "before"),
|
||||
after: readStringParam(params, "after"),
|
||||
around: readStringParam(params, "around"),
|
||||
},
|
||||
accountOpts,
|
||||
);
|
||||
const query = {
|
||||
limit:
|
||||
typeof params.limit === "number" && Number.isFinite(params.limit)
|
||||
? params.limit
|
||||
: undefined,
|
||||
before: readStringParam(params, "before"),
|
||||
after: readStringParam(params, "after"),
|
||||
around: readStringParam(params, "around"),
|
||||
};
|
||||
const messages = accountId
|
||||
? await readMessagesDiscord(channelId, query, { accountId })
|
||||
: await readMessagesDiscord(channelId, query);
|
||||
return jsonResult({
|
||||
ok: true,
|
||||
messages: messages.map((message) => normalizeMessage(message)),
|
||||
@@ -222,7 +234,7 @@ export async function handleDiscordMessagingAction(
|
||||
const embeds =
|
||||
Array.isArray(params.embeds) && params.embeds.length > 0 ? params.embeds : undefined;
|
||||
const result = await sendMessageDiscord(to, content, {
|
||||
accountId: accountId ?? undefined,
|
||||
...(accountId ? { accountId } : {}),
|
||||
mediaUrl,
|
||||
replyTo,
|
||||
embeds,
|
||||
@@ -240,14 +252,9 @@ export async function handleDiscordMessagingAction(
|
||||
const content = readStringParam(params, "content", {
|
||||
required: true,
|
||||
});
|
||||
const message = await editMessageDiscord(
|
||||
channelId,
|
||||
messageId,
|
||||
{
|
||||
content,
|
||||
},
|
||||
accountOpts,
|
||||
);
|
||||
const message = accountId
|
||||
? await editMessageDiscord(channelId, messageId, { content }, { accountId })
|
||||
: await editMessageDiscord(channelId, messageId, { content });
|
||||
return jsonResult({ ok: true, message });
|
||||
}
|
||||
case "deleteMessage": {
|
||||
@@ -258,7 +265,11 @@ export async function handleDiscordMessagingAction(
|
||||
const messageId = readStringParam(params, "messageId", {
|
||||
required: true,
|
||||
});
|
||||
await deleteMessageDiscord(channelId, messageId, accountOpts);
|
||||
if (accountId) {
|
||||
await deleteMessageDiscord(channelId, messageId, { accountId });
|
||||
} else {
|
||||
await deleteMessageDiscord(channelId, messageId);
|
||||
}
|
||||
return jsonResult({ ok: true });
|
||||
}
|
||||
case "threadCreate": {
|
||||
@@ -273,15 +284,13 @@ export async function handleDiscordMessagingAction(
|
||||
typeof autoArchiveMinutesRaw === "number" && Number.isFinite(autoArchiveMinutesRaw)
|
||||
? autoArchiveMinutesRaw
|
||||
: undefined;
|
||||
const thread = await createThreadDiscord(
|
||||
channelId,
|
||||
{
|
||||
name,
|
||||
messageId,
|
||||
autoArchiveMinutes,
|
||||
},
|
||||
accountOpts,
|
||||
);
|
||||
const thread = accountId
|
||||
? await createThreadDiscord(
|
||||
channelId,
|
||||
{ name, messageId, autoArchiveMinutes },
|
||||
{ accountId },
|
||||
)
|
||||
: await createThreadDiscord(channelId, { name, messageId, autoArchiveMinutes });
|
||||
return jsonResult({ ok: true, thread });
|
||||
}
|
||||
case "threadList": {
|
||||
@@ -299,16 +308,24 @@ export async function handleDiscordMessagingAction(
|
||||
typeof params.limit === "number" && Number.isFinite(params.limit)
|
||||
? params.limit
|
||||
: undefined;
|
||||
const threads = await listThreadsDiscord(
|
||||
{
|
||||
guildId,
|
||||
channelId,
|
||||
includeArchived,
|
||||
before,
|
||||
limit,
|
||||
},
|
||||
accountOpts,
|
||||
);
|
||||
const threads = accountId
|
||||
? await listThreadsDiscord(
|
||||
{
|
||||
guildId,
|
||||
channelId,
|
||||
includeArchived,
|
||||
before,
|
||||
limit,
|
||||
},
|
||||
{ accountId },
|
||||
)
|
||||
: await listThreadsDiscord({
|
||||
guildId,
|
||||
channelId,
|
||||
includeArchived,
|
||||
before,
|
||||
limit,
|
||||
});
|
||||
return jsonResult({ ok: true, threads });
|
||||
}
|
||||
case "threadReply": {
|
||||
@@ -322,7 +339,7 @@ export async function handleDiscordMessagingAction(
|
||||
const mediaUrl = readStringParam(params, "mediaUrl");
|
||||
const replyTo = readStringParam(params, "replyTo");
|
||||
const result = await sendMessageDiscord(`channel:${channelId}`, content, {
|
||||
accountId: accountId ?? undefined,
|
||||
...(accountId ? { accountId } : {}),
|
||||
mediaUrl,
|
||||
replyTo,
|
||||
});
|
||||
@@ -336,7 +353,11 @@ export async function handleDiscordMessagingAction(
|
||||
const messageId = readStringParam(params, "messageId", {
|
||||
required: true,
|
||||
});
|
||||
await pinMessageDiscord(channelId, messageId, accountOpts);
|
||||
if (accountId) {
|
||||
await pinMessageDiscord(channelId, messageId, { accountId });
|
||||
} else {
|
||||
await pinMessageDiscord(channelId, messageId);
|
||||
}
|
||||
return jsonResult({ ok: true });
|
||||
}
|
||||
case "unpinMessage": {
|
||||
@@ -347,7 +368,11 @@ export async function handleDiscordMessagingAction(
|
||||
const messageId = readStringParam(params, "messageId", {
|
||||
required: true,
|
||||
});
|
||||
await unpinMessageDiscord(channelId, messageId, accountOpts);
|
||||
if (accountId) {
|
||||
await unpinMessageDiscord(channelId, messageId, { accountId });
|
||||
} else {
|
||||
await unpinMessageDiscord(channelId, messageId);
|
||||
}
|
||||
return jsonResult({ ok: true });
|
||||
}
|
||||
case "listPins": {
|
||||
@@ -355,7 +380,9 @@ export async function handleDiscordMessagingAction(
|
||||
throw new Error("Discord pins are disabled.");
|
||||
}
|
||||
const channelId = resolveChannelId();
|
||||
const pins = await listPinsDiscord(channelId, accountOpts);
|
||||
const pins = accountId
|
||||
? await listPinsDiscord(channelId, { accountId })
|
||||
: await listPinsDiscord(channelId);
|
||||
return jsonResult({ ok: true, pins: pins.map((pin) => normalizeMessage(pin)) });
|
||||
}
|
||||
case "searchMessages": {
|
||||
@@ -378,16 +405,24 @@ export async function handleDiscordMessagingAction(
|
||||
: undefined;
|
||||
const channelIdList = [...(channelIds ?? []), ...(channelId ? [channelId] : [])];
|
||||
const authorIdList = [...(authorIds ?? []), ...(authorId ? [authorId] : [])];
|
||||
const results = await searchMessagesDiscord(
|
||||
{
|
||||
guildId,
|
||||
content,
|
||||
channelIds: channelIdList.length ? channelIdList : undefined,
|
||||
authorIds: authorIdList.length ? authorIdList : undefined,
|
||||
limit,
|
||||
},
|
||||
accountOpts,
|
||||
);
|
||||
const results = accountId
|
||||
? await searchMessagesDiscord(
|
||||
{
|
||||
guildId,
|
||||
content,
|
||||
channelIds: channelIdList.length ? channelIdList : undefined,
|
||||
authorIds: authorIdList.length ? authorIdList : undefined,
|
||||
limit,
|
||||
},
|
||||
{ accountId },
|
||||
)
|
||||
: await searchMessagesDiscord({
|
||||
guildId,
|
||||
content,
|
||||
channelIds: channelIdList.length ? channelIdList : undefined,
|
||||
authorIds: authorIdList.length ? authorIdList : undefined,
|
||||
limit,
|
||||
});
|
||||
if (!results || typeof results !== "object") {
|
||||
return jsonResult({ ok: true, results });
|
||||
}
|
||||
|
||||
@@ -9,8 +9,6 @@ export async function handleDiscordModerationAction(
|
||||
isActionEnabled: ActionGate<DiscordActionConfig>,
|
||||
): Promise<AgentToolResult<unknown>> {
|
||||
const accountId = readStringParam(params, "accountId");
|
||||
const accountOpts = accountId ? { accountId } : {};
|
||||
|
||||
switch (action) {
|
||||
case "timeout": {
|
||||
if (!isActionEnabled("moderation", false)) {
|
||||
@@ -28,16 +26,24 @@ export async function handleDiscordModerationAction(
|
||||
: undefined;
|
||||
const until = readStringParam(params, "until");
|
||||
const reason = readStringParam(params, "reason");
|
||||
const member = await timeoutMemberDiscord(
|
||||
{
|
||||
guildId,
|
||||
userId,
|
||||
durationMinutes,
|
||||
until,
|
||||
reason,
|
||||
},
|
||||
accountOpts,
|
||||
);
|
||||
const member = accountId
|
||||
? await timeoutMemberDiscord(
|
||||
{
|
||||
guildId,
|
||||
userId,
|
||||
durationMinutes,
|
||||
until,
|
||||
reason,
|
||||
},
|
||||
{ accountId },
|
||||
)
|
||||
: await timeoutMemberDiscord({
|
||||
guildId,
|
||||
userId,
|
||||
durationMinutes,
|
||||
until,
|
||||
reason,
|
||||
});
|
||||
return jsonResult({ ok: true, member });
|
||||
}
|
||||
case "kick": {
|
||||
@@ -51,7 +57,11 @@ export async function handleDiscordModerationAction(
|
||||
required: true,
|
||||
});
|
||||
const reason = readStringParam(params, "reason");
|
||||
await kickMemberDiscord({ guildId, userId, reason }, accountOpts);
|
||||
if (accountId) {
|
||||
await kickMemberDiscord({ guildId, userId, reason }, { accountId });
|
||||
} else {
|
||||
await kickMemberDiscord({ guildId, userId, reason });
|
||||
}
|
||||
return jsonResult({ ok: true });
|
||||
}
|
||||
case "ban": {
|
||||
@@ -69,15 +79,24 @@ export async function handleDiscordModerationAction(
|
||||
typeof params.deleteMessageDays === "number" && Number.isFinite(params.deleteMessageDays)
|
||||
? params.deleteMessageDays
|
||||
: undefined;
|
||||
await banMemberDiscord(
|
||||
{
|
||||
if (accountId) {
|
||||
await banMemberDiscord(
|
||||
{
|
||||
guildId,
|
||||
userId,
|
||||
reason,
|
||||
deleteMessageDays,
|
||||
},
|
||||
{ accountId },
|
||||
);
|
||||
} else {
|
||||
await banMemberDiscord({
|
||||
guildId,
|
||||
userId,
|
||||
reason,
|
||||
deleteMessageDays,
|
||||
},
|
||||
accountOpts,
|
||||
);
|
||||
});
|
||||
}
|
||||
return jsonResult({ ok: true });
|
||||
}
|
||||
default:
|
||||
|
||||
@@ -3,6 +3,7 @@ import { describe, expect, it, vi } from "vitest";
|
||||
import type { DiscordActionConfig } from "../../config/config.js";
|
||||
import { handleDiscordGuildAction } from "./discord-actions-guild.js";
|
||||
import { handleDiscordMessagingAction } from "./discord-actions-messaging.js";
|
||||
import { handleDiscordModerationAction } from "./discord-actions-moderation.js";
|
||||
|
||||
const createChannelDiscord = vi.fn(async () => ({
|
||||
id: "new-channel",
|
||||
@@ -20,6 +21,7 @@ const editMessageDiscord = vi.fn(async () => ({}));
|
||||
const fetchMessageDiscord = vi.fn(async () => ({}));
|
||||
const fetchChannelPermissionsDiscord = vi.fn(async () => ({}));
|
||||
const fetchReactionsDiscord = vi.fn(async () => ({}));
|
||||
const listGuildChannelsDiscord = vi.fn(async () => []);
|
||||
const listPinsDiscord = vi.fn(async () => ({}));
|
||||
const listThreadsDiscord = vi.fn(async () => ({}));
|
||||
const moveChannelDiscord = vi.fn(async () => ({ ok: true }));
|
||||
@@ -35,8 +37,12 @@ const sendPollDiscord = vi.fn(async () => ({}));
|
||||
const sendStickerDiscord = vi.fn(async () => ({}));
|
||||
const setChannelPermissionDiscord = vi.fn(async () => ({ ok: true }));
|
||||
const unpinMessageDiscord = vi.fn(async () => ({}));
|
||||
const timeoutMemberDiscord = vi.fn(async () => ({}));
|
||||
const kickMemberDiscord = vi.fn(async () => ({}));
|
||||
const banMemberDiscord = vi.fn(async () => ({}));
|
||||
|
||||
vi.mock("../../discord/send.js", () => ({
|
||||
banMemberDiscord: (...args: unknown[]) => banMemberDiscord(...args),
|
||||
createChannelDiscord: (...args: unknown[]) => createChannelDiscord(...args),
|
||||
createThreadDiscord: (...args: unknown[]) => createThreadDiscord(...args),
|
||||
deleteChannelDiscord: (...args: unknown[]) => deleteChannelDiscord(...args),
|
||||
@@ -46,6 +52,8 @@ vi.mock("../../discord/send.js", () => ({
|
||||
fetchMessageDiscord: (...args: unknown[]) => fetchMessageDiscord(...args),
|
||||
fetchChannelPermissionsDiscord: (...args: unknown[]) => fetchChannelPermissionsDiscord(...args),
|
||||
fetchReactionsDiscord: (...args: unknown[]) => fetchReactionsDiscord(...args),
|
||||
kickMemberDiscord: (...args: unknown[]) => kickMemberDiscord(...args),
|
||||
listGuildChannelsDiscord: (...args: unknown[]) => listGuildChannelsDiscord(...args),
|
||||
listPinsDiscord: (...args: unknown[]) => listPinsDiscord(...args),
|
||||
listThreadsDiscord: (...args: unknown[]) => listThreadsDiscord(...args),
|
||||
moveChannelDiscord: (...args: unknown[]) => moveChannelDiscord(...args),
|
||||
@@ -60,12 +68,15 @@ vi.mock("../../discord/send.js", () => ({
|
||||
sendPollDiscord: (...args: unknown[]) => sendPollDiscord(...args),
|
||||
sendStickerDiscord: (...args: unknown[]) => sendStickerDiscord(...args),
|
||||
setChannelPermissionDiscord: (...args: unknown[]) => setChannelPermissionDiscord(...args),
|
||||
timeoutMemberDiscord: (...args: unknown[]) => timeoutMemberDiscord(...args),
|
||||
unpinMessageDiscord: (...args: unknown[]) => unpinMessageDiscord(...args),
|
||||
}));
|
||||
|
||||
const enableAllActions = () => true;
|
||||
|
||||
const disabledActions = (key: keyof DiscordActionConfig) => key !== "reactions";
|
||||
const channelInfoEnabled = (key: keyof DiscordActionConfig) => key === "channelInfo";
|
||||
const moderationEnabled = (key: keyof DiscordActionConfig) => key === "moderation";
|
||||
|
||||
describe("handleDiscordMessagingAction", () => {
|
||||
it("adds reactions", async () => {
|
||||
@@ -81,6 +92,20 @@ describe("handleDiscordMessagingAction", () => {
|
||||
expect(reactMessageDiscord).toHaveBeenCalledWith("C1", "M1", "✅", {});
|
||||
});
|
||||
|
||||
it("forwards accountId for reactions", async () => {
|
||||
await handleDiscordMessagingAction(
|
||||
"react",
|
||||
{
|
||||
channelId: "C1",
|
||||
messageId: "M1",
|
||||
emoji: "✅",
|
||||
accountId: "ops",
|
||||
},
|
||||
enableAllActions,
|
||||
);
|
||||
expect(reactMessageDiscord).toHaveBeenCalledWith("C1", "M1", "✅", { accountId: "ops" });
|
||||
});
|
||||
|
||||
it("removes reactions on empty emoji", async () => {
|
||||
await handleDiscordMessagingAction(
|
||||
"react",
|
||||
@@ -248,6 +273,15 @@ describe("handleDiscordGuildAction - channel management", () => {
|
||||
).rejects.toThrow(/Discord channel management is disabled/);
|
||||
});
|
||||
|
||||
it("forwards accountId for channelList", async () => {
|
||||
await handleDiscordGuildAction(
|
||||
"channelList",
|
||||
{ guildId: "G1", accountId: "ops" },
|
||||
channelInfoEnabled,
|
||||
);
|
||||
expect(listGuildChannelsDiscord).toHaveBeenCalledWith("G1", { accountId: "ops" });
|
||||
});
|
||||
|
||||
it("edits a channel", async () => {
|
||||
await handleDiscordGuildAction(
|
||||
"channelEdit",
|
||||
@@ -481,3 +515,26 @@ describe("handleDiscordGuildAction - channel management", () => {
|
||||
expect(removeChannelPermissionDiscord).toHaveBeenCalledWith("C1", "R1", {});
|
||||
});
|
||||
});
|
||||
|
||||
describe("handleDiscordModerationAction", () => {
|
||||
it("forwards accountId for timeout", async () => {
|
||||
await handleDiscordModerationAction(
|
||||
"timeout",
|
||||
{
|
||||
guildId: "G1",
|
||||
userId: "U1",
|
||||
durationMinutes: 5,
|
||||
accountId: "ops",
|
||||
},
|
||||
moderationEnabled,
|
||||
);
|
||||
expect(timeoutMemberDiscord).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
guildId: "G1",
|
||||
userId: "U1",
|
||||
durationMinutes: 5,
|
||||
}),
|
||||
{ accountId: "ops" },
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user