refactor: add aws-sdk auth mode and tighten provider auth
This commit is contained in:
@@ -41,6 +41,7 @@ export AWS_BEARER_TOKEN_BEDROCK="..."
|
||||
"amazon-bedrock": {
|
||||
baseUrl: "https://bedrock-runtime.us-east-1.amazonaws.com",
|
||||
api: "bedrock-converse-stream",
|
||||
auth: "aws-sdk",
|
||||
models: [
|
||||
{
|
||||
id: "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
@@ -67,6 +68,9 @@ export AWS_BEARER_TOKEN_BEDROCK="..."
|
||||
|
||||
- Bedrock requires **model access** enabled in your AWS account/region.
|
||||
- If you use profiles, set `AWS_PROFILE` on the gateway host.
|
||||
- Clawdbot surfaces the credential source in this order: `AWS_BEARER_TOKEN_BEDROCK`,
|
||||
then `AWS_ACCESS_KEY_ID` + `AWS_SECRET_ACCESS_KEY`, then `AWS_PROFILE`, then the
|
||||
default AWS SDK chain.
|
||||
- Reasoning support depends on the model; check the Bedrock model card for
|
||||
current capabilities.
|
||||
- If you prefer a managed key flow, you can also place an OpenAI‑compatible
|
||||
|
||||
@@ -18,7 +18,7 @@ import {
|
||||
ensureAuthProfileStore,
|
||||
saveAuthProfileStore,
|
||||
} from "./auth-profiles.js";
|
||||
import { getApiKeyForModel } from "./model-auth.js";
|
||||
import { getApiKeyForModel, requireApiKey } from "./model-auth.js";
|
||||
import { normalizeProviderId, parseModelRef } from "./model-selection.js";
|
||||
import { ensureClawdbotModelsJson } from "./models-config.js";
|
||||
|
||||
@@ -178,7 +178,8 @@ describeLive("live anthropic setup-token", () => {
|
||||
profileId: tokenSource.profileId,
|
||||
agentDir: tokenSource.agentDir,
|
||||
});
|
||||
const tokenError = validateAnthropicSetupToken(apiKeyInfo.apiKey);
|
||||
const apiKey = requireApiKey(apiKeyInfo, model.provider);
|
||||
const tokenError = validateAnthropicSetupToken(apiKey);
|
||||
if (tokenError) {
|
||||
throw new Error(`Resolved profile is not a setup-token: ${tokenError}`);
|
||||
}
|
||||
@@ -195,7 +196,7 @@ describeLive("live anthropic setup-token", () => {
|
||||
],
|
||||
},
|
||||
{
|
||||
apiKey: apiKeyInfo.apiKey,
|
||||
apiKey,
|
||||
maxTokens: 64,
|
||||
temperature: 0,
|
||||
},
|
||||
|
||||
@@ -280,4 +280,187 @@ describe("getApiKeyForModel", () => {
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
it("prefers Bedrock bearer token over access keys and profile", async () => {
|
||||
const previous = {
|
||||
bearer: process.env.AWS_BEARER_TOKEN_BEDROCK,
|
||||
access: process.env.AWS_ACCESS_KEY_ID,
|
||||
secret: process.env.AWS_SECRET_ACCESS_KEY,
|
||||
profile: process.env.AWS_PROFILE,
|
||||
};
|
||||
|
||||
try {
|
||||
process.env.AWS_BEARER_TOKEN_BEDROCK = "bedrock-token";
|
||||
process.env.AWS_ACCESS_KEY_ID = "access-key";
|
||||
process.env.AWS_SECRET_ACCESS_KEY = "secret-key";
|
||||
process.env.AWS_PROFILE = "profile";
|
||||
|
||||
vi.resetModules();
|
||||
const { resolveApiKeyForProvider } = await import("./model-auth.js");
|
||||
|
||||
const resolved = await resolveApiKeyForProvider({
|
||||
provider: "amazon-bedrock",
|
||||
store: { version: 1, profiles: {} },
|
||||
cfg: {
|
||||
models: {
|
||||
providers: {
|
||||
"amazon-bedrock": {
|
||||
baseUrl: "https://bedrock-runtime.us-east-1.amazonaws.com",
|
||||
api: "bedrock-converse-stream",
|
||||
auth: "aws-sdk",
|
||||
models: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
} as never,
|
||||
});
|
||||
|
||||
expect(resolved.mode).toBe("aws-sdk");
|
||||
expect(resolved.apiKey).toBeUndefined();
|
||||
expect(resolved.source).toContain("AWS_BEARER_TOKEN_BEDROCK");
|
||||
} finally {
|
||||
if (previous.bearer === undefined) {
|
||||
delete process.env.AWS_BEARER_TOKEN_BEDROCK;
|
||||
} else {
|
||||
process.env.AWS_BEARER_TOKEN_BEDROCK = previous.bearer;
|
||||
}
|
||||
if (previous.access === undefined) {
|
||||
delete process.env.AWS_ACCESS_KEY_ID;
|
||||
} else {
|
||||
process.env.AWS_ACCESS_KEY_ID = previous.access;
|
||||
}
|
||||
if (previous.secret === undefined) {
|
||||
delete process.env.AWS_SECRET_ACCESS_KEY;
|
||||
} else {
|
||||
process.env.AWS_SECRET_ACCESS_KEY = previous.secret;
|
||||
}
|
||||
if (previous.profile === undefined) {
|
||||
delete process.env.AWS_PROFILE;
|
||||
} else {
|
||||
process.env.AWS_PROFILE = previous.profile;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
it("prefers Bedrock access keys over profile", async () => {
|
||||
const previous = {
|
||||
bearer: process.env.AWS_BEARER_TOKEN_BEDROCK,
|
||||
access: process.env.AWS_ACCESS_KEY_ID,
|
||||
secret: process.env.AWS_SECRET_ACCESS_KEY,
|
||||
profile: process.env.AWS_PROFILE,
|
||||
};
|
||||
|
||||
try {
|
||||
delete process.env.AWS_BEARER_TOKEN_BEDROCK;
|
||||
process.env.AWS_ACCESS_KEY_ID = "access-key";
|
||||
process.env.AWS_SECRET_ACCESS_KEY = "secret-key";
|
||||
process.env.AWS_PROFILE = "profile";
|
||||
|
||||
vi.resetModules();
|
||||
const { resolveApiKeyForProvider } = await import("./model-auth.js");
|
||||
|
||||
const resolved = await resolveApiKeyForProvider({
|
||||
provider: "amazon-bedrock",
|
||||
store: { version: 1, profiles: {} },
|
||||
cfg: {
|
||||
models: {
|
||||
providers: {
|
||||
"amazon-bedrock": {
|
||||
baseUrl: "https://bedrock-runtime.us-east-1.amazonaws.com",
|
||||
api: "bedrock-converse-stream",
|
||||
auth: "aws-sdk",
|
||||
models: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
} as never,
|
||||
});
|
||||
|
||||
expect(resolved.mode).toBe("aws-sdk");
|
||||
expect(resolved.apiKey).toBeUndefined();
|
||||
expect(resolved.source).toContain("AWS_ACCESS_KEY_ID");
|
||||
} finally {
|
||||
if (previous.bearer === undefined) {
|
||||
delete process.env.AWS_BEARER_TOKEN_BEDROCK;
|
||||
} else {
|
||||
process.env.AWS_BEARER_TOKEN_BEDROCK = previous.bearer;
|
||||
}
|
||||
if (previous.access === undefined) {
|
||||
delete process.env.AWS_ACCESS_KEY_ID;
|
||||
} else {
|
||||
process.env.AWS_ACCESS_KEY_ID = previous.access;
|
||||
}
|
||||
if (previous.secret === undefined) {
|
||||
delete process.env.AWS_SECRET_ACCESS_KEY;
|
||||
} else {
|
||||
process.env.AWS_SECRET_ACCESS_KEY = previous.secret;
|
||||
}
|
||||
if (previous.profile === undefined) {
|
||||
delete process.env.AWS_PROFILE;
|
||||
} else {
|
||||
process.env.AWS_PROFILE = previous.profile;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
it("uses Bedrock profile when access keys are missing", async () => {
|
||||
const previous = {
|
||||
bearer: process.env.AWS_BEARER_TOKEN_BEDROCK,
|
||||
access: process.env.AWS_ACCESS_KEY_ID,
|
||||
secret: process.env.AWS_SECRET_ACCESS_KEY,
|
||||
profile: process.env.AWS_PROFILE,
|
||||
};
|
||||
|
||||
try {
|
||||
delete process.env.AWS_BEARER_TOKEN_BEDROCK;
|
||||
delete process.env.AWS_ACCESS_KEY_ID;
|
||||
delete process.env.AWS_SECRET_ACCESS_KEY;
|
||||
process.env.AWS_PROFILE = "profile";
|
||||
|
||||
vi.resetModules();
|
||||
const { resolveApiKeyForProvider } = await import("./model-auth.js");
|
||||
|
||||
const resolved = await resolveApiKeyForProvider({
|
||||
provider: "amazon-bedrock",
|
||||
store: { version: 1, profiles: {} },
|
||||
cfg: {
|
||||
models: {
|
||||
providers: {
|
||||
"amazon-bedrock": {
|
||||
baseUrl: "https://bedrock-runtime.us-east-1.amazonaws.com",
|
||||
api: "bedrock-converse-stream",
|
||||
auth: "aws-sdk",
|
||||
models: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
} as never,
|
||||
});
|
||||
|
||||
expect(resolved.mode).toBe("aws-sdk");
|
||||
expect(resolved.apiKey).toBeUndefined();
|
||||
expect(resolved.source).toContain("AWS_PROFILE");
|
||||
} finally {
|
||||
if (previous.bearer === undefined) {
|
||||
delete process.env.AWS_BEARER_TOKEN_BEDROCK;
|
||||
} else {
|
||||
process.env.AWS_BEARER_TOKEN_BEDROCK = previous.bearer;
|
||||
}
|
||||
if (previous.access === undefined) {
|
||||
delete process.env.AWS_ACCESS_KEY_ID;
|
||||
} else {
|
||||
process.env.AWS_ACCESS_KEY_ID = previous.access;
|
||||
}
|
||||
if (previous.secret === undefined) {
|
||||
delete process.env.AWS_SECRET_ACCESS_KEY;
|
||||
} else {
|
||||
process.env.AWS_SECRET_ACCESS_KEY = previous.secret;
|
||||
}
|
||||
if (previous.profile === undefined) {
|
||||
delete process.env.AWS_PROFILE;
|
||||
} else {
|
||||
process.env.AWS_PROFILE = previous.profile;
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -2,7 +2,7 @@ import path from "node:path";
|
||||
|
||||
import { type Api, getEnvApiKey, type Model } from "@mariozechner/pi-ai";
|
||||
import type { ClawdbotConfig } from "../config/config.js";
|
||||
import type { ModelProviderConfig } from "../config/types.js";
|
||||
import type { ModelProviderAuthMode, ModelProviderConfig } from "../config/types.js";
|
||||
import { getShellEnvAppliedKeys } from "../infra/shell-env.js";
|
||||
import { formatCliCommand } from "../cli/command-format.js";
|
||||
import {
|
||||
@@ -17,16 +17,115 @@ import { normalizeProviderId } from "./model-selection.js";
|
||||
|
||||
export { ensureAuthProfileStore, resolveAuthProfileOrder } from "./auth-profiles.js";
|
||||
|
||||
const AWS_BEARER_ENV = "AWS_BEARER_TOKEN_BEDROCK";
|
||||
const AWS_ACCESS_KEY_ENV = "AWS_ACCESS_KEY_ID";
|
||||
const AWS_SECRET_KEY_ENV = "AWS_SECRET_ACCESS_KEY";
|
||||
const AWS_PROFILE_ENV = "AWS_PROFILE";
|
||||
|
||||
function resolveProviderConfig(
|
||||
cfg: ClawdbotConfig | undefined,
|
||||
provider: string,
|
||||
): ModelProviderConfig | undefined {
|
||||
const providers = cfg?.models?.providers ?? {};
|
||||
const direct = providers[provider] as ModelProviderConfig | undefined;
|
||||
if (direct) return direct;
|
||||
const normalized = normalizeProviderId(provider);
|
||||
if (normalized === provider) {
|
||||
const matched = Object.entries(providers).find(
|
||||
([key]) => normalizeProviderId(key) === normalized,
|
||||
);
|
||||
return matched?.[1] as ModelProviderConfig | undefined;
|
||||
}
|
||||
return (
|
||||
(providers[normalized] as ModelProviderConfig | undefined) ??
|
||||
(Object.entries(providers).find(
|
||||
([key]) => normalizeProviderId(key) === normalized,
|
||||
)?.[1] as ModelProviderConfig | undefined)
|
||||
);
|
||||
}
|
||||
|
||||
export function getCustomProviderApiKey(
|
||||
cfg: ClawdbotConfig | undefined,
|
||||
provider: string,
|
||||
): string | undefined {
|
||||
const providers = cfg?.models?.providers ?? {};
|
||||
const entry = providers[provider] as ModelProviderConfig | undefined;
|
||||
const entry = resolveProviderConfig(cfg, provider);
|
||||
const key = entry?.apiKey?.trim();
|
||||
return key || undefined;
|
||||
}
|
||||
|
||||
function resolveProviderAuthOverride(
|
||||
cfg: ClawdbotConfig | undefined,
|
||||
provider: string,
|
||||
): ModelProviderAuthMode | undefined {
|
||||
const entry = resolveProviderConfig(cfg, provider);
|
||||
const auth = entry?.auth;
|
||||
if (auth === "api-key" || auth === "aws-sdk" || auth === "oauth" || auth === "token") {
|
||||
return auth;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function resolveEnvSourceLabel(params: {
|
||||
applied: Set<string>;
|
||||
envVars: string[];
|
||||
label: string;
|
||||
}): string {
|
||||
const shellApplied = params.envVars.some((envVar) => params.applied.has(envVar));
|
||||
const prefix = shellApplied ? "shell env: " : "env: ";
|
||||
return `${prefix}${params.label}`;
|
||||
}
|
||||
|
||||
export function resolveAwsSdkEnvVarName(): string | undefined {
|
||||
if (process.env[AWS_BEARER_ENV]?.trim()) return AWS_BEARER_ENV;
|
||||
if (process.env[AWS_ACCESS_KEY_ENV]?.trim() && process.env[AWS_SECRET_KEY_ENV]?.trim()) {
|
||||
return AWS_ACCESS_KEY_ENV;
|
||||
}
|
||||
if (process.env[AWS_PROFILE_ENV]?.trim()) return AWS_PROFILE_ENV;
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function resolveAwsSdkAuthInfo(): { mode: "aws-sdk"; source: string } {
|
||||
const applied = new Set(getShellEnvAppliedKeys());
|
||||
if (process.env[AWS_BEARER_ENV]?.trim()) {
|
||||
return {
|
||||
mode: "aws-sdk",
|
||||
source: resolveEnvSourceLabel({
|
||||
applied,
|
||||
envVars: [AWS_BEARER_ENV],
|
||||
label: AWS_BEARER_ENV,
|
||||
}),
|
||||
};
|
||||
}
|
||||
if (process.env[AWS_ACCESS_KEY_ENV]?.trim() && process.env[AWS_SECRET_KEY_ENV]?.trim()) {
|
||||
return {
|
||||
mode: "aws-sdk",
|
||||
source: resolveEnvSourceLabel({
|
||||
applied,
|
||||
envVars: [AWS_ACCESS_KEY_ENV, AWS_SECRET_KEY_ENV],
|
||||
label: `${AWS_ACCESS_KEY_ENV} + ${AWS_SECRET_KEY_ENV}`,
|
||||
}),
|
||||
};
|
||||
}
|
||||
if (process.env[AWS_PROFILE_ENV]?.trim()) {
|
||||
return {
|
||||
mode: "aws-sdk",
|
||||
source: resolveEnvSourceLabel({
|
||||
applied,
|
||||
envVars: [AWS_PROFILE_ENV],
|
||||
label: AWS_PROFILE_ENV,
|
||||
}),
|
||||
};
|
||||
}
|
||||
return { mode: "aws-sdk", source: "aws-sdk default chain" };
|
||||
}
|
||||
|
||||
export type ResolvedProviderAuth = {
|
||||
apiKey?: string;
|
||||
profileId?: string;
|
||||
source: string;
|
||||
mode: "api-key" | "oauth" | "token" | "aws-sdk";
|
||||
};
|
||||
|
||||
export async function resolveApiKeyForProvider(params: {
|
||||
provider: string;
|
||||
cfg?: ClawdbotConfig;
|
||||
@@ -34,7 +133,7 @@ export async function resolveApiKeyForProvider(params: {
|
||||
preferredProfile?: string;
|
||||
store?: AuthProfileStore;
|
||||
agentDir?: string;
|
||||
}): Promise<{ apiKey: string; profileId?: string; source: string }> {
|
||||
}): Promise<ResolvedProviderAuth> {
|
||||
const { provider, cfg, profileId, preferredProfile } = params;
|
||||
const store = params.store ?? ensureAuthProfileStore(params.agentDir);
|
||||
|
||||
@@ -48,13 +147,20 @@ export async function resolveApiKeyForProvider(params: {
|
||||
if (!resolved) {
|
||||
throw new Error(`No credentials found for profile "${profileId}".`);
|
||||
}
|
||||
const mode = store.profiles[profileId]?.type;
|
||||
return {
|
||||
apiKey: resolved.apiKey,
|
||||
profileId,
|
||||
source: `profile:${profileId}`,
|
||||
mode: mode === "oauth" ? "oauth" : mode === "token" ? "token" : "api-key",
|
||||
};
|
||||
}
|
||||
|
||||
const authOverride = resolveProviderAuthOverride(cfg, provider);
|
||||
if (authOverride === "aws-sdk") {
|
||||
return resolveAwsSdkAuthInfo();
|
||||
}
|
||||
|
||||
const order = resolveAuthProfileOrder({
|
||||
cfg,
|
||||
store,
|
||||
@@ -70,10 +176,12 @@ export async function resolveApiKeyForProvider(params: {
|
||||
agentDir: params.agentDir,
|
||||
});
|
||||
if (resolved) {
|
||||
const mode = store.profiles[candidate]?.type;
|
||||
return {
|
||||
apiKey: resolved.apiKey,
|
||||
profileId: candidate,
|
||||
source: `profile:${candidate}`,
|
||||
mode: mode === "oauth" ? "oauth" : mode === "token" ? "token" : "api-key",
|
||||
};
|
||||
}
|
||||
} catch {}
|
||||
@@ -81,12 +189,21 @@ export async function resolveApiKeyForProvider(params: {
|
||||
|
||||
const envResolved = resolveEnvApiKey(provider);
|
||||
if (envResolved) {
|
||||
return { apiKey: envResolved.apiKey, source: envResolved.source };
|
||||
return {
|
||||
apiKey: envResolved.apiKey,
|
||||
source: envResolved.source,
|
||||
mode: envResolved.source.includes("OAUTH_TOKEN") ? "oauth" : "api-key",
|
||||
};
|
||||
}
|
||||
|
||||
const customKey = getCustomProviderApiKey(cfg, provider);
|
||||
if (customKey) {
|
||||
return { apiKey: customKey, source: "models.json" };
|
||||
return { apiKey: customKey, source: "models.json", mode: "api-key" };
|
||||
}
|
||||
|
||||
const normalized = normalizeProviderId(provider);
|
||||
if (authOverride === undefined && normalized === "amazon-bedrock") {
|
||||
return resolveAwsSdkAuthInfo();
|
||||
}
|
||||
|
||||
if (provider === "openai") {
|
||||
@@ -110,7 +227,7 @@ export async function resolveApiKeyForProvider(params: {
|
||||
}
|
||||
|
||||
export type EnvApiKeyResult = { apiKey: string; source: string };
|
||||
export type ModelAuthMode = "api-key" | "oauth" | "token" | "mixed" | "unknown";
|
||||
export type ModelAuthMode = "api-key" | "oauth" | "token" | "mixed" | "aws-sdk" | "unknown";
|
||||
|
||||
export function resolveEnvApiKey(provider: string): EnvApiKeyResult | null {
|
||||
const normalized = normalizeProviderId(provider);
|
||||
@@ -181,6 +298,9 @@ export function resolveModelAuthMode(
|
||||
const resolved = provider?.trim();
|
||||
if (!resolved) return undefined;
|
||||
|
||||
const authOverride = resolveProviderAuthOverride(cfg, resolved);
|
||||
if (authOverride === "aws-sdk") return "aws-sdk";
|
||||
|
||||
const authStore = store ?? ensureAuthProfileStore();
|
||||
const profiles = listProfilesForProvider(authStore, resolved);
|
||||
if (profiles.length > 0) {
|
||||
@@ -198,6 +318,10 @@ export function resolveModelAuthMode(
|
||||
if (modes.has("api_key")) return "api-key";
|
||||
}
|
||||
|
||||
if (authOverride === undefined && normalizeProviderId(resolved) === "amazon-bedrock") {
|
||||
return "aws-sdk";
|
||||
}
|
||||
|
||||
const envKey = resolveEnvApiKey(resolved);
|
||||
if (envKey?.apiKey) {
|
||||
return envKey.source.includes("OAUTH_TOKEN") ? "oauth" : "api-key";
|
||||
@@ -215,7 +339,7 @@ export async function getApiKeyForModel(params: {
|
||||
preferredProfile?: string;
|
||||
store?: AuthProfileStore;
|
||||
agentDir?: string;
|
||||
}): Promise<{ apiKey: string; profileId?: string; source: string }> {
|
||||
}): Promise<ResolvedProviderAuth> {
|
||||
return resolveApiKeyForProvider({
|
||||
provider: params.model.provider,
|
||||
cfg: params.cfg,
|
||||
@@ -225,3 +349,11 @@ export async function getApiKeyForModel(params: {
|
||||
agentDir: params.agentDir,
|
||||
});
|
||||
}
|
||||
|
||||
export function requireApiKey(auth: ResolvedProviderAuth, provider: string): string {
|
||||
const key = auth.apiKey?.trim();
|
||||
if (key) return key;
|
||||
throw new Error(
|
||||
`No API key resolved for provider "${provider}" (auth mode: ${auth.mode}).`,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import {
|
||||
resolveCopilotApiToken,
|
||||
} from "../providers/github-copilot-token.js";
|
||||
import { ensureAuthProfileStore, listProfilesForProvider } from "./auth-profiles.js";
|
||||
import { resolveEnvApiKey } from "./model-auth.js";
|
||||
import { resolveAwsSdkEnvVarName, resolveEnvApiKey } from "./model-auth.js";
|
||||
import {
|
||||
buildSyntheticModelDefinition,
|
||||
SYNTHETIC_BASE_URL,
|
||||
@@ -74,6 +74,10 @@ function resolveEnvApiKeyVarName(provider: string): string | undefined {
|
||||
return match ? match[1] : undefined;
|
||||
}
|
||||
|
||||
function resolveAwsSdkApiKeyVarName(): string {
|
||||
return resolveAwsSdkEnvVarName() ?? "AWS_PROFILE";
|
||||
}
|
||||
|
||||
function resolveApiKeyFromProfiles(params: {
|
||||
provider: string;
|
||||
store: ReturnType<typeof ensureAuthProfileStore>;
|
||||
@@ -138,15 +142,23 @@ export function normalizeProviders(params: {
|
||||
const hasModels =
|
||||
Array.isArray(normalizedProvider.models) && normalizedProvider.models.length > 0;
|
||||
if (hasModels && !normalizedProvider.apiKey?.trim()) {
|
||||
const fromEnv = resolveEnvApiKeyVarName(normalizedKey);
|
||||
const fromProfiles = resolveApiKeyFromProfiles({
|
||||
provider: normalizedKey,
|
||||
store: authStore,
|
||||
});
|
||||
const apiKey = fromEnv ?? fromProfiles;
|
||||
if (apiKey?.trim()) {
|
||||
const authMode =
|
||||
normalizedProvider.auth ?? (normalizedKey === "amazon-bedrock" ? "aws-sdk" : undefined);
|
||||
if (authMode === "aws-sdk") {
|
||||
const apiKey = resolveAwsSdkApiKeyVarName();
|
||||
mutated = true;
|
||||
normalizedProvider = { ...normalizedProvider, apiKey };
|
||||
} else {
|
||||
const fromEnv = resolveEnvApiKeyVarName(normalizedKey);
|
||||
const fromProfiles = resolveApiKeyFromProfiles({
|
||||
provider: normalizedKey,
|
||||
store: authStore,
|
||||
});
|
||||
const apiKey = fromEnv ?? fromProfiles;
|
||||
if (apiKey?.trim()) {
|
||||
mutated = true;
|
||||
normalizedProvider = { ...normalizedProvider, apiKey };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import {
|
||||
isAnthropicRateLimitError,
|
||||
} from "./live-auth-keys.js";
|
||||
import { isModernModelRef } from "./live-model-filter.js";
|
||||
import { getApiKeyForModel } from "./model-auth.js";
|
||||
import { getApiKeyForModel, requireApiKey } from "./model-auth.js";
|
||||
import { ensureClawdbotModelsJson } from "./models-config.js";
|
||||
import { isRateLimitErrorMessage } from "./pi-embedded-helpers/errors.js";
|
||||
|
||||
@@ -226,7 +226,7 @@ describeLive("live models (profile keys)", () => {
|
||||
const apiKey =
|
||||
model.provider === "anthropic" && anthropicKeys.length > 0
|
||||
? anthropicKeys[attempt]
|
||||
: apiKeyInfo.apiKey;
|
||||
: requireApiKey(apiKeyInfo, model.provider);
|
||||
try {
|
||||
// Special regression: OpenAI requires replayed `reasoning` items for tool-only turns.
|
||||
if (
|
||||
|
||||
@@ -115,7 +115,13 @@ export async function compactEmbeddedPiSession(params: {
|
||||
agentDir,
|
||||
});
|
||||
|
||||
if (model.provider === "github-copilot") {
|
||||
if (!apiKeyInfo.apiKey) {
|
||||
if (apiKeyInfo.mode !== "aws-sdk") {
|
||||
throw new Error(
|
||||
`No API key resolved for provider "${model.provider}" (auth mode: ${apiKeyInfo.mode}).`,
|
||||
);
|
||||
}
|
||||
} else if (model.provider === "github-copilot") {
|
||||
const { resolveCopilotApiToken } =
|
||||
await import("../../providers/github-copilot-token.js");
|
||||
const copilotToken = await resolveCopilotApiToken({
|
||||
|
||||
29
src/agents/pi-embedded-runner/model.test.ts
Normal file
29
src/agents/pi-embedded-runner/model.test.ts
Normal file
@@ -0,0 +1,29 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
|
||||
import { buildInlineProviderModels } from "./model.js";
|
||||
|
||||
const makeModel = (id: string) => ({
|
||||
id,
|
||||
name: id,
|
||||
reasoning: false,
|
||||
input: ["text"] as const,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 1,
|
||||
maxTokens: 1,
|
||||
});
|
||||
|
||||
describe("buildInlineProviderModels", () => {
|
||||
it("attaches provider ids to inline models", () => {
|
||||
const providers = {
|
||||
" alpha ": { models: [makeModel("alpha-model")] },
|
||||
beta: { models: [makeModel("beta-model")] },
|
||||
};
|
||||
|
||||
const result = buildInlineProviderModels(providers);
|
||||
|
||||
expect(result).toEqual([
|
||||
{ ...makeModel("alpha-model"), provider: "alpha" },
|
||||
{ ...makeModel("beta-model"), provider: "beta" },
|
||||
]);
|
||||
});
|
||||
});
|
||||
@@ -2,9 +2,23 @@ import type { Api, Model } from "@mariozechner/pi-ai";
|
||||
import { discoverAuthStorage, discoverModels } from "@mariozechner/pi-coding-agent";
|
||||
|
||||
import type { ClawdbotConfig } from "../../config/config.js";
|
||||
import type { ModelDefinitionConfig } from "../../config/types.js";
|
||||
import { resolveClawdbotAgentDir } from "../agent-paths.js";
|
||||
import { DEFAULT_CONTEXT_TOKENS } from "../defaults.js";
|
||||
import { normalizeModelCompat } from "../model-compat.js";
|
||||
import { normalizeProviderId } from "../model-selection.js";
|
||||
|
||||
type InlineModelEntry = ModelDefinitionConfig & { provider: string };
|
||||
|
||||
export function buildInlineProviderModels(
|
||||
providers: Record<string, { models?: ModelDefinitionConfig[] }>,
|
||||
): InlineModelEntry[] {
|
||||
return Object.entries(providers).flatMap(([providerId, entry]) => {
|
||||
const trimmed = providerId.trim();
|
||||
if (!trimmed) return [];
|
||||
return (entry?.models ?? []).map((model) => ({ ...model, provider: trimmed }));
|
||||
});
|
||||
}
|
||||
|
||||
export function buildModelAliasLines(cfg?: ClawdbotConfig) {
|
||||
const models = cfg?.agents?.defaults?.models ?? {};
|
||||
@@ -38,12 +52,12 @@ export function resolveModel(
|
||||
const model = modelRegistry.find(provider, modelId) as Model<Api> | null;
|
||||
if (!model) {
|
||||
const providers = cfg?.models?.providers ?? {};
|
||||
const inlineModels =
|
||||
providers[provider]?.models?.map((entry) => ({ ...entry, provider })) ??
|
||||
Object.values(providers)
|
||||
.flatMap((entry) => entry?.models ?? [])
|
||||
.map((entry) => ({ ...entry, provider }));
|
||||
const inlineMatch = inlineModels.find((entry) => entry.id === modelId);
|
||||
const inlineModels = buildInlineProviderModels(providers);
|
||||
const normalizedProvider = normalizeProviderId(provider);
|
||||
const inlineMatch = inlineModels.find(
|
||||
(entry) =>
|
||||
normalizeProviderId(entry.provider) === normalizedProvider && entry.id === modelId,
|
||||
);
|
||||
if (inlineMatch) {
|
||||
const normalized = normalizeModelCompat(inlineMatch as Model<Api>);
|
||||
return {
|
||||
|
||||
@@ -21,6 +21,7 @@ import {
|
||||
ensureAuthProfileStore,
|
||||
getApiKeyForModel,
|
||||
resolveAuthProfileOrder,
|
||||
type ResolvedProviderAuth,
|
||||
} from "../model-auth.js";
|
||||
import { ensureClawdbotModelsJson } from "../models-config.js";
|
||||
import {
|
||||
@@ -47,11 +48,7 @@ import { buildEmbeddedRunPayloads } from "./run/payloads.js";
|
||||
import type { EmbeddedPiAgentMeta, EmbeddedPiRunResult } from "./types.js";
|
||||
import { describeUnknownError } from "./utils.js";
|
||||
|
||||
type ApiKeyInfo = {
|
||||
apiKey: string;
|
||||
profileId?: string;
|
||||
source: string;
|
||||
};
|
||||
type ApiKeyInfo = ResolvedProviderAuth;
|
||||
|
||||
export async function runEmbeddedPiAgent(
|
||||
params: RunEmbeddedPiAgentParams,
|
||||
@@ -151,6 +148,15 @@ export async function runEmbeddedPiAgent(
|
||||
|
||||
const applyApiKeyInfo = async (candidate?: string): Promise<void> => {
|
||||
apiKeyInfo = await resolveApiKeyForCandidate(candidate);
|
||||
if (!apiKeyInfo.apiKey) {
|
||||
if (apiKeyInfo.mode !== "aws-sdk") {
|
||||
throw new Error(
|
||||
`No API key resolved for provider "${model.provider}" (auth mode: ${apiKeyInfo.mode}).`,
|
||||
);
|
||||
}
|
||||
lastProfileId = apiKeyInfo.profileId;
|
||||
return;
|
||||
}
|
||||
if (model.provider === "github-copilot") {
|
||||
const { resolveCopilotApiToken } =
|
||||
await import("../../providers/github-copilot-token.js");
|
||||
|
||||
@@ -17,7 +17,7 @@ import { loadWebMedia } from "../../web/media.js";
|
||||
import { ensureAuthProfileStore, listProfilesForProvider } from "../auth-profiles.js";
|
||||
import { DEFAULT_MODEL, DEFAULT_PROVIDER } from "../defaults.js";
|
||||
import { minimaxUnderstandImage } from "../minimax-vlm.js";
|
||||
import { getApiKeyForModel, resolveEnvApiKey } from "../model-auth.js";
|
||||
import { getApiKeyForModel, requireApiKey, resolveEnvApiKey } from "../model-auth.js";
|
||||
import { runWithImageModelFallback } from "../model-fallback.js";
|
||||
import { resolveConfiguredModelRef } from "../model-selection.js";
|
||||
import { ensureClawdbotModelsJson } from "../models-config.js";
|
||||
@@ -252,12 +252,13 @@ async function runImagePrompt(params: {
|
||||
cfg: effectiveCfg,
|
||||
agentDir: params.agentDir,
|
||||
});
|
||||
authStorage.setRuntimeApiKey(model.provider, apiKeyInfo.apiKey);
|
||||
const apiKey = requireApiKey(apiKeyInfo, model.provider);
|
||||
authStorage.setRuntimeApiKey(model.provider, apiKey);
|
||||
const imageDataUrl = `data:${params.mimeType};base64,${params.base64}`;
|
||||
|
||||
if (model.provider === "minimax") {
|
||||
const text = await minimaxUnderstandImage({
|
||||
apiKey: apiKeyInfo.apiKey,
|
||||
apiKey,
|
||||
prompt: params.prompt,
|
||||
imageDataUrl,
|
||||
modelBaseUrl: model.baseUrl,
|
||||
@@ -267,7 +268,7 @@ async function runImagePrompt(params: {
|
||||
|
||||
const context = buildImageContext(params.prompt, params.base64, params.mimeType);
|
||||
const message = (await complete(model, context, {
|
||||
apiKey: apiKeyInfo.apiKey,
|
||||
apiKey,
|
||||
maxTokens: 512,
|
||||
})) as AssistantMessage;
|
||||
const text = coerceImageAssistantText({
|
||||
|
||||
@@ -13,6 +13,8 @@ export type ModelCompatConfig = {
|
||||
maxTokensField?: "max_completion_tokens" | "max_tokens";
|
||||
};
|
||||
|
||||
export type ModelProviderAuthMode = "api-key" | "aws-sdk" | "oauth" | "token";
|
||||
|
||||
export type ModelDefinitionConfig = {
|
||||
id: string;
|
||||
name: string;
|
||||
@@ -34,6 +36,7 @@ export type ModelDefinitionConfig = {
|
||||
export type ModelProviderConfig = {
|
||||
baseUrl: string;
|
||||
apiKey?: string;
|
||||
auth?: ModelProviderAuthMode;
|
||||
api?: ModelApi;
|
||||
headers?: Record<string, string>;
|
||||
authHeader?: boolean;
|
||||
|
||||
@@ -49,6 +49,14 @@ export const ModelProviderSchema = z
|
||||
.object({
|
||||
baseUrl: z.string().min(1),
|
||||
apiKey: z.string().optional(),
|
||||
auth: z
|
||||
.union([
|
||||
z.literal("api-key"),
|
||||
z.literal("aws-sdk"),
|
||||
z.literal("oauth"),
|
||||
z.literal("token"),
|
||||
])
|
||||
.optional(),
|
||||
api: ModelApiSchema.optional(),
|
||||
headers: z.record(z.string(), z.string()).optional(),
|
||||
authHeader: z.boolean().optional(),
|
||||
|
||||
@@ -13,7 +13,12 @@ vi.mock("../agents/model-auth.js", () => ({
|
||||
resolveApiKeyForProvider: vi.fn(async () => ({
|
||||
apiKey: "test-key",
|
||||
source: "test",
|
||||
mode: "api-key",
|
||||
})),
|
||||
requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => {
|
||||
if (auth?.apiKey) return auth.apiKey;
|
||||
throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth?.mode}).`);
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock("../media/fetch.js", () => ({
|
||||
|
||||
@@ -2,7 +2,7 @@ import type { Api, AssistantMessage, Context, Model } from "@mariozechner/pi-ai"
|
||||
import { complete } from "@mariozechner/pi-ai";
|
||||
import { discoverAuthStorage, discoverModels } from "@mariozechner/pi-coding-agent";
|
||||
|
||||
import { getApiKeyForModel } from "../../agents/model-auth.js";
|
||||
import { getApiKeyForModel, requireApiKey } from "../../agents/model-auth.js";
|
||||
import { ensureClawdbotModelsJson } from "../../agents/models-config.js";
|
||||
import { minimaxUnderstandImage } from "../../agents/minimax-vlm.js";
|
||||
import { coerceImageAssistantText } from "../../agents/tools/image-tool.helpers.js";
|
||||
@@ -28,12 +28,13 @@ export async function describeImageWithModel(
|
||||
profileId: params.profile,
|
||||
preferredProfile: params.preferredProfile,
|
||||
});
|
||||
authStorage.setRuntimeApiKey(model.provider, apiKeyInfo.apiKey);
|
||||
const apiKey = requireApiKey(apiKeyInfo, model.provider);
|
||||
authStorage.setRuntimeApiKey(model.provider, apiKey);
|
||||
|
||||
const base64 = params.buffer.toString("base64");
|
||||
if (model.provider === "minimax") {
|
||||
const text = await minimaxUnderstandImage({
|
||||
apiKey: apiKeyInfo.apiKey,
|
||||
apiKey,
|
||||
prompt: params.prompt ?? "Describe the image.",
|
||||
imageDataUrl: `data:${params.mime ?? "image/jpeg"};base64,${base64}`,
|
||||
modelBaseUrl: model.baseUrl,
|
||||
@@ -54,7 +55,7 @@ export async function describeImageWithModel(
|
||||
],
|
||||
};
|
||||
const message = (await complete(model, context, {
|
||||
apiKey: apiKeyInfo.apiKey,
|
||||
apiKey,
|
||||
maxTokens: params.maxTokens ?? 512,
|
||||
})) as AssistantMessage;
|
||||
const text = coerceImageAssistantText({
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import type { ClawdbotConfig } from "../config/config.js";
|
||||
import type { MsgContext } from "../auto-reply/templating.js";
|
||||
import { applyTemplate } from "../auto-reply/templating.js";
|
||||
import { resolveApiKeyForProvider } from "../agents/model-auth.js";
|
||||
import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js";
|
||||
import { logVerbose, shouldLogVerbose } from "../globals.js";
|
||||
import { runExec } from "../process/exec.js";
|
||||
import type {
|
||||
@@ -300,13 +300,14 @@ async function runProviderEntry(params: {
|
||||
maxBytes,
|
||||
timeoutMs,
|
||||
});
|
||||
const key = await resolveApiKeyForProvider({
|
||||
const auth = await resolveApiKeyForProvider({
|
||||
provider: providerId,
|
||||
cfg,
|
||||
profileId: entry.profile,
|
||||
preferredProfile: entry.preferredProfile,
|
||||
agentDir: params.agentDir,
|
||||
});
|
||||
const apiKey = requireApiKey(auth, providerId);
|
||||
const providerConfig = cfg.models?.providers?.[providerId];
|
||||
const baseUrl = entry.baseUrl ?? params.config?.baseUrl ?? providerConfig?.baseUrl;
|
||||
const mergedHeaders = {
|
||||
@@ -325,7 +326,7 @@ async function runProviderEntry(params: {
|
||||
buffer: media.buffer,
|
||||
fileName: media.fileName,
|
||||
mime: media.mime,
|
||||
apiKey: key.apiKey,
|
||||
apiKey,
|
||||
baseUrl,
|
||||
headers,
|
||||
model,
|
||||
@@ -359,19 +360,20 @@ async function runProviderEntry(params: {
|
||||
`Video attachment ${params.attachmentIndex + 1} base64 payload ${estimatedBase64Bytes} exceeds ${maxBase64Bytes}`,
|
||||
);
|
||||
}
|
||||
const key = await resolveApiKeyForProvider({
|
||||
const auth = await resolveApiKeyForProvider({
|
||||
provider: providerId,
|
||||
cfg,
|
||||
profileId: entry.profile,
|
||||
preferredProfile: entry.preferredProfile,
|
||||
agentDir: params.agentDir,
|
||||
});
|
||||
const apiKey = requireApiKey(auth, providerId);
|
||||
const providerConfig = cfg.models?.providers?.[providerId];
|
||||
const result = await provider.describeVideo({
|
||||
buffer: media.buffer,
|
||||
fileName: media.fileName,
|
||||
mime: media.mime,
|
||||
apiKey: key.apiKey,
|
||||
apiKey,
|
||||
baseUrl: providerConfig?.baseUrl,
|
||||
headers: providerConfig?.headers,
|
||||
model: entry.model,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { resolveApiKeyForProvider } from "../agents/model-auth.js";
|
||||
import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js";
|
||||
import { isTruthyEnvValue } from "../infra/env.js";
|
||||
import { createSubsystemLogger } from "../logging/subsystem.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
@@ -115,13 +115,16 @@ export async function resolveGeminiEmbeddingClient(
|
||||
const remoteApiKey = resolveRemoteApiKey(remote?.apiKey);
|
||||
const remoteBaseUrl = remote?.baseUrl?.trim();
|
||||
|
||||
const { apiKey } = remoteApiKey
|
||||
? { apiKey: remoteApiKey }
|
||||
: await resolveApiKeyForProvider({
|
||||
provider: "google",
|
||||
cfg: options.config,
|
||||
agentDir: options.agentDir,
|
||||
});
|
||||
const apiKey = remoteApiKey
|
||||
? remoteApiKey
|
||||
: requireApiKey(
|
||||
await resolveApiKeyForProvider({
|
||||
provider: "google",
|
||||
cfg: options.config,
|
||||
agentDir: options.agentDir,
|
||||
}),
|
||||
"google",
|
||||
);
|
||||
|
||||
const providerConfig = options.config.models?.providers?.google;
|
||||
const rawBaseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GEMINI_BASE_URL;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { resolveApiKeyForProvider } from "../agents/model-auth.js";
|
||||
import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js";
|
||||
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
|
||||
|
||||
export type OpenAiEmbeddingClient = {
|
||||
@@ -62,13 +62,16 @@ export async function resolveOpenAiEmbeddingClient(
|
||||
const remoteApiKey = remote?.apiKey?.trim();
|
||||
const remoteBaseUrl = remote?.baseUrl?.trim();
|
||||
|
||||
const { apiKey } = remoteApiKey
|
||||
? { apiKey: remoteApiKey }
|
||||
: await resolveApiKeyForProvider({
|
||||
provider: "openai",
|
||||
cfg: options.config,
|
||||
agentDir: options.agentDir,
|
||||
});
|
||||
const apiKey = remoteApiKey
|
||||
? remoteApiKey
|
||||
: requireApiKey(
|
||||
await resolveApiKeyForProvider({
|
||||
provider: "openai",
|
||||
cfg: options.config,
|
||||
agentDir: options.agentDir,
|
||||
}),
|
||||
"openai",
|
||||
);
|
||||
|
||||
const providerConfig = options.config.models?.providers?.openai;
|
||||
const baseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_OPENAI_BASE_URL;
|
||||
|
||||
@@ -2,6 +2,10 @@ import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
vi.mock("../agents/model-auth.js", () => ({
|
||||
resolveApiKeyForProvider: vi.fn(),
|
||||
requireApiKey: (auth: { apiKey?: string; mode?: string }, provider: string) => {
|
||||
if (auth?.apiKey) return auth.apiKey;
|
||||
throw new Error(`No API key resolved for provider "${provider}" (auth mode: ${auth?.mode}).`);
|
||||
},
|
||||
}));
|
||||
|
||||
const createFetchMock = () =>
|
||||
@@ -26,6 +30,8 @@ describe("embedding provider remote overrides", () => {
|
||||
const authModule = await import("../agents/model-auth.js");
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
|
||||
apiKey: "provider-key",
|
||||
mode: "api-key",
|
||||
source: "test",
|
||||
});
|
||||
|
||||
const cfg = {
|
||||
@@ -78,6 +84,8 @@ describe("embedding provider remote overrides", () => {
|
||||
const authModule = await import("../agents/model-auth.js");
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
|
||||
apiKey: "provider-key",
|
||||
mode: "api-key",
|
||||
source: "test",
|
||||
});
|
||||
|
||||
const cfg = {
|
||||
@@ -120,6 +128,8 @@ describe("embedding provider remote overrides", () => {
|
||||
const authModule = await import("../agents/model-auth.js");
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
|
||||
apiKey: "provider-key",
|
||||
mode: "api-key",
|
||||
source: "test",
|
||||
});
|
||||
|
||||
const cfg = {
|
||||
@@ -166,7 +176,7 @@ describe("embedding provider auto selection", () => {
|
||||
const authModule = await import("../agents/model-auth.js");
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) => {
|
||||
if (provider === "openai") {
|
||||
return { apiKey: "openai-key", source: "env: OPENAI_API_KEY" };
|
||||
return { apiKey: "openai-key", source: "env: OPENAI_API_KEY", mode: "api-key" };
|
||||
}
|
||||
throw new Error(`No API key found for provider "${provider}".`);
|
||||
});
|
||||
@@ -190,7 +200,7 @@ describe("embedding provider auto selection", () => {
|
||||
throw new Error('No API key found for provider "openai".');
|
||||
}
|
||||
if (provider === "google") {
|
||||
return { apiKey: "gemini-key", source: "env: GEMINI_API_KEY" };
|
||||
return { apiKey: "gemini-key", source: "env: GEMINI_API_KEY", mode: "api-key" };
|
||||
}
|
||||
throw new Error(`Unexpected provider ${provider}`);
|
||||
});
|
||||
@@ -231,6 +241,8 @@ describe("embedding provider local fallback", () => {
|
||||
const authModule = await import("../agents/model-auth.js");
|
||||
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
|
||||
apiKey: "provider-key",
|
||||
mode: "api-key",
|
||||
source: "test",
|
||||
});
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
|
||||
Reference in New Issue
Block a user