feat: add dynamic Bedrock model discovery
Add automatic discovery of AWS Bedrock models using ListFoundationModels API. When AWS credentials are detected, models that support streaming and text output are automatically discovered and made available. - Add @aws-sdk/client-bedrock dependency - Add discoverBedrockModels() with caching (default 1 hour) - Add resolveImplicitBedrockProvider() for auto-registration - Add BedrockDiscoveryConfig for optional filtering by provider/region - Filter to active, streaming, text-output models only - Update docs/bedrock.md with auto-discovery documentation
This commit is contained in:
committed by
Peter Steinberger
parent
c66b1fd18b
commit
8effb557d5
96
src/agents/bedrock-discovery.test.ts
Normal file
96
src/agents/bedrock-discovery.test.ts
Normal file
@@ -0,0 +1,96 @@
|
||||
import type { BedrockClient } from "@aws-sdk/client-bedrock";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
const sendMock = vi.fn();
|
||||
const clientFactory = () => ({ send: sendMock }) as unknown as BedrockClient;
|
||||
|
||||
describe("bedrock discovery", () => {
|
||||
beforeEach(() => {
|
||||
sendMock.mockReset();
|
||||
});
|
||||
|
||||
it("filters to active streaming text models and maps modalities", async () => {
|
||||
const { discoverBedrockModels, resetBedrockDiscoveryCacheForTest } =
|
||||
await import("./bedrock-discovery.js");
|
||||
resetBedrockDiscoveryCacheForTest();
|
||||
|
||||
sendMock.mockResolvedValueOnce({
|
||||
modelSummaries: [
|
||||
{
|
||||
modelId: "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
modelName: "Claude 3.7 Sonnet",
|
||||
providerName: "anthropic",
|
||||
inputModalities: ["TEXT", "IMAGE"],
|
||||
outputModalities: ["TEXT"],
|
||||
responseStreamingSupported: true,
|
||||
modelLifecycle: { status: "ACTIVE" },
|
||||
},
|
||||
{
|
||||
modelId: "anthropic.claude-3-haiku-20240307-v1:0",
|
||||
modelName: "Claude 3 Haiku",
|
||||
providerName: "anthropic",
|
||||
inputModalities: ["TEXT"],
|
||||
outputModalities: ["TEXT"],
|
||||
responseStreamingSupported: false,
|
||||
modelLifecycle: { status: "ACTIVE" },
|
||||
},
|
||||
{
|
||||
modelId: "meta.llama3-8b-instruct-v1:0",
|
||||
modelName: "Llama 3 8B",
|
||||
providerName: "meta",
|
||||
inputModalities: ["TEXT"],
|
||||
outputModalities: ["TEXT"],
|
||||
responseStreamingSupported: true,
|
||||
modelLifecycle: { status: "INACTIVE" },
|
||||
},
|
||||
{
|
||||
modelId: "amazon.titan-embed-text-v1",
|
||||
modelName: "Titan Embed",
|
||||
providerName: "amazon",
|
||||
inputModalities: ["TEXT"],
|
||||
outputModalities: ["EMBEDDING"],
|
||||
responseStreamingSupported: true,
|
||||
modelLifecycle: { status: "ACTIVE" },
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const models = await discoverBedrockModels({ region: "us-east-1", clientFactory });
|
||||
expect(models).toHaveLength(1);
|
||||
expect(models[0]).toMatchObject({
|
||||
id: "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
name: "Claude 3.7 Sonnet",
|
||||
reasoning: false,
|
||||
input: ["text", "image"],
|
||||
contextWindow: 128000,
|
||||
maxTokens: 8192,
|
||||
});
|
||||
});
|
||||
|
||||
it("applies provider filter", async () => {
|
||||
const { discoverBedrockModels, resetBedrockDiscoveryCacheForTest } =
|
||||
await import("./bedrock-discovery.js");
|
||||
resetBedrockDiscoveryCacheForTest();
|
||||
|
||||
sendMock.mockResolvedValueOnce({
|
||||
modelSummaries: [
|
||||
{
|
||||
modelId: "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
modelName: "Claude 3.7 Sonnet",
|
||||
providerName: "anthropic",
|
||||
inputModalities: ["TEXT"],
|
||||
outputModalities: ["TEXT"],
|
||||
responseStreamingSupported: true,
|
||||
modelLifecycle: { status: "ACTIVE" },
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const models = await discoverBedrockModels({
|
||||
region: "us-east-1",
|
||||
config: { providerFilter: ["amazon"] },
|
||||
clientFactory,
|
||||
});
|
||||
expect(models).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
184
src/agents/bedrock-discovery.ts
Normal file
184
src/agents/bedrock-discovery.ts
Normal file
@@ -0,0 +1,184 @@
|
||||
import {
|
||||
BedrockClient,
|
||||
ListFoundationModelsCommand,
|
||||
type ListFoundationModelsCommandOutput,
|
||||
} from "@aws-sdk/client-bedrock";
|
||||
|
||||
import type { BedrockDiscoveryConfig, ModelDefinitionConfig } from "../config/types.js";
|
||||
|
||||
const DEFAULT_REFRESH_INTERVAL_SECONDS = 3600;
|
||||
const DEFAULT_CONTEXT_WINDOW = 128000;
|
||||
const DEFAULT_MAX_TOKENS = 8192;
|
||||
const DEFAULT_COST = {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
};
|
||||
|
||||
type BedrockModelSummary = NonNullable<ListFoundationModelsCommandOutput["modelSummaries"]>[number];
|
||||
|
||||
type BedrockDiscoveryCacheEntry = {
|
||||
expiresAt: number;
|
||||
value?: ModelDefinitionConfig[];
|
||||
inFlight?: Promise<ModelDefinitionConfig[]>;
|
||||
};
|
||||
|
||||
const discoveryCache = new Map<string, BedrockDiscoveryCacheEntry>();
|
||||
let hasLoggedBedrockError = false;
|
||||
|
||||
function normalizeProviderFilter(filter?: string[]): string[] {
|
||||
if (!filter || filter.length === 0) return [];
|
||||
const normalized = new Set(
|
||||
filter.map((entry) => entry.trim().toLowerCase()).filter((entry) => entry.length > 0),
|
||||
);
|
||||
return Array.from(normalized).sort();
|
||||
}
|
||||
|
||||
function buildCacheKey(params: {
|
||||
region: string;
|
||||
providerFilter: string[];
|
||||
refreshIntervalSeconds: number;
|
||||
}): string {
|
||||
return JSON.stringify(params);
|
||||
}
|
||||
|
||||
function includesTextModalities(modalities?: Array<string>): boolean {
|
||||
return (modalities ?? []).some((entry) => entry.toLowerCase() === "text");
|
||||
}
|
||||
|
||||
function isActive(summary: BedrockModelSummary): boolean {
|
||||
const status = summary.modelLifecycle?.status;
|
||||
return typeof status === "string" ? status.toUpperCase() === "ACTIVE" : false;
|
||||
}
|
||||
|
||||
function mapInputModalities(summary: BedrockModelSummary): Array<"text" | "image"> {
|
||||
const inputs = summary.inputModalities ?? [];
|
||||
const mapped = new Set<"text" | "image">();
|
||||
for (const modality of inputs) {
|
||||
const lower = modality.toLowerCase();
|
||||
if (lower === "text") mapped.add("text");
|
||||
if (lower === "image") mapped.add("image");
|
||||
}
|
||||
if (mapped.size === 0) mapped.add("text");
|
||||
return Array.from(mapped);
|
||||
}
|
||||
|
||||
function inferReasoningSupport(summary: BedrockModelSummary): boolean {
|
||||
const haystack = `${summary.modelId ?? ""} ${summary.modelName ?? ""}`.toLowerCase();
|
||||
return haystack.includes("reasoning") || haystack.includes("thinking");
|
||||
}
|
||||
|
||||
function inferContextWindow(): number {
|
||||
return DEFAULT_CONTEXT_WINDOW;
|
||||
}
|
||||
|
||||
function inferMaxTokens(): number {
|
||||
return DEFAULT_MAX_TOKENS;
|
||||
}
|
||||
|
||||
function matchesProviderFilter(summary: BedrockModelSummary, filter: string[]): boolean {
|
||||
if (filter.length === 0) return true;
|
||||
const providerName =
|
||||
summary.providerName ??
|
||||
(typeof summary.modelId === "string" ? summary.modelId.split(".")[0] : undefined);
|
||||
const normalized = providerName?.trim().toLowerCase();
|
||||
if (!normalized) return false;
|
||||
return filter.includes(normalized);
|
||||
}
|
||||
|
||||
function shouldIncludeSummary(summary: BedrockModelSummary, filter: string[]): boolean {
|
||||
if (!summary.modelId?.trim()) return false;
|
||||
if (!matchesProviderFilter(summary, filter)) return false;
|
||||
if (summary.responseStreamingSupported !== true) return false;
|
||||
if (!includesTextModalities(summary.outputModalities)) return false;
|
||||
if (!isActive(summary)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
function toModelDefinition(summary: BedrockModelSummary): ModelDefinitionConfig {
|
||||
const id = summary.modelId?.trim() ?? "";
|
||||
return {
|
||||
id,
|
||||
name: summary.modelName?.trim() || id,
|
||||
reasoning: inferReasoningSupport(summary),
|
||||
input: mapInputModalities(summary),
|
||||
cost: DEFAULT_COST,
|
||||
contextWindow: inferContextWindow(),
|
||||
maxTokens: inferMaxTokens(),
|
||||
};
|
||||
}
|
||||
|
||||
export function resetBedrockDiscoveryCacheForTest(): void {
|
||||
discoveryCache.clear();
|
||||
hasLoggedBedrockError = false;
|
||||
}
|
||||
|
||||
export async function discoverBedrockModels(params: {
|
||||
region: string;
|
||||
config?: BedrockDiscoveryConfig;
|
||||
now?: () => number;
|
||||
clientFactory?: (region: string) => BedrockClient;
|
||||
}): Promise<ModelDefinitionConfig[]> {
|
||||
const refreshIntervalSeconds = Math.max(
|
||||
0,
|
||||
Math.floor(params.config?.refreshInterval ?? DEFAULT_REFRESH_INTERVAL_SECONDS),
|
||||
);
|
||||
const providerFilter = normalizeProviderFilter(params.config?.providerFilter);
|
||||
const cacheKey = buildCacheKey({
|
||||
region: params.region,
|
||||
providerFilter,
|
||||
refreshIntervalSeconds,
|
||||
});
|
||||
const now = params.now?.() ?? Date.now();
|
||||
|
||||
if (refreshIntervalSeconds > 0) {
|
||||
const cached = discoveryCache.get(cacheKey);
|
||||
if (cached?.value && cached.expiresAt > now) {
|
||||
return cached.value;
|
||||
}
|
||||
if (cached?.inFlight) {
|
||||
return cached.inFlight;
|
||||
}
|
||||
}
|
||||
|
||||
const clientFactory = params.clientFactory ?? ((region: string) => new BedrockClient({ region }));
|
||||
const client = clientFactory(params.region);
|
||||
|
||||
const discoveryPromise = (async () => {
|
||||
const response = await client.send(new ListFoundationModelsCommand({}));
|
||||
const discovered: ModelDefinitionConfig[] = [];
|
||||
for (const summary of response.modelSummaries ?? []) {
|
||||
if (!shouldIncludeSummary(summary, providerFilter)) continue;
|
||||
discovered.push(toModelDefinition(summary));
|
||||
}
|
||||
return discovered.sort((a, b) => a.name.localeCompare(b.name));
|
||||
})();
|
||||
|
||||
if (refreshIntervalSeconds > 0) {
|
||||
discoveryCache.set(cacheKey, {
|
||||
expiresAt: now + refreshIntervalSeconds * 1000,
|
||||
inFlight: discoveryPromise,
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
const value = await discoveryPromise;
|
||||
if (refreshIntervalSeconds > 0) {
|
||||
discoveryCache.set(cacheKey, {
|
||||
expiresAt: now + refreshIntervalSeconds * 1000,
|
||||
value,
|
||||
});
|
||||
}
|
||||
return value;
|
||||
} catch (error) {
|
||||
if (refreshIntervalSeconds > 0) {
|
||||
discoveryCache.delete(cacheKey);
|
||||
}
|
||||
if (!hasLoggedBedrockError) {
|
||||
hasLoggedBedrockError = true;
|
||||
console.warn(`[bedrock-discovery] Failed to list models: ${String(error)}`);
|
||||
}
|
||||
return [];
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
} from "../providers/github-copilot-token.js";
|
||||
import { ensureAuthProfileStore, listProfilesForProvider } from "./auth-profiles.js";
|
||||
import { resolveAwsSdkEnvVarName, resolveEnvApiKey } from "./model-auth.js";
|
||||
import { discoverBedrockModels } from "./bedrock-discovery.js";
|
||||
import {
|
||||
buildSyntheticModelDefinition,
|
||||
SYNTHETIC_BASE_URL,
|
||||
@@ -375,3 +376,27 @@ export async function resolveImplicitCopilotProvider(params: {
|
||||
models: [],
|
||||
} satisfies ProviderConfig;
|
||||
}
|
||||
|
||||
export async function resolveImplicitBedrockProvider(params: {
|
||||
agentDir: string;
|
||||
config?: ClawdbotConfig;
|
||||
env?: NodeJS.ProcessEnv;
|
||||
}): Promise<ProviderConfig | null> {
|
||||
const env = params.env ?? process.env;
|
||||
const discoveryConfig = params.config?.models?.bedrockDiscovery;
|
||||
const enabled = discoveryConfig?.enabled;
|
||||
const hasAwsCreds = resolveAwsSdkEnvVarName() !== undefined;
|
||||
if (enabled === false) return null;
|
||||
if (enabled !== true && !hasAwsCreds) return null;
|
||||
|
||||
const region = discoveryConfig?.region ?? env.AWS_REGION ?? env.AWS_DEFAULT_REGION ?? "us-east-1";
|
||||
const models = await discoverBedrockModels({ region, config: discoveryConfig });
|
||||
if (models.length === 0) return null;
|
||||
|
||||
return {
|
||||
baseUrl: `https://bedrock-runtime.${region}.amazonaws.com`,
|
||||
api: "bedrock-converse-stream",
|
||||
auth: "aws-sdk",
|
||||
models,
|
||||
} satisfies ProviderConfig;
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import { resolveClawdbotAgentDir } from "./agent-paths.js";
|
||||
import {
|
||||
normalizeProviders,
|
||||
type ProviderConfig,
|
||||
resolveImplicitBedrockProvider,
|
||||
resolveImplicitCopilotProvider,
|
||||
resolveImplicitProviders,
|
||||
} from "./models-config.providers.js";
|
||||
@@ -84,6 +85,13 @@ export async function ensureClawdbotModelsJson(
|
||||
implicit: implicitProviders,
|
||||
explicit: explicitProviders,
|
||||
});
|
||||
const implicitBedrock = await resolveImplicitBedrockProvider({ agentDir, config: cfg });
|
||||
if (implicitBedrock) {
|
||||
const existing = providers["amazon-bedrock"];
|
||||
providers["amazon-bedrock"] = existing
|
||||
? mergeProviderModels(implicitBedrock, existing)
|
||||
: implicitBedrock;
|
||||
}
|
||||
const implicitCopilot = await resolveImplicitCopilotProvider({ agentDir });
|
||||
if (implicitCopilot && !providers["github-copilot"]) {
|
||||
providers["github-copilot"] = implicitCopilot;
|
||||
|
||||
@@ -11,6 +11,7 @@ const resolveAuthStorePathForDisplay = vi
|
||||
.mockReturnValue("/tmp/clawdbot-agent/auth-profiles.json");
|
||||
const resolveProfileUnusableUntilForDisplay = vi.fn().mockReturnValue(null);
|
||||
const resolveEnvApiKey = vi.fn().mockReturnValue(undefined);
|
||||
const resolveAwsSdkEnvVarName = vi.fn().mockReturnValue(undefined);
|
||||
const getCustomProviderApiKey = vi.fn().mockReturnValue(undefined);
|
||||
const discoverAuthStorage = vi.fn().mockReturnValue({});
|
||||
const discoverModels = vi.fn();
|
||||
@@ -39,6 +40,7 @@ vi.mock("../agents/auth-profiles.js", () => ({
|
||||
|
||||
vi.mock("../agents/model-auth.js", () => ({
|
||||
resolveEnvApiKey,
|
||||
resolveAwsSdkEnvVarName,
|
||||
getCustomProviderApiKey,
|
||||
}));
|
||||
|
||||
|
||||
@@ -4,7 +4,11 @@ import { discoverAuthStorage, discoverModels } from "@mariozechner/pi-coding-age
|
||||
import { resolveClawdbotAgentDir } from "../../agents/agent-paths.js";
|
||||
import type { AuthProfileStore } from "../../agents/auth-profiles.js";
|
||||
import { listProfilesForProvider } from "../../agents/auth-profiles.js";
|
||||
import { getCustomProviderApiKey, resolveEnvApiKey } from "../../agents/model-auth.js";
|
||||
import {
|
||||
getCustomProviderApiKey,
|
||||
resolveAwsSdkEnvVarName,
|
||||
resolveEnvApiKey,
|
||||
} from "../../agents/model-auth.js";
|
||||
import { ensureClawdbotModelsJson } from "../../agents/models-config.js";
|
||||
import type { ClawdbotConfig } from "../../config/config.js";
|
||||
import type { ModelRow } from "./list.types.js";
|
||||
@@ -28,6 +32,7 @@ const isLocalBaseUrl = (baseUrl: string) => {
|
||||
|
||||
const hasAuthForProvider = (provider: string, cfg: ClawdbotConfig, authStore: AuthProfileStore) => {
|
||||
if (listProfilesForProvider(authStore, provider).length > 0) return true;
|
||||
if (provider === "amazon-bedrock" && resolveAwsSdkEnvVarName()) return true;
|
||||
if (resolveEnvApiKey(provider)) return true;
|
||||
if (getCustomProviderApiKey(cfg, provider)) return true;
|
||||
return false;
|
||||
|
||||
@@ -43,7 +43,15 @@ export type ModelProviderConfig = {
|
||||
models: ModelDefinitionConfig[];
|
||||
};
|
||||
|
||||
export type BedrockDiscoveryConfig = {
|
||||
enabled?: boolean;
|
||||
region?: string;
|
||||
providerFilter?: string[];
|
||||
refreshInterval?: number;
|
||||
};
|
||||
|
||||
export type ModelsConfig = {
|
||||
mode?: "merge" | "replace";
|
||||
providers?: Record<string, ModelProviderConfig>;
|
||||
bedrockDiscovery?: BedrockDiscoveryConfig;
|
||||
};
|
||||
|
||||
@@ -59,10 +59,21 @@ export const ModelProviderSchema = z
|
||||
})
|
||||
.strict();
|
||||
|
||||
export const BedrockDiscoverySchema = z
|
||||
.object({
|
||||
enabled: z.boolean().optional(),
|
||||
region: z.string().optional(),
|
||||
providerFilter: z.array(z.string()).optional(),
|
||||
refreshInterval: z.number().int().nonnegative().optional(),
|
||||
})
|
||||
.strict()
|
||||
.optional();
|
||||
|
||||
export const ModelsConfigSchema = z
|
||||
.object({
|
||||
mode: z.union([z.literal("merge"), z.literal("replace")]).optional(),
|
||||
providers: z.record(z.string(), ModelProviderSchema).optional(),
|
||||
bedrockDiscovery: BedrockDiscoverySchema,
|
||||
})
|
||||
.strict()
|
||||
.optional();
|
||||
|
||||
Reference in New Issue
Block a user