feat: add image model config + tool

This commit is contained in:
Peter Steinberger
2026-01-04 19:35:00 +01:00
parent 0716a624a8
commit 78998dba9e
20 changed files with 856 additions and 144 deletions

View File

@@ -9,12 +9,16 @@ import { createSessionsHistoryTool } from "./tools/sessions-history-tool.js";
import { createSessionsListTool } from "./tools/sessions-list-tool.js";
import { createSessionsSendTool } from "./tools/sessions-send-tool.js";
import { createSlackTool } from "./tools/slack-tool.js";
import { createImageTool } from "./tools/image-tool.js";
import type { ClawdbotConfig } from "../config/config.js";
export function createClawdbotTools(options?: {
browserControlUrl?: string;
agentSessionKey?: string;
agentSurface?: string;
config?: ClawdbotConfig;
}): AnyAgentTool[] {
const imageTool = createImageTool({ config: options?.config });
return [
createBrowserTool({ defaultControlUrl: options?.browserControlUrl }),
createCanvasTool(),
@@ -29,5 +33,6 @@ export function createClawdbotTools(options?: {
agentSessionKey: options?.agentSessionKey,
agentSurface: options?.agentSurface,
}),
...(imageTool ? [imageTool] : []),
];
}

143
src/agents/model-auth.ts Normal file
View File

@@ -0,0 +1,143 @@
import fsSync from "node:fs";
import os from "node:os";
import path from "node:path";
import {
type Api,
getEnvApiKey,
getOAuthApiKey,
type Model,
type OAuthCredentials,
type OAuthProvider,
} from "@mariozechner/pi-ai";
import { discoverAuthStorage } from "@mariozechner/pi-coding-agent";
import { CONFIG_DIR, resolveUserPath } from "../utils.js";
const OAUTH_FILENAME = "oauth.json";
const DEFAULT_OAUTH_DIR = path.join(CONFIG_DIR, "credentials");
let oauthStorageConfigured = false;
type OAuthStorage = Record<string, OAuthCredentials>;
function resolveClawdbotOAuthPath(): string {
const overrideDir =
process.env.CLAWDBOT_OAUTH_DIR?.trim() || DEFAULT_OAUTH_DIR;
return path.join(resolveUserPath(overrideDir), OAUTH_FILENAME);
}
function loadOAuthStorageAt(pathname: string): OAuthStorage | null {
if (!fsSync.existsSync(pathname)) return null;
try {
const content = fsSync.readFileSync(pathname, "utf8");
const json = JSON.parse(content) as OAuthStorage;
if (!json || typeof json !== "object") return null;
return json;
} catch {
return null;
}
}
function hasAnthropicOAuth(storage: OAuthStorage): boolean {
const entry = storage.anthropic as
| {
refresh?: string;
refresh_token?: string;
refreshToken?: string;
access?: string;
access_token?: string;
accessToken?: string;
}
| undefined;
if (!entry) return false;
const refresh =
entry.refresh ?? entry.refresh_token ?? entry.refreshToken ?? "";
const access = entry.access ?? entry.access_token ?? entry.accessToken ?? "";
return Boolean(refresh.trim() && access.trim());
}
function saveOAuthStorageAt(pathname: string, storage: OAuthStorage): void {
const dir = path.dirname(pathname);
fsSync.mkdirSync(dir, { recursive: true, mode: 0o700 });
fsSync.writeFileSync(
pathname,
`${JSON.stringify(storage, null, 2)}\n`,
"utf8",
);
fsSync.chmodSync(pathname, 0o600);
}
function legacyOAuthPaths(): string[] {
const paths: string[] = [];
const piOverride = process.env.PI_CODING_AGENT_DIR?.trim();
if (piOverride) {
paths.push(path.join(resolveUserPath(piOverride), OAUTH_FILENAME));
}
paths.push(path.join(os.homedir(), ".pi", "agent", OAUTH_FILENAME));
paths.push(path.join(os.homedir(), ".claude", OAUTH_FILENAME));
paths.push(path.join(os.homedir(), ".config", "claude", OAUTH_FILENAME));
paths.push(path.join(os.homedir(), ".config", "anthropic", OAUTH_FILENAME));
return Array.from(new Set(paths));
}
function importLegacyOAuthIfNeeded(destPath: string): void {
if (fsSync.existsSync(destPath)) return;
for (const legacyPath of legacyOAuthPaths()) {
const storage = loadOAuthStorageAt(legacyPath);
if (!storage || !hasAnthropicOAuth(storage)) continue;
saveOAuthStorageAt(destPath, storage);
return;
}
}
export function ensureOAuthStorage(): void {
if (oauthStorageConfigured) return;
oauthStorageConfigured = true;
const oauthPath = resolveClawdbotOAuthPath();
importLegacyOAuthIfNeeded(oauthPath);
}
function isOAuthProvider(provider: string): provider is OAuthProvider {
return (
provider === "anthropic" ||
provider === "anthropic-oauth" ||
provider === "google" ||
provider === "openai" ||
provider === "openai-compatible" ||
provider === "github-copilot" ||
provider === "google-gemini-cli" ||
provider === "google-antigravity"
);
}
export async function getApiKeyForModel(
model: Model<Api>,
authStorage: ReturnType<typeof discoverAuthStorage>,
): Promise<string> {
const storedKey = await authStorage.getApiKey(model.provider);
if (storedKey) return storedKey;
ensureOAuthStorage();
if (model.provider === "anthropic") {
const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN;
if (oauthEnv?.trim()) return oauthEnv.trim();
}
const envKey = getEnvApiKey(model.provider);
if (envKey) return envKey;
if (isOAuthProvider(model.provider)) {
const oauthPath = resolveClawdbotOAuthPath();
const storage = loadOAuthStorageAt(oauthPath);
if (storage) {
try {
const result = await getOAuthApiKey(model.provider, storage);
if (result?.apiKey) {
storage[model.provider] = result.newCredentials;
saveOAuthStorageAt(oauthPath, storage);
return result.apiKey;
}
} catch {
// fall through to error below
}
}
}
throw new Error(`No API key found for provider "${model.provider}"`);
}

View File

@@ -44,6 +44,54 @@ function buildAllowedModelKeys(
return keys.size > 0 ? keys : null;
}
function resolveImageFallbackCandidates(params: {
cfg: ClawdbotConfig | undefined;
defaultProvider: string;
modelOverride?: string;
}): ModelCandidate[] {
const aliasIndex = buildModelAliasIndex({
cfg: params.cfg ?? {},
defaultProvider: params.defaultProvider,
});
const allowlist = buildAllowedModelKeys(params.cfg, params.defaultProvider);
const seen = new Set<string>();
const candidates: ModelCandidate[] = [];
const addCandidate = (
candidate: ModelCandidate,
enforceAllowlist: boolean,
) => {
if (!candidate.provider || !candidate.model) return;
const key = modelKey(candidate.provider, candidate.model);
if (seen.has(key)) return;
if (enforceAllowlist && allowlist && !allowlist.has(key)) return;
seen.add(key);
candidates.push(candidate);
};
const addRaw = (raw: string, enforceAllowlist: boolean) => {
const resolved = resolveModelRefFromString({
raw: String(raw ?? ""),
defaultProvider: params.defaultProvider,
aliasIndex,
});
if (!resolved) return;
addCandidate(resolved.ref, enforceAllowlist);
};
if (params.modelOverride?.trim()) {
addRaw(params.modelOverride, false);
} else if (params.cfg?.agent?.imageModel?.trim()) {
addRaw(params.cfg.agent.imageModel, false);
}
for (const raw of params.cfg?.agent?.imageModelFallbacks ?? []) {
addRaw(raw, true);
}
return candidates;
}
function resolveFallbackCandidates(params: {
cfg: ClawdbotConfig | undefined;
provider: string;
@@ -151,3 +199,78 @@ export async function runWithModelFallback<T>(params: {
{ cause: lastError instanceof Error ? lastError : undefined },
);
}
export async function runWithImageModelFallback<T>(params: {
cfg: ClawdbotConfig | undefined;
modelOverride?: string;
run: (provider: string, model: string) => Promise<T>;
onError?: (attempt: {
provider: string;
model: string;
error: unknown;
attempt: number;
total: number;
}) => void | Promise<void>;
}): Promise<{
result: T;
provider: string;
model: string;
attempts: FallbackAttempt[];
}> {
const candidates = resolveImageFallbackCandidates({
cfg: params.cfg,
defaultProvider: DEFAULT_PROVIDER,
modelOverride: params.modelOverride,
});
if (candidates.length === 0) {
throw new Error(
"No image model configured. Set agent.imageModel or agent.imageModelFallbacks.",
);
}
const attempts: FallbackAttempt[] = [];
let lastError: unknown;
for (let i = 0; i < candidates.length; i += 1) {
const candidate = candidates[i] as ModelCandidate;
try {
const result = await params.run(candidate.provider, candidate.model);
return {
result,
provider: candidate.provider,
model: candidate.model,
attempts,
};
} catch (err) {
if (isAbortError(err)) throw err;
lastError = err;
attempts.push({
provider: candidate.provider,
model: candidate.model,
error: err instanceof Error ? err.message : String(err),
});
await params.onError?.({
provider: candidate.provider,
model: candidate.model,
error: err,
attempt: i + 1,
total: candidates.length,
});
}
}
if (attempts.length <= 1 && lastError) throw lastError;
const summary =
attempts.length > 0
? attempts
.map(
(attempt) =>
`${attempt.provider}/${attempt.model}: ${attempt.error}`,
)
.join(" | ")
: "unknown";
throw new Error(
`All image models failed (${attempts.length || candidates.length}): ${summary}`,
{ cause: lastError instanceof Error ? lastError : undefined },
);
}

View File

@@ -348,7 +348,7 @@ export async function scanOpenRouterModels(
};
const toolResult = await probeTool(model, apiKey, timeoutMs);
const imageResult = toolResult.ok
const imageResult = model.input.includes("image")
? await probeImage(ensureImageInput(model), apiKey, timeoutMs)
: { ok: false, latencyMs: null, skipped: true };

View File

@@ -1,17 +1,11 @@
import fsSync from "node:fs";
import fs from "node:fs/promises";
import os from "node:os";
import path from "node:path";
import type { AgentMessage, ThinkingLevel } from "@mariozechner/pi-agent-core";
import {
type Api,
type AssistantMessage,
getEnvApiKey,
getOAuthApiKey,
type Model,
type OAuthCredentials,
type OAuthProvider,
} from "@mariozechner/pi-ai";
import {
buildSystemPrompt,
@@ -25,7 +19,6 @@ import {
import type { ThinkLevel, VerboseLevel } from "../auto-reply/thinking.js";
import { formatToolAggregate } from "../auto-reply/tool-meta.js";
import type { ClawdbotConfig } from "../config/config.js";
import { resolveOAuthPath } from "../config/paths.js";
import { getMachineDisplayName } from "../infra/machine-name.js";
import { createSubsystemLogger } from "../logging.js";
import { splitMediaFromOutput } from "../media/parse.js";
@@ -37,6 +30,7 @@ import { resolveUserPath } from "../utils.js";
import { resolveClawdbotAgentDir } from "./agent-paths.js";
import type { BashElevatedDefaults } from "./bash-tools.js";
import { DEFAULT_MODEL, DEFAULT_PROVIDER } from "./defaults.js";
import { getApiKeyForModel } from "./model-auth.js";
import { ensureClawdbotModelsJson } from "./models-config.js";
import {
buildBootstrapContextFiles,
@@ -106,10 +100,6 @@ type EmbeddedRunWaiter = {
};
const EMBEDDED_RUN_WAITERS = new Map<string, Set<EmbeddedRunWaiter>>();
const OAUTH_FILENAME = "oauth.json";
let oauthStorageConfigured = false;
type OAuthStorage = Record<string, OAuthCredentials>;
type EmbeddedSandboxInfo = {
enabled: boolean;
workspaceDir?: string;
@@ -139,90 +129,6 @@ export function buildEmbeddedSandboxInfo(
};
}
function resolveClawdbotOAuthPath(): string {
return resolveOAuthPath();
}
function loadOAuthStorageAt(pathname: string): OAuthStorage | null {
if (!fsSync.existsSync(pathname)) return null;
try {
const content = fsSync.readFileSync(pathname, "utf8");
const json = JSON.parse(content) as OAuthStorage;
if (!json || typeof json !== "object") return null;
return json;
} catch {
return null;
}
}
function hasAnthropicOAuth(storage: OAuthStorage): boolean {
const entry = storage.anthropic as
| {
refresh?: string;
refresh_token?: string;
refreshToken?: string;
access?: string;
access_token?: string;
accessToken?: string;
}
| undefined;
if (!entry) return false;
const refresh =
entry.refresh ?? entry.refresh_token ?? entry.refreshToken ?? "";
const access = entry.access ?? entry.access_token ?? entry.accessToken ?? "";
return Boolean(refresh.trim() && access.trim());
}
function saveOAuthStorageAt(pathname: string, storage: OAuthStorage): void {
const dir = path.dirname(pathname);
fsSync.mkdirSync(dir, { recursive: true, mode: 0o700 });
fsSync.writeFileSync(
pathname,
`${JSON.stringify(storage, null, 2)}\n`,
"utf8",
);
fsSync.chmodSync(pathname, 0o600);
}
function legacyOAuthPaths(): string[] {
const paths: string[] = [];
const piOverride = process.env.PI_CODING_AGENT_DIR?.trim();
if (piOverride) {
paths.push(path.join(resolveUserPath(piOverride), OAUTH_FILENAME));
}
paths.push(path.join(os.homedir(), ".pi", "agent", OAUTH_FILENAME));
paths.push(path.join(os.homedir(), ".claude", OAUTH_FILENAME));
paths.push(path.join(os.homedir(), ".config", "claude", OAUTH_FILENAME));
paths.push(path.join(os.homedir(), ".config", "anthropic", OAUTH_FILENAME));
return Array.from(new Set(paths));
}
function importLegacyOAuthIfNeeded(destPath: string): void {
if (fsSync.existsSync(destPath)) return;
for (const legacyPath of legacyOAuthPaths()) {
const storage = loadOAuthStorageAt(legacyPath);
if (!storage || !hasAnthropicOAuth(storage)) continue;
saveOAuthStorageAt(destPath, storage);
return;
}
}
function ensureOAuthStorage(): void {
if (oauthStorageConfigured) return;
oauthStorageConfigured = true;
const oauthPath = resolveClawdbotOAuthPath();
importLegacyOAuthIfNeeded(oauthPath);
}
function isOAuthProvider(provider: string): provider is OAuthProvider {
return (
provider === "anthropic" ||
provider === "github-copilot" ||
provider === "google-gemini-cli" ||
provider === "google-antigravity"
);
}
export function queueEmbeddedPiMessage(
sessionId: string,
text: string,
@@ -325,38 +231,6 @@ function resolveModel(
return { model, authStorage, modelRegistry };
}
async function getApiKeyForModel(
model: Model<Api>,
authStorage: ReturnType<typeof discoverAuthStorage>,
): Promise<string> {
const storedKey = await authStorage.getApiKey(model.provider);
if (storedKey) return storedKey;
ensureOAuthStorage();
if (model.provider === "anthropic") {
const oauthEnv = process.env.ANTHROPIC_OAUTH_TOKEN;
if (oauthEnv?.trim()) return oauthEnv.trim();
}
const envKey = getEnvApiKey(model.provider);
if (envKey) return envKey;
if (isOAuthProvider(model.provider)) {
const oauthPath = resolveClawdbotOAuthPath();
const storage = loadOAuthStorageAt(oauthPath);
if (storage) {
try {
const result = await getOAuthApiKey(model.provider, storage);
if (result?.apiKey) {
storage[model.provider] = result.newCredentials;
saveOAuthStorageAt(oauthPath, storage);
return result.apiKey;
}
} catch {
// fall through to error below
}
}
}
throw new Error(`No API key found for provider "${model.provider}"`);
}
function resolvePromptSkills(
snapshot: SkillSnapshot,
entries: SkillEntry[],
@@ -502,6 +376,7 @@ export async function runEmbeddedPiAgent(params: {
sandbox,
surface: params.surface,
sessionKey: params.sessionKey ?? params.sessionId,
config: params.config,
});
const machineName = await getMachineDisplayName();
const runtimeInfo = {

View File

@@ -17,6 +17,7 @@ import {
type ProcessToolDefaults,
} from "./bash-tools.js";
import { createClawdbotTools } from "./clawdbot-tools.js";
import type { ClawdbotConfig } from "../config/config.js";
import type { SandboxContext, SandboxToolPolicy } from "./sandbox.js";
import { assertSandboxPath } from "./sandbox-paths.js";
import { sanitizeToolResultImages } from "./tool-images.js";
@@ -452,6 +453,7 @@ export function createClawdbotCodingTools(options?: {
surface?: string;
sandbox?: SandboxContext | null;
sessionKey?: string;
config?: ClawdbotConfig;
}): AnyAgentTool[] {
const bashToolName = "bash";
const sandbox = options?.sandbox?.enabled ? options.sandbox : undefined;
@@ -497,6 +499,7 @@ export function createClawdbotCodingTools(options?: {
browserControlUrl: sandbox?.browser?.controlUrl,
agentSessionKey: options?.sessionKey,
agentSurface: options?.surface,
config: options?.config,
}),
];
const allowDiscord = shouldIncludeDiscordTool(options?.surface);

View File

@@ -0,0 +1,157 @@
import { type Api, type AssistantMessage, complete, type Context, type Model } from "@mariozechner/pi-ai";
import { discoverAuthStorage, discoverModels } from "@mariozechner/pi-coding-agent";
import { Type } from "@sinclair/typebox";
import type { ClawdbotConfig } from "../../config/config.js";
import { loadWebMedia } from "../../web/media.js";
import { resolveClawdbotAgentDir } from "../agent-paths.js";
import { getApiKeyForModel } from "../model-auth.js";
import { runWithImageModelFallback } from "../model-fallback.js";
import { ensureClawdbotModelsJson } from "../models-config.js";
import { extractAssistantText } from "../pi-embedded-utils.js";
import { resolveUserPath } from "../../utils.js";
import type { AnyAgentTool } from "./common.js";
const DEFAULT_PROMPT = "Describe the image.";
function ensureImageToolConfigured(cfg?: ClawdbotConfig): boolean {
const primary = cfg?.agent?.imageModel?.trim();
const fallbacks = cfg?.agent?.imageModelFallbacks ?? [];
return Boolean(primary || fallbacks.length > 0);
}
function pickMaxBytes(cfg?: ClawdbotConfig, maxBytesMb?: number): number | undefined {
if (typeof maxBytesMb === "number" && Number.isFinite(maxBytesMb) && maxBytesMb > 0) {
return Math.floor(maxBytesMb * 1024 * 1024);
}
const configured = cfg?.agent?.mediaMaxMb;
if (typeof configured === "number" && Number.isFinite(configured) && configured > 0) {
return Math.floor(configured * 1024 * 1024);
}
return undefined;
}
function buildImageContext(prompt: string, base64: string, mimeType: string): Context {
return {
messages: [
{
role: "user",
content: [
{ type: "text", text: prompt },
{ type: "image", data: base64, mimeType },
],
timestamp: Date.now(),
},
],
};
}
async function runImagePrompt(params: {
cfg?: ClawdbotConfig;
modelOverride?: string;
prompt: string;
base64: string;
mimeType: string;
}): Promise<{ text: string; provider: string; model: string }> {
const agentDir = resolveClawdbotAgentDir();
await ensureClawdbotModelsJson(params.cfg);
const authStorage = discoverAuthStorage(agentDir);
const modelRegistry = discoverModels(authStorage, agentDir);
const result = await runWithImageModelFallback({
cfg: params.cfg,
modelOverride: params.modelOverride,
run: async (provider, modelId) => {
const model = modelRegistry.find(provider, modelId) as Model<Api> | null;
if (!model) {
throw new Error(`Unknown model: ${provider}/${modelId}`);
}
if (!model.input?.includes("image")) {
throw new Error(`Model does not support images: ${provider}/${modelId}`);
}
const apiKey = await getApiKeyForModel(model, authStorage);
authStorage.setRuntimeApiKey(model.provider, apiKey);
const context = buildImageContext(
params.prompt,
params.base64,
params.mimeType,
);
const message = (await complete(model, context, {
apiKey,
maxTokens: 512,
temperature: 0,
})) as AssistantMessage;
return message;
},
});
const text = extractAssistantText(result.result);
return {
text: text || "(no text returned)",
provider: result.provider,
model: result.model,
};
}
export function createImageTool(options?: {
config?: ClawdbotConfig;
}): AnyAgentTool | null {
if (!ensureImageToolConfigured(options?.config)) return null;
return {
label: "Image",
name: "image",
description:
"Analyze an image with the configured image model (agent.imageModel). Provide a prompt and image path or URL.",
parameters: Type.Object({
prompt: Type.Optional(Type.String()),
image: Type.String(),
model: Type.Optional(Type.String()),
maxBytesMb: Type.Optional(Type.Number()),
}),
execute: async (_toolCallId, args) => {
const record =
args && typeof args === "object"
? (args as Record<string, unknown>)
: {};
const imageRaw =
typeof record.image === "string" ? record.image.trim() : "";
if (!imageRaw) throw new Error("image required");
const promptRaw =
typeof record.prompt === "string" && record.prompt.trim()
? record.prompt.trim()
: DEFAULT_PROMPT;
const modelOverride =
typeof record.model === "string" && record.model.trim()
? record.model.trim()
: undefined;
const maxBytesMb =
typeof record.maxBytesMb === "number" ? record.maxBytesMb : undefined;
const maxBytes = pickMaxBytes(options?.config, maxBytesMb);
const resolvedImage = imageRaw.startsWith("~")
? resolveUserPath(imageRaw)
: imageRaw;
const media = await loadWebMedia(resolvedImage, maxBytes);
if (media.kind !== "image") {
throw new Error(`Unsupported media type: ${media.kind}`);
}
const mimeType = media.contentType ?? "image/png";
const base64 = media.buffer.toString("base64");
const result = await runImagePrompt({
cfg: options?.config,
modelOverride,
prompt: promptRaw,
base64,
mimeType,
});
return {
content: [{ type: "text", text: result.text }],
details: {
model: `${result.provider}/${result.model}`,
image: resolvedImage,
},
};
},
};
}