From a5adedea916c209aeb63bd46f3910ba850c28977 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Tue, 20 Jan 2026 07:53:25 +0000 Subject: [PATCH] refactor: add aws-sdk auth mode and tighten provider auth --- docs/bedrock.md | 4 + src/agents/anthropic.setup-token.live.test.ts | 7 +- src/agents/model-auth.test.ts | 183 ++++++++++++++++++ src/agents/model-auth.ts | 148 +++++++++++++- src/agents/models-config.providers.ts | 28 ++- src/agents/models.profiles.live.test.ts | 4 +- src/agents/pi-embedded-runner/compact.ts | 8 +- src/agents/pi-embedded-runner/model.test.ts | 29 +++ src/agents/pi-embedded-runner/model.ts | 26 ++- src/agents/pi-embedded-runner/run.ts | 16 +- src/agents/tools/image-tool.ts | 9 +- src/config/types.models.ts | 3 + src/config/zod-schema.core.ts | 8 + src/media-understanding/apply.test.ts | 5 + src/media-understanding/providers/image.ts | 9 +- src/media-understanding/runner.ts | 12 +- src/memory/embeddings-gemini.ts | 19 +- src/memory/embeddings-openai.ts | 19 +- src/memory/embeddings.test.ts | 16 +- 19 files changed, 489 insertions(+), 64 deletions(-) create mode 100644 src/agents/pi-embedded-runner/model.test.ts diff --git a/docs/bedrock.md b/docs/bedrock.md index fd7547cf1..6967c2f94 100644 --- a/docs/bedrock.md +++ b/docs/bedrock.md @@ -41,6 +41,7 @@ export AWS_BEARER_TOKEN_BEDROCK="..." "amazon-bedrock": { baseUrl: "https://bedrock-runtime.us-east-1.amazonaws.com", api: "bedrock-converse-stream", + auth: "aws-sdk", models: [ { id: "anthropic.claude-3-7-sonnet-20250219-v1:0", @@ -67,6 +68,9 @@ export AWS_BEARER_TOKEN_BEDROCK="..." - Bedrock requires **model access** enabled in your AWS account/region. - If you use profiles, set `AWS_PROFILE` on the gateway host. +- Clawdbot surfaces the credential source in this order: `AWS_BEARER_TOKEN_BEDROCK`, + then `AWS_ACCESS_KEY_ID` + `AWS_SECRET_ACCESS_KEY`, then `AWS_PROFILE`, then the + default AWS SDK chain. - Reasoning support depends on the model; check the Bedrock model card for current capabilities. - If you prefer a managed key flow, you can also place an OpenAI‑compatible diff --git a/src/agents/anthropic.setup-token.live.test.ts b/src/agents/anthropic.setup-token.live.test.ts index 16061b188..ac7860981 100644 --- a/src/agents/anthropic.setup-token.live.test.ts +++ b/src/agents/anthropic.setup-token.live.test.ts @@ -18,7 +18,7 @@ import { ensureAuthProfileStore, saveAuthProfileStore, } from "./auth-profiles.js"; -import { getApiKeyForModel } from "./model-auth.js"; +import { getApiKeyForModel, requireApiKey } from "./model-auth.js"; import { normalizeProviderId, parseModelRef } from "./model-selection.js"; import { ensureClawdbotModelsJson } from "./models-config.js"; @@ -178,7 +178,8 @@ describeLive("live anthropic setup-token", () => { profileId: tokenSource.profileId, agentDir: tokenSource.agentDir, }); - const tokenError = validateAnthropicSetupToken(apiKeyInfo.apiKey); + const apiKey = requireApiKey(apiKeyInfo, model.provider); + const tokenError = validateAnthropicSetupToken(apiKey); if (tokenError) { throw new Error(`Resolved profile is not a setup-token: ${tokenError}`); } @@ -195,7 +196,7 @@ describeLive("live anthropic setup-token", () => { ], }, { - apiKey: apiKeyInfo.apiKey, + apiKey, maxTokens: 64, temperature: 0, }, diff --git a/src/agents/model-auth.test.ts b/src/agents/model-auth.test.ts index ad1bccc90..3de3e33b2 100644 --- a/src/agents/model-auth.test.ts +++ b/src/agents/model-auth.test.ts @@ -280,4 +280,187 @@ describe("getApiKeyForModel", () => { } } }); + + it("prefers Bedrock bearer token over access keys and profile", async () => { + const previous = { + bearer: process.env.AWS_BEARER_TOKEN_BEDROCK, + access: process.env.AWS_ACCESS_KEY_ID, + secret: process.env.AWS_SECRET_ACCESS_KEY, + profile: process.env.AWS_PROFILE, + }; + + try { + process.env.AWS_BEARER_TOKEN_BEDROCK = "bedrock-token"; + process.env.AWS_ACCESS_KEY_ID = "access-key"; + process.env.AWS_SECRET_ACCESS_KEY = "secret-key"; + process.env.AWS_PROFILE = "profile"; + + vi.resetModules(); + const { resolveApiKeyForProvider } = await import("./model-auth.js"); + + const resolved = await resolveApiKeyForProvider({ + provider: "amazon-bedrock", + store: { version: 1, profiles: {} }, + cfg: { + models: { + providers: { + "amazon-bedrock": { + baseUrl: "https://bedrock-runtime.us-east-1.amazonaws.com", + api: "bedrock-converse-stream", + auth: "aws-sdk", + models: [], + }, + }, + }, + } as never, + }); + + expect(resolved.mode).toBe("aws-sdk"); + expect(resolved.apiKey).toBeUndefined(); + expect(resolved.source).toContain("AWS_BEARER_TOKEN_BEDROCK"); + } finally { + if (previous.bearer === undefined) { + delete process.env.AWS_BEARER_TOKEN_BEDROCK; + } else { + process.env.AWS_BEARER_TOKEN_BEDROCK = previous.bearer; + } + if (previous.access === undefined) { + delete process.env.AWS_ACCESS_KEY_ID; + } else { + process.env.AWS_ACCESS_KEY_ID = previous.access; + } + if (previous.secret === undefined) { + delete process.env.AWS_SECRET_ACCESS_KEY; + } else { + process.env.AWS_SECRET_ACCESS_KEY = previous.secret; + } + if (previous.profile === undefined) { + delete process.env.AWS_PROFILE; + } else { + process.env.AWS_PROFILE = previous.profile; + } + } + }); + + it("prefers Bedrock access keys over profile", async () => { + const previous = { + bearer: process.env.AWS_BEARER_TOKEN_BEDROCK, + access: process.env.AWS_ACCESS_KEY_ID, + secret: process.env.AWS_SECRET_ACCESS_KEY, + profile: process.env.AWS_PROFILE, + }; + + try { + delete process.env.AWS_BEARER_TOKEN_BEDROCK; + process.env.AWS_ACCESS_KEY_ID = "access-key"; + process.env.AWS_SECRET_ACCESS_KEY = "secret-key"; + process.env.AWS_PROFILE = "profile"; + + vi.resetModules(); + const { resolveApiKeyForProvider } = await import("./model-auth.js"); + + const resolved = await resolveApiKeyForProvider({ + provider: "amazon-bedrock", + store: { version: 1, profiles: {} }, + cfg: { + models: { + providers: { + "amazon-bedrock": { + baseUrl: "https://bedrock-runtime.us-east-1.amazonaws.com", + api: "bedrock-converse-stream", + auth: "aws-sdk", + models: [], + }, + }, + }, + } as never, + }); + + expect(resolved.mode).toBe("aws-sdk"); + expect(resolved.apiKey).toBeUndefined(); + expect(resolved.source).toContain("AWS_ACCESS_KEY_ID"); + } finally { + if (previous.bearer === undefined) { + delete process.env.AWS_BEARER_TOKEN_BEDROCK; + } else { + process.env.AWS_BEARER_TOKEN_BEDROCK = previous.bearer; + } + if (previous.access === undefined) { + delete process.env.AWS_ACCESS_KEY_ID; + } else { + process.env.AWS_ACCESS_KEY_ID = previous.access; + } + if (previous.secret === undefined) { + delete process.env.AWS_SECRET_ACCESS_KEY; + } else { + process.env.AWS_SECRET_ACCESS_KEY = previous.secret; + } + if (previous.profile === undefined) { + delete process.env.AWS_PROFILE; + } else { + process.env.AWS_PROFILE = previous.profile; + } + } + }); + + it("uses Bedrock profile when access keys are missing", async () => { + const previous = { + bearer: process.env.AWS_BEARER_TOKEN_BEDROCK, + access: process.env.AWS_ACCESS_KEY_ID, + secret: process.env.AWS_SECRET_ACCESS_KEY, + profile: process.env.AWS_PROFILE, + }; + + try { + delete process.env.AWS_BEARER_TOKEN_BEDROCK; + delete process.env.AWS_ACCESS_KEY_ID; + delete process.env.AWS_SECRET_ACCESS_KEY; + process.env.AWS_PROFILE = "profile"; + + vi.resetModules(); + const { resolveApiKeyForProvider } = await import("./model-auth.js"); + + const resolved = await resolveApiKeyForProvider({ + provider: "amazon-bedrock", + store: { version: 1, profiles: {} }, + cfg: { + models: { + providers: { + "amazon-bedrock": { + baseUrl: "https://bedrock-runtime.us-east-1.amazonaws.com", + api: "bedrock-converse-stream", + auth: "aws-sdk", + models: [], + }, + }, + }, + } as never, + }); + + expect(resolved.mode).toBe("aws-sdk"); + expect(resolved.apiKey).toBeUndefined(); + expect(resolved.source).toContain("AWS_PROFILE"); + } finally { + if (previous.bearer === undefined) { + delete process.env.AWS_BEARER_TOKEN_BEDROCK; + } else { + process.env.AWS_BEARER_TOKEN_BEDROCK = previous.bearer; + } + if (previous.access === undefined) { + delete process.env.AWS_ACCESS_KEY_ID; + } else { + process.env.AWS_ACCESS_KEY_ID = previous.access; + } + if (previous.secret === undefined) { + delete process.env.AWS_SECRET_ACCESS_KEY; + } else { + process.env.AWS_SECRET_ACCESS_KEY = previous.secret; + } + if (previous.profile === undefined) { + delete process.env.AWS_PROFILE; + } else { + process.env.AWS_PROFILE = previous.profile; + } + } + }); }); diff --git a/src/agents/model-auth.ts b/src/agents/model-auth.ts index e434f7dac..754d75e66 100644 --- a/src/agents/model-auth.ts +++ b/src/agents/model-auth.ts @@ -2,7 +2,7 @@ import path from "node:path"; import { type Api, getEnvApiKey, type Model } from "@mariozechner/pi-ai"; import type { ClawdbotConfig } from "../config/config.js"; -import type { ModelProviderConfig } from "../config/types.js"; +import type { ModelProviderAuthMode, ModelProviderConfig } from "../config/types.js"; import { getShellEnvAppliedKeys } from "../infra/shell-env.js"; import { formatCliCommand } from "../cli/command-format.js"; import { @@ -17,16 +17,115 @@ import { normalizeProviderId } from "./model-selection.js"; export { ensureAuthProfileStore, resolveAuthProfileOrder } from "./auth-profiles.js"; +const AWS_BEARER_ENV = "AWS_BEARER_TOKEN_BEDROCK"; +const AWS_ACCESS_KEY_ENV = "AWS_ACCESS_KEY_ID"; +const AWS_SECRET_KEY_ENV = "AWS_SECRET_ACCESS_KEY"; +const AWS_PROFILE_ENV = "AWS_PROFILE"; + +function resolveProviderConfig( + cfg: ClawdbotConfig | undefined, + provider: string, +): ModelProviderConfig | undefined { + const providers = cfg?.models?.providers ?? {}; + const direct = providers[provider] as ModelProviderConfig | undefined; + if (direct) return direct; + const normalized = normalizeProviderId(provider); + if (normalized === provider) { + const matched = Object.entries(providers).find( + ([key]) => normalizeProviderId(key) === normalized, + ); + return matched?.[1] as ModelProviderConfig | undefined; + } + return ( + (providers[normalized] as ModelProviderConfig | undefined) ?? + (Object.entries(providers).find( + ([key]) => normalizeProviderId(key) === normalized, + )?.[1] as ModelProviderConfig | undefined) + ); +} + export function getCustomProviderApiKey( cfg: ClawdbotConfig | undefined, provider: string, ): string | undefined { - const providers = cfg?.models?.providers ?? {}; - const entry = providers[provider] as ModelProviderConfig | undefined; + const entry = resolveProviderConfig(cfg, provider); const key = entry?.apiKey?.trim(); return key || undefined; } +function resolveProviderAuthOverride( + cfg: ClawdbotConfig | undefined, + provider: string, +): ModelProviderAuthMode | undefined { + const entry = resolveProviderConfig(cfg, provider); + const auth = entry?.auth; + if (auth === "api-key" || auth === "aws-sdk" || auth === "oauth" || auth === "token") { + return auth; + } + return undefined; +} + +function resolveEnvSourceLabel(params: { + applied: Set; + envVars: string[]; + label: string; +}): string { + const shellApplied = params.envVars.some((envVar) => params.applied.has(envVar)); + const prefix = shellApplied ? "shell env: " : "env: "; + return `${prefix}${params.label}`; +} + +export function resolveAwsSdkEnvVarName(): string | undefined { + if (process.env[AWS_BEARER_ENV]?.trim()) return AWS_BEARER_ENV; + if (process.env[AWS_ACCESS_KEY_ENV]?.trim() && process.env[AWS_SECRET_KEY_ENV]?.trim()) { + return AWS_ACCESS_KEY_ENV; + } + if (process.env[AWS_PROFILE_ENV]?.trim()) return AWS_PROFILE_ENV; + return undefined; +} + +function resolveAwsSdkAuthInfo(): { mode: "aws-sdk"; source: string } { + const applied = new Set(getShellEnvAppliedKeys()); + if (process.env[AWS_BEARER_ENV]?.trim()) { + return { + mode: "aws-sdk", + source: resolveEnvSourceLabel({ + applied, + envVars: [AWS_BEARER_ENV], + label: AWS_BEARER_ENV, + }), + }; + } + if (process.env[AWS_ACCESS_KEY_ENV]?.trim() && process.env[AWS_SECRET_KEY_ENV]?.trim()) { + return { + mode: "aws-sdk", + source: resolveEnvSourceLabel({ + applied, + envVars: [AWS_ACCESS_KEY_ENV, AWS_SECRET_KEY_ENV], + label: `${AWS_ACCESS_KEY_ENV} + ${AWS_SECRET_KEY_ENV}`, + }), + }; + } + if (process.env[AWS_PROFILE_ENV]?.trim()) { + return { + mode: "aws-sdk", + source: resolveEnvSourceLabel({ + applied, + envVars: [AWS_PROFILE_ENV], + label: AWS_PROFILE_ENV, + }), + }; + } + return { mode: "aws-sdk", source: "aws-sdk default chain" }; +} + +export type ResolvedProviderAuth = { + apiKey?: string; + profileId?: string; + source: string; + mode: "api-key" | "oauth" | "token" | "aws-sdk"; +}; + export async function resolveApiKeyForProvider(params: { provider: string; cfg?: ClawdbotConfig; @@ -34,7 +133,7 @@ export async function resolveApiKeyForProvider(params: { preferredProfile?: string; store?: AuthProfileStore; agentDir?: string; -}): Promise<{ apiKey: string; profileId?: string; source: string }> { +}): Promise { const { provider, cfg, profileId, preferredProfile } = params; const store = params.store ?? ensureAuthProfileStore(params.agentDir); @@ -48,13 +147,20 @@ export async function resolveApiKeyForProvider(params: { if (!resolved) { throw new Error(`No credentials found for profile "${profileId}".`); } + const mode = store.profiles[profileId]?.type; return { apiKey: resolved.apiKey, profileId, source: `profile:${profileId}`, + mode: mode === "oauth" ? "oauth" : mode === "token" ? "token" : "api-key", }; } + const authOverride = resolveProviderAuthOverride(cfg, provider); + if (authOverride === "aws-sdk") { + return resolveAwsSdkAuthInfo(); + } + const order = resolveAuthProfileOrder({ cfg, store, @@ -70,10 +176,12 @@ export async function resolveApiKeyForProvider(params: { agentDir: params.agentDir, }); if (resolved) { + const mode = store.profiles[candidate]?.type; return { apiKey: resolved.apiKey, profileId: candidate, source: `profile:${candidate}`, + mode: mode === "oauth" ? "oauth" : mode === "token" ? "token" : "api-key", }; } } catch {} @@ -81,12 +189,21 @@ export async function resolveApiKeyForProvider(params: { const envResolved = resolveEnvApiKey(provider); if (envResolved) { - return { apiKey: envResolved.apiKey, source: envResolved.source }; + return { + apiKey: envResolved.apiKey, + source: envResolved.source, + mode: envResolved.source.includes("OAUTH_TOKEN") ? "oauth" : "api-key", + }; } const customKey = getCustomProviderApiKey(cfg, provider); if (customKey) { - return { apiKey: customKey, source: "models.json" }; + return { apiKey: customKey, source: "models.json", mode: "api-key" }; + } + + const normalized = normalizeProviderId(provider); + if (authOverride === undefined && normalized === "amazon-bedrock") { + return resolveAwsSdkAuthInfo(); } if (provider === "openai") { @@ -110,7 +227,7 @@ export async function resolveApiKeyForProvider(params: { } export type EnvApiKeyResult = { apiKey: string; source: string }; -export type ModelAuthMode = "api-key" | "oauth" | "token" | "mixed" | "unknown"; +export type ModelAuthMode = "api-key" | "oauth" | "token" | "mixed" | "aws-sdk" | "unknown"; export function resolveEnvApiKey(provider: string): EnvApiKeyResult | null { const normalized = normalizeProviderId(provider); @@ -181,6 +298,9 @@ export function resolveModelAuthMode( const resolved = provider?.trim(); if (!resolved) return undefined; + const authOverride = resolveProviderAuthOverride(cfg, resolved); + if (authOverride === "aws-sdk") return "aws-sdk"; + const authStore = store ?? ensureAuthProfileStore(); const profiles = listProfilesForProvider(authStore, resolved); if (profiles.length > 0) { @@ -198,6 +318,10 @@ export function resolveModelAuthMode( if (modes.has("api_key")) return "api-key"; } + if (authOverride === undefined && normalizeProviderId(resolved) === "amazon-bedrock") { + return "aws-sdk"; + } + const envKey = resolveEnvApiKey(resolved); if (envKey?.apiKey) { return envKey.source.includes("OAUTH_TOKEN") ? "oauth" : "api-key"; @@ -215,7 +339,7 @@ export async function getApiKeyForModel(params: { preferredProfile?: string; store?: AuthProfileStore; agentDir?: string; -}): Promise<{ apiKey: string; profileId?: string; source: string }> { +}): Promise { return resolveApiKeyForProvider({ provider: params.model.provider, cfg: params.cfg, @@ -225,3 +349,11 @@ export async function getApiKeyForModel(params: { agentDir: params.agentDir, }); } + +export function requireApiKey(auth: ResolvedProviderAuth, provider: string): string { + const key = auth.apiKey?.trim(); + if (key) return key; + throw new Error( + `No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`, + ); +} diff --git a/src/agents/models-config.providers.ts b/src/agents/models-config.providers.ts index 0acb6aaaf..251f7b92b 100644 --- a/src/agents/models-config.providers.ts +++ b/src/agents/models-config.providers.ts @@ -4,7 +4,7 @@ import { resolveCopilotApiToken, } from "../providers/github-copilot-token.js"; import { ensureAuthProfileStore, listProfilesForProvider } from "./auth-profiles.js"; -import { resolveEnvApiKey } from "./model-auth.js"; +import { resolveAwsSdkEnvVarName, resolveEnvApiKey } from "./model-auth.js"; import { buildSyntheticModelDefinition, SYNTHETIC_BASE_URL, @@ -74,6 +74,10 @@ function resolveEnvApiKeyVarName(provider: string): string | undefined { return match ? match[1] : undefined; } +function resolveAwsSdkApiKeyVarName(): string { + return resolveAwsSdkEnvVarName() ?? "AWS_PROFILE"; +} + function resolveApiKeyFromProfiles(params: { provider: string; store: ReturnType; @@ -138,15 +142,23 @@ export function normalizeProviders(params: { const hasModels = Array.isArray(normalizedProvider.models) && normalizedProvider.models.length > 0; if (hasModels && !normalizedProvider.apiKey?.trim()) { - const fromEnv = resolveEnvApiKeyVarName(normalizedKey); - const fromProfiles = resolveApiKeyFromProfiles({ - provider: normalizedKey, - store: authStore, - }); - const apiKey = fromEnv ?? fromProfiles; - if (apiKey?.trim()) { + const authMode = + normalizedProvider.auth ?? (normalizedKey === "amazon-bedrock" ? "aws-sdk" : undefined); + if (authMode === "aws-sdk") { + const apiKey = resolveAwsSdkApiKeyVarName(); mutated = true; normalizedProvider = { ...normalizedProvider, apiKey }; + } else { + const fromEnv = resolveEnvApiKeyVarName(normalizedKey); + const fromProfiles = resolveApiKeyFromProfiles({ + provider: normalizedKey, + store: authStore, + }); + const apiKey = fromEnv ?? fromProfiles; + if (apiKey?.trim()) { + mutated = true; + normalizedProvider = { ...normalizedProvider, apiKey }; + } } } diff --git a/src/agents/models.profiles.live.test.ts b/src/agents/models.profiles.live.test.ts index b038343ed..032c82992 100644 --- a/src/agents/models.profiles.live.test.ts +++ b/src/agents/models.profiles.live.test.ts @@ -11,7 +11,7 @@ import { isAnthropicRateLimitError, } from "./live-auth-keys.js"; import { isModernModelRef } from "./live-model-filter.js"; -import { getApiKeyForModel } from "./model-auth.js"; +import { getApiKeyForModel, requireApiKey } from "./model-auth.js"; import { ensureClawdbotModelsJson } from "./models-config.js"; import { isRateLimitErrorMessage } from "./pi-embedded-helpers/errors.js"; @@ -226,7 +226,7 @@ describeLive("live models (profile keys)", () => { const apiKey = model.provider === "anthropic" && anthropicKeys.length > 0 ? anthropicKeys[attempt] - : apiKeyInfo.apiKey; + : requireApiKey(apiKeyInfo, model.provider); try { // Special regression: OpenAI requires replayed `reasoning` items for tool-only turns. if ( diff --git a/src/agents/pi-embedded-runner/compact.ts b/src/agents/pi-embedded-runner/compact.ts index 56a595892..d79654f02 100644 --- a/src/agents/pi-embedded-runner/compact.ts +++ b/src/agents/pi-embedded-runner/compact.ts @@ -115,7 +115,13 @@ export async function compactEmbeddedPiSession(params: { agentDir, }); - if (model.provider === "github-copilot") { + if (!apiKeyInfo.apiKey) { + if (apiKeyInfo.mode !== "aws-sdk") { + throw new Error( + `No API key resolved for provider "${model.provider}" (auth mode: ${apiKeyInfo.mode}).`, + ); + } + } else if (model.provider === "github-copilot") { const { resolveCopilotApiToken } = await import("../../providers/github-copilot-token.js"); const copilotToken = await resolveCopilotApiToken({ diff --git a/src/agents/pi-embedded-runner/model.test.ts b/src/agents/pi-embedded-runner/model.test.ts new file mode 100644 index 000000000..b59735623 --- /dev/null +++ b/src/agents/pi-embedded-runner/model.test.ts @@ -0,0 +1,29 @@ +import { describe, expect, it } from "vitest"; + +import { buildInlineProviderModels } from "./model.js"; + +const makeModel = (id: string) => ({ + id, + name: id, + reasoning: false, + input: ["text"] as const, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 1, + maxTokens: 1, +}); + +describe("buildInlineProviderModels", () => { + it("attaches provider ids to inline models", () => { + const providers = { + " alpha ": { models: [makeModel("alpha-model")] }, + beta: { models: [makeModel("beta-model")] }, + }; + + const result = buildInlineProviderModels(providers); + + expect(result).toEqual([ + { ...makeModel("alpha-model"), provider: "alpha" }, + { ...makeModel("beta-model"), provider: "beta" }, + ]); + }); +}); diff --git a/src/agents/pi-embedded-runner/model.ts b/src/agents/pi-embedded-runner/model.ts index 2d146f80b..4fa836f98 100644 --- a/src/agents/pi-embedded-runner/model.ts +++ b/src/agents/pi-embedded-runner/model.ts @@ -2,9 +2,23 @@ import type { Api, Model } from "@mariozechner/pi-ai"; import { discoverAuthStorage, discoverModels } from "@mariozechner/pi-coding-agent"; import type { ClawdbotConfig } from "../../config/config.js"; +import type { ModelDefinitionConfig } from "../../config/types.js"; import { resolveClawdbotAgentDir } from "../agent-paths.js"; import { DEFAULT_CONTEXT_TOKENS } from "../defaults.js"; import { normalizeModelCompat } from "../model-compat.js"; +import { normalizeProviderId } from "../model-selection.js"; + +type InlineModelEntry = ModelDefinitionConfig & { provider: string }; + +export function buildInlineProviderModels( + providers: Record, +): InlineModelEntry[] { + return Object.entries(providers).flatMap(([providerId, entry]) => { + const trimmed = providerId.trim(); + if (!trimmed) return []; + return (entry?.models ?? []).map((model) => ({ ...model, provider: trimmed })); + }); +} export function buildModelAliasLines(cfg?: ClawdbotConfig) { const models = cfg?.agents?.defaults?.models ?? {}; @@ -38,12 +52,12 @@ export function resolveModel( const model = modelRegistry.find(provider, modelId) as Model | null; if (!model) { const providers = cfg?.models?.providers ?? {}; - const inlineModels = - providers[provider]?.models?.map((entry) => ({ ...entry, provider })) ?? - Object.values(providers) - .flatMap((entry) => entry?.models ?? []) - .map((entry) => ({ ...entry, provider })); - const inlineMatch = inlineModels.find((entry) => entry.id === modelId); + const inlineModels = buildInlineProviderModels(providers); + const normalizedProvider = normalizeProviderId(provider); + const inlineMatch = inlineModels.find( + (entry) => + normalizeProviderId(entry.provider) === normalizedProvider && entry.id === modelId, + ); if (inlineMatch) { const normalized = normalizeModelCompat(inlineMatch as Model); return { diff --git a/src/agents/pi-embedded-runner/run.ts b/src/agents/pi-embedded-runner/run.ts index 3e59ff5ed..31096d7cc 100644 --- a/src/agents/pi-embedded-runner/run.ts +++ b/src/agents/pi-embedded-runner/run.ts @@ -21,6 +21,7 @@ import { ensureAuthProfileStore, getApiKeyForModel, resolveAuthProfileOrder, + type ResolvedProviderAuth, } from "../model-auth.js"; import { ensureClawdbotModelsJson } from "../models-config.js"; import { @@ -47,11 +48,7 @@ import { buildEmbeddedRunPayloads } from "./run/payloads.js"; import type { EmbeddedPiAgentMeta, EmbeddedPiRunResult } from "./types.js"; import { describeUnknownError } from "./utils.js"; -type ApiKeyInfo = { - apiKey: string; - profileId?: string; - source: string; -}; +type ApiKeyInfo = ResolvedProviderAuth; export async function runEmbeddedPiAgent( params: RunEmbeddedPiAgentParams, @@ -151,6 +148,15 @@ export async function runEmbeddedPiAgent( const applyApiKeyInfo = async (candidate?: string): Promise => { apiKeyInfo = await resolveApiKeyForCandidate(candidate); + if (!apiKeyInfo.apiKey) { + if (apiKeyInfo.mode !== "aws-sdk") { + throw new Error( + `No API key resolved for provider "${model.provider}" (auth mode: ${apiKeyInfo.mode}).`, + ); + } + lastProfileId = apiKeyInfo.profileId; + return; + } if (model.provider === "github-copilot") { const { resolveCopilotApiToken } = await import("../../providers/github-copilot-token.js"); diff --git a/src/agents/tools/image-tool.ts b/src/agents/tools/image-tool.ts index 59e5df47d..a640830fd 100644 --- a/src/agents/tools/image-tool.ts +++ b/src/agents/tools/image-tool.ts @@ -17,7 +17,7 @@ import { loadWebMedia } from "../../web/media.js"; import { ensureAuthProfileStore, listProfilesForProvider } from "../auth-profiles.js"; import { DEFAULT_MODEL, DEFAULT_PROVIDER } from "../defaults.js"; import { minimaxUnderstandImage } from "../minimax-vlm.js"; -import { getApiKeyForModel, resolveEnvApiKey } from "../model-auth.js"; +import { getApiKeyForModel, requireApiKey, resolveEnvApiKey } from "../model-auth.js"; import { runWithImageModelFallback } from "../model-fallback.js"; import { resolveConfiguredModelRef } from "../model-selection.js"; import { ensureClawdbotModelsJson } from "../models-config.js"; @@ -252,12 +252,13 @@ async function runImagePrompt(params: { cfg: effectiveCfg, agentDir: params.agentDir, }); - authStorage.setRuntimeApiKey(model.provider, apiKeyInfo.apiKey); + const apiKey = requireApiKey(apiKeyInfo, model.provider); + authStorage.setRuntimeApiKey(model.provider, apiKey); const imageDataUrl = `data:${params.mimeType};base64,${params.base64}`; if (model.provider === "minimax") { const text = await minimaxUnderstandImage({ - apiKey: apiKeyInfo.apiKey, + apiKey, prompt: params.prompt, imageDataUrl, modelBaseUrl: model.baseUrl, @@ -267,7 +268,7 @@ async function runImagePrompt(params: { const context = buildImageContext(params.prompt, params.base64, params.mimeType); const message = (await complete(model, context, { - apiKey: apiKeyInfo.apiKey, + apiKey, maxTokens: 512, })) as AssistantMessage; const text = coerceImageAssistantText({ diff --git a/src/config/types.models.ts b/src/config/types.models.ts index 92032ae1e..f11f368f1 100644 --- a/src/config/types.models.ts +++ b/src/config/types.models.ts @@ -13,6 +13,8 @@ export type ModelCompatConfig = { maxTokensField?: "max_completion_tokens" | "max_tokens"; }; +export type ModelProviderAuthMode = "api-key" | "aws-sdk" | "oauth" | "token"; + export type ModelDefinitionConfig = { id: string; name: string; @@ -34,6 +36,7 @@ export type ModelDefinitionConfig = { export type ModelProviderConfig = { baseUrl: string; apiKey?: string; + auth?: ModelProviderAuthMode; api?: ModelApi; headers?: Record; authHeader?: boolean; diff --git a/src/config/zod-schema.core.ts b/src/config/zod-schema.core.ts index 13a0f1668..4b38644ec 100644 --- a/src/config/zod-schema.core.ts +++ b/src/config/zod-schema.core.ts @@ -49,6 +49,14 @@ export const ModelProviderSchema = z .object({ baseUrl: z.string().min(1), apiKey: z.string().optional(), + auth: z + .union([ + z.literal("api-key"), + z.literal("aws-sdk"), + z.literal("oauth"), + z.literal("token"), + ]) + .optional(), api: ModelApiSchema.optional(), headers: z.record(z.string(), z.string()).optional(), authHeader: z.boolean().optional(), diff --git a/src/media-understanding/apply.test.ts b/src/media-understanding/apply.test.ts index e7bf08a24..cc45f3e29 100644 --- a/src/media-understanding/apply.test.ts +++ b/src/media-understanding/apply.test.ts @@ -13,7 +13,12 @@ vi.mock("../agents/model-auth.js", () => ({ resolveApiKeyForProvider: vi.fn(async () => ({ apiKey: "test-key", source: "test", + mode: "api-key", })), + requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => { + if (auth?.apiKey) return auth.apiKey; + throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth?.mode}).`); + }, })); vi.mock("../media/fetch.js", () => ({ diff --git a/src/media-understanding/providers/image.ts b/src/media-understanding/providers/image.ts index bd056253a..c2c4dbbd6 100644 --- a/src/media-understanding/providers/image.ts +++ b/src/media-understanding/providers/image.ts @@ -2,7 +2,7 @@ import type { Api, AssistantMessage, Context, Model } from "@mariozechner/pi-ai" import { complete } from "@mariozechner/pi-ai"; import { discoverAuthStorage, discoverModels } from "@mariozechner/pi-coding-agent"; -import { getApiKeyForModel } from "../../agents/model-auth.js"; +import { getApiKeyForModel, requireApiKey } from "../../agents/model-auth.js"; import { ensureClawdbotModelsJson } from "../../agents/models-config.js"; import { minimaxUnderstandImage } from "../../agents/minimax-vlm.js"; import { coerceImageAssistantText } from "../../agents/tools/image-tool.helpers.js"; @@ -28,12 +28,13 @@ export async function describeImageWithModel( profileId: params.profile, preferredProfile: params.preferredProfile, }); - authStorage.setRuntimeApiKey(model.provider, apiKeyInfo.apiKey); + const apiKey = requireApiKey(apiKeyInfo, model.provider); + authStorage.setRuntimeApiKey(model.provider, apiKey); const base64 = params.buffer.toString("base64"); if (model.provider === "minimax") { const text = await minimaxUnderstandImage({ - apiKey: apiKeyInfo.apiKey, + apiKey, prompt: params.prompt ?? "Describe the image.", imageDataUrl: `data:${params.mime ?? "image/jpeg"};base64,${base64}`, modelBaseUrl: model.baseUrl, @@ -54,7 +55,7 @@ export async function describeImageWithModel( ], }; const message = (await complete(model, context, { - apiKey: apiKeyInfo.apiKey, + apiKey, maxTokens: params.maxTokens ?? 512, })) as AssistantMessage; const text = coerceImageAssistantText({ diff --git a/src/media-understanding/runner.ts b/src/media-understanding/runner.ts index dd5f4f0b6..eed797f26 100644 --- a/src/media-understanding/runner.ts +++ b/src/media-understanding/runner.ts @@ -1,7 +1,7 @@ import type { ClawdbotConfig } from "../config/config.js"; import type { MsgContext } from "../auto-reply/templating.js"; import { applyTemplate } from "../auto-reply/templating.js"; -import { resolveApiKeyForProvider } from "../agents/model-auth.js"; +import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js"; import { logVerbose, shouldLogVerbose } from "../globals.js"; import { runExec } from "../process/exec.js"; import type { @@ -300,13 +300,14 @@ async function runProviderEntry(params: { maxBytes, timeoutMs, }); - const key = await resolveApiKeyForProvider({ + const auth = await resolveApiKeyForProvider({ provider: providerId, cfg, profileId: entry.profile, preferredProfile: entry.preferredProfile, agentDir: params.agentDir, }); + const apiKey = requireApiKey(auth, providerId); const providerConfig = cfg.models?.providers?.[providerId]; const baseUrl = entry.baseUrl ?? params.config?.baseUrl ?? providerConfig?.baseUrl; const mergedHeaders = { @@ -325,7 +326,7 @@ async function runProviderEntry(params: { buffer: media.buffer, fileName: media.fileName, mime: media.mime, - apiKey: key.apiKey, + apiKey, baseUrl, headers, model, @@ -359,19 +360,20 @@ async function runProviderEntry(params: { `Video attachment ${params.attachmentIndex + 1} base64 payload ${estimatedBase64Bytes} exceeds ${maxBase64Bytes}`, ); } - const key = await resolveApiKeyForProvider({ + const auth = await resolveApiKeyForProvider({ provider: providerId, cfg, profileId: entry.profile, preferredProfile: entry.preferredProfile, agentDir: params.agentDir, }); + const apiKey = requireApiKey(auth, providerId); const providerConfig = cfg.models?.providers?.[providerId]; const result = await provider.describeVideo({ buffer: media.buffer, fileName: media.fileName, mime: media.mime, - apiKey: key.apiKey, + apiKey, baseUrl: providerConfig?.baseUrl, headers: providerConfig?.headers, model: entry.model, diff --git a/src/memory/embeddings-gemini.ts b/src/memory/embeddings-gemini.ts index 423522775..244384df6 100644 --- a/src/memory/embeddings-gemini.ts +++ b/src/memory/embeddings-gemini.ts @@ -1,4 +1,4 @@ -import { resolveApiKeyForProvider } from "../agents/model-auth.js"; +import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js"; import { isTruthyEnvValue } from "../infra/env.js"; import { createSubsystemLogger } from "../logging/subsystem.js"; import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; @@ -115,13 +115,16 @@ export async function resolveGeminiEmbeddingClient( const remoteApiKey = resolveRemoteApiKey(remote?.apiKey); const remoteBaseUrl = remote?.baseUrl?.trim(); - const { apiKey } = remoteApiKey - ? { apiKey: remoteApiKey } - : await resolveApiKeyForProvider({ - provider: "google", - cfg: options.config, - agentDir: options.agentDir, - }); + const apiKey = remoteApiKey + ? remoteApiKey + : requireApiKey( + await resolveApiKeyForProvider({ + provider: "google", + cfg: options.config, + agentDir: options.agentDir, + }), + "google", + ); const providerConfig = options.config.models?.providers?.google; const rawBaseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GEMINI_BASE_URL; diff --git a/src/memory/embeddings-openai.ts b/src/memory/embeddings-openai.ts index f2bed1e04..cfc53efae 100644 --- a/src/memory/embeddings-openai.ts +++ b/src/memory/embeddings-openai.ts @@ -1,4 +1,4 @@ -import { resolveApiKeyForProvider } from "../agents/model-auth.js"; +import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js"; import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; export type OpenAiEmbeddingClient = { @@ -62,13 +62,16 @@ export async function resolveOpenAiEmbeddingClient( const remoteApiKey = remote?.apiKey?.trim(); const remoteBaseUrl = remote?.baseUrl?.trim(); - const { apiKey } = remoteApiKey - ? { apiKey: remoteApiKey } - : await resolveApiKeyForProvider({ - provider: "openai", - cfg: options.config, - agentDir: options.agentDir, - }); + const apiKey = remoteApiKey + ? remoteApiKey + : requireApiKey( + await resolveApiKeyForProvider({ + provider: "openai", + cfg: options.config, + agentDir: options.agentDir, + }), + "openai", + ); const providerConfig = options.config.models?.providers?.openai; const baseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_OPENAI_BASE_URL; diff --git a/src/memory/embeddings.test.ts b/src/memory/embeddings.test.ts index 7c4f609aa..e37bca3cd 100644 --- a/src/memory/embeddings.test.ts +++ b/src/memory/embeddings.test.ts @@ -2,6 +2,10 @@ import { afterEach, describe, expect, it, vi } from "vitest"; vi.mock("../agents/model-auth.js", () => ({ resolveApiKeyForProvider: vi.fn(), + requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => { + if (auth?.apiKey) return auth.apiKey; + throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth?.mode}).`); + }, })); const createFetchMock = () => @@ -26,6 +30,8 @@ describe("embedding provider remote overrides", () => { const authModule = await import("../agents/model-auth.js"); vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({ apiKey: "provider-key", + mode: "api-key", + source: "test", }); const cfg = { @@ -78,6 +84,8 @@ describe("embedding provider remote overrides", () => { const authModule = await import("../agents/model-auth.js"); vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({ apiKey: "provider-key", + mode: "api-key", + source: "test", }); const cfg = { @@ -120,6 +128,8 @@ describe("embedding provider remote overrides", () => { const authModule = await import("../agents/model-auth.js"); vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({ apiKey: "provider-key", + mode: "api-key", + source: "test", }); const cfg = { @@ -166,7 +176,7 @@ describe("embedding provider auto selection", () => { const authModule = await import("../agents/model-auth.js"); vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) => { if (provider === "openai") { - return { apiKey: "openai-key", source: "env: OPENAI_API_KEY" }; + return { apiKey: "openai-key", source: "env: OPENAI_API_KEY", mode: "api-key" }; } throw new Error(`No API key found for provider "${provider}".`); }); @@ -190,7 +200,7 @@ describe("embedding provider auto selection", () => { throw new Error('No API key found for provider "openai".'); } if (provider === "google") { - return { apiKey: "gemini-key", source: "env: GEMINI_API_KEY" }; + return { apiKey: "gemini-key", source: "env: GEMINI_API_KEY", mode: "api-key" }; } throw new Error(`Unexpected provider ${provider}`); }); @@ -231,6 +241,8 @@ describe("embedding provider local fallback", () => { const authModule = await import("../agents/model-auth.js"); vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({ apiKey: "provider-key", + mode: "api-key", + source: "test", }); const result = await createEmbeddingProvider({