feat(memory): add gemini embeddings + auto select providers

Co-authored-by: Gustavo Madeira Santana <gumadeiras@gmail.com>
This commit is contained in:
Peter Steinberger
2026-01-18 15:29:16 +00:00
parent 7252938339
commit be7191879a
11 changed files with 536 additions and 352 deletions

View File

@@ -0,0 +1,145 @@
import { resolveApiKeyForProvider } from "../agents/model-auth.js";
import { createSubsystemLogger } from "../logging.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
export type GeminiEmbeddingClient = {
baseUrl: string;
headers: Record<string, string>;
model: string;
modelPath: string;
};
const DEFAULT_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta";
export const DEFAULT_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001";
const debugEmbeddings = process.env.CLAWDBOT_DEBUG_MEMORY_EMBEDDINGS === "1";
const log = createSubsystemLogger("memory/embeddings");
const debugLog = (message: string, meta?: Record<string, unknown>) => {
if (!debugEmbeddings) return;
const suffix = meta ? ` ${JSON.stringify(meta)}` : "";
log.raw(`${message}${suffix}`);
};
function resolveRemoteApiKey(remoteApiKey?: string): string | undefined {
const trimmed = remoteApiKey?.trim();
if (!trimmed) return undefined;
if (trimmed === "GOOGLE_API_KEY" || trimmed === "GEMINI_API_KEY") {
return process.env[trimmed]?.trim();
}
return trimmed;
}
function normalizeGeminiModel(model: string): string {
const trimmed = model.trim();
if (!trimmed) return DEFAULT_GEMINI_EMBEDDING_MODEL;
const withoutPrefix = trimmed.replace(/^models\//, "");
if (withoutPrefix.startsWith("gemini/")) return withoutPrefix.slice("gemini/".length);
if (withoutPrefix.startsWith("google/")) return withoutPrefix.slice("google/".length);
return withoutPrefix;
}
function normalizeGeminiBaseUrl(raw: string): string {
const trimmed = raw.replace(/\/+$/, "");
const openAiIndex = trimmed.indexOf("/openai");
if (openAiIndex > -1) return trimmed.slice(0, openAiIndex);
return trimmed;
}
function buildGeminiModelPath(model: string): string {
return model.startsWith("models/") ? model : `models/${model}`;
}
export async function createGeminiEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: GeminiEmbeddingClient }> {
const client = await resolveGeminiEmbeddingClient(options);
const baseUrl = client.baseUrl.replace(/\/$/, "");
const embedUrl = `${baseUrl}/${client.modelPath}:embedContent`;
const batchUrl = `${baseUrl}/${client.modelPath}:batchEmbedContents`;
const embedQuery = async (text: string): Promise<number[]> => {
if (!text.trim()) return [];
const res = await fetch(embedUrl, {
method: "POST",
headers: client.headers,
body: JSON.stringify({
content: { parts: [{ text }] },
taskType: "RETRIEVAL_QUERY",
}),
});
if (!res.ok) {
const payload = await res.text();
throw new Error(`gemini embeddings failed: ${res.status} ${payload}`);
}
const payload = (await res.json()) as { embedding?: { values?: number[] } };
return payload.embedding?.values ?? [];
};
const embedBatch = async (texts: string[]): Promise<number[][]> => {
if (texts.length === 0) return [];
const requests = texts.map((text) => ({
model: client.modelPath,
content: { parts: [{ text }] },
taskType: "RETRIEVAL_DOCUMENT",
}));
const res = await fetch(batchUrl, {
method: "POST",
headers: client.headers,
body: JSON.stringify({ requests }),
});
if (!res.ok) {
const payload = await res.text();
throw new Error(`gemini embeddings failed: ${res.status} ${payload}`);
}
const payload = (await res.json()) as { embeddings?: Array<{ values?: number[] }> };
const embeddings = Array.isArray(payload.embeddings) ? payload.embeddings : [];
return texts.map((_, index) => embeddings[index]?.values ?? []);
};
return {
provider: {
id: "gemini",
model: client.model,
embedQuery,
embedBatch,
},
client,
};
}
export async function resolveGeminiEmbeddingClient(
options: EmbeddingProviderOptions,
): Promise<GeminiEmbeddingClient> {
const remote = options.remote;
const remoteApiKey = resolveRemoteApiKey(remote?.apiKey);
const remoteBaseUrl = remote?.baseUrl?.trim();
const { apiKey } = remoteApiKey
? { apiKey: remoteApiKey }
: await resolveApiKeyForProvider({
provider: "google",
cfg: options.config,
agentDir: options.agentDir,
});
const providerConfig = options.config.models?.providers?.google;
const rawBaseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GEMINI_BASE_URL;
const baseUrl = normalizeGeminiBaseUrl(rawBaseUrl);
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
const headers: Record<string, string> = {
"Content-Type": "application/json",
"x-goog-api-key": apiKey,
...headerOverrides,
};
const model = normalizeGeminiModel(options.model);
const modelPath = buildGeminiModelPath(model);
debugLog("memory embeddings: gemini client", {
rawBaseUrl,
baseUrl,
model,
modelPath,
embedEndpoint: `${baseUrl}/${modelPath}:embedContent`,
batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`,
});
return { baseUrl, headers, model, modelPath };
}

View File

@@ -0,0 +1,83 @@
import { resolveApiKeyForProvider } from "../agents/model-auth.js";
import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js";
export type OpenAiEmbeddingClient = {
baseUrl: string;
headers: Record<string, string>;
model: string;
};
export const DEFAULT_OPENAI_EMBEDDING_MODEL = "text-embedding-3-small";
const DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1";
export function normalizeOpenAiModel(model: string): string {
const trimmed = model.trim();
if (!trimmed) return DEFAULT_OPENAI_EMBEDDING_MODEL;
if (trimmed.startsWith("openai/")) return trimmed.slice("openai/".length);
return trimmed;
}
export async function createOpenAiEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: OpenAiEmbeddingClient }> {
const client = await resolveOpenAiEmbeddingClient(options);
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
const embed = async (input: string[]): Promise<number[][]> => {
if (input.length === 0) return [];
const res = await fetch(url, {
method: "POST",
headers: client.headers,
body: JSON.stringify({ model: client.model, input }),
});
if (!res.ok) {
const text = await res.text();
throw new Error(`openai embeddings failed: ${res.status} ${text}`);
}
const payload = (await res.json()) as {
data?: Array<{ embedding?: number[] }>;
};
const data = payload.data ?? [];
return data.map((entry) => entry.embedding ?? []);
};
return {
provider: {
id: "openai",
model: client.model,
embedQuery: async (text) => {
const [vec] = await embed([text]);
return vec ?? [];
},
embedBatch: embed,
},
client,
};
}
export async function resolveOpenAiEmbeddingClient(
options: EmbeddingProviderOptions,
): Promise<OpenAiEmbeddingClient> {
const remote = options.remote;
const remoteApiKey = remote?.apiKey?.trim();
const remoteBaseUrl = remote?.baseUrl?.trim();
const { apiKey } = remoteApiKey
? { apiKey: remoteApiKey }
: await resolveApiKeyForProvider({
provider: "openai",
cfg: options.config,
agentDir: options.agentDir,
});
const providerConfig = options.config.models?.providers?.openai;
const baseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_OPENAI_BASE_URL;
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
const headers: Record<string, string> = {
"Content-Type": "application/json",
Authorization: `Bearer ${apiKey}`,
...headerOverrides,
};
const model = normalizeOpenAiModel(options.model);
return { baseUrl, headers, model };
}

View File

@@ -108,7 +108,7 @@ describe("embedding provider remote overrides", () => {
expect(headers.Authorization).toBe("Bearer provider-key");
});
it("uses gemini embedContent endpoint with x-goog-api-key", async () => {
it("builds Gemini embeddings requests with api key header", async () => {
const fetchMock = vi.fn(async () => ({
ok: true,
status: 200,
@@ -119,29 +119,94 @@ describe("embedding provider remote overrides", () => {
const { createEmbeddingProvider } = await import("./embeddings.js");
const authModule = await import("../agents/model-auth.js");
vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({
apiKey: "gemini-key",
apiKey: "provider-key",
});
const cfg = {
models: {
providers: {
google: {
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
},
},
},
};
const result = await createEmbeddingProvider({
config: {} as never,
config: cfg as never,
provider: "gemini",
remote: {
baseUrl: "https://gemini.example/v1beta",
apiKey: "gemini-key",
},
model: "gemini-embedding-001",
model: "text-embedding-004",
fallback: "openai",
});
await result.provider.embedQuery("hello");
const [url, init] = fetchMock.mock.calls[0] ?? [];
expect(url).toBe("https://gemini.example/v1beta/models/gemini-embedding-001:embedContent");
expect(url).toBe(
"https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent",
);
const headers = (init?.headers ?? {}) as Record<string, string>;
expect(headers["x-goog-api-key"]).toBe("gemini-key");
expect(headers["Content-Type"]).toBe("application/json");
});
});
describe("embedding provider auto selection", () => {
afterEach(() => {
vi.resetAllMocks();
vi.resetModules();
vi.unstubAllGlobals();
});
it("prefers openai when a key resolves", async () => {
const { createEmbeddingProvider } = await import("./embeddings.js");
const authModule = await import("../agents/model-auth.js");
vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) => {
if (provider === "openai") {
return { apiKey: "openai-key", source: "env: OPENAI_API_KEY" };
}
throw new Error(`No API key found for provider "${provider}".`);
});
const result = await createEmbeddingProvider({
config: {} as never,
provider: "auto",
model: "",
fallback: "none",
});
expect(result.requestedProvider).toBe("auto");
expect(result.provider.id).toBe("openai");
});
it("uses gemini when openai is missing", async () => {
const { createEmbeddingProvider } = await import("./embeddings.js");
const authModule = await import("../agents/model-auth.js");
vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) => {
if (provider === "openai") {
throw new Error('No API key found for provider "openai".');
}
if (provider === "google") {
return { apiKey: "gemini-key", source: "env: GEMINI_API_KEY" };
}
throw new Error(`Unexpected provider ${provider}`);
});
const result = await createEmbeddingProvider({
config: {} as never,
provider: "auto",
model: "",
fallback: "none",
});
expect(result.requestedProvider).toBe("auto");
expect(result.provider.id).toBe("gemini");
});
});
describe("embedding provider local fallback", () => {
afterEach(() => {
vi.resetAllMocks();

View File

@@ -1,8 +1,21 @@
import fsSync from "node:fs";
import type { Llama, LlamaEmbeddingContext, LlamaModel } from "node-llama-cpp";
import { resolveApiKeyForProvider } from "../agents/model-auth.js";
import type { ClawdbotConfig } from "../config/config.js";
import { resolveUserPath } from "../utils.js";
import {
createGeminiEmbeddingProvider,
type GeminiEmbeddingClient,
} from "./embeddings-gemini.js";
import {
createOpenAiEmbeddingProvider,
type OpenAiEmbeddingClient,
} from "./embeddings-openai.js";
import { importNodeLlamaCpp } from "./node-llama.js";
export type { GeminiEmbeddingClient } from "./embeddings-gemini.js";
export type { OpenAiEmbeddingClient } from "./embeddings-openai.js";
export type EmbeddingProvider = {
id: string;
model: string;
@@ -12,230 +25,49 @@ export type EmbeddingProvider = {
export type EmbeddingProviderResult = {
provider: EmbeddingProvider;
requestedProvider: "openai" | "gemini" | "local";
fallbackFrom?: "local";
requestedProvider: "openai" | "local" | "gemini" | "auto";
fallbackFrom?: "openai" | "local" | "gemini";
fallbackReason?: string;
openAi?: OpenAiEmbeddingClient;
gemini?: GeminiEmbeddingClient;
};
export type OpenAiEmbeddingClient = {
baseUrl: string;
headers: Record<string, string>;
model: string;
};
export type GeminiEmbeddingClient = {
baseUrl: string;
headers: Record<string, string>;
model: string;
};
export type EmbeddingProviderOptions = {
config: ClawdbotConfig;
agentDir?: string;
provider: "openai" | "gemini" | "local";
provider: "openai" | "local" | "gemini" | "auto";
remote?: {
baseUrl?: string;
apiKey?: string;
headers?: Record<string, string>;
};
model: string;
fallback: "openai" | "none";
fallback: "openai" | "gemini" | "local" | "none";
local?: {
modelPath?: string;
modelCacheDir?: string;
};
};
const DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1";
const DEFAULT_LOCAL_MODEL = "hf:ggml-org/embeddinggemma-300M-GGUF/embeddinggemma-300M-Q8_0.gguf";
const DEFAULT_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta";
const DEFAULT_GEMINI_MODEL = "gemini-embedding-001";
function normalizeOpenAiModel(model: string): string {
const trimmed = model.trim();
if (!trimmed) return "text-embedding-3-small";
if (trimmed.startsWith("openai/")) return trimmed.slice("openai/".length);
return trimmed;
}
function normalizeGeminiModel(model: string): string {
const trimmed = model.trim();
if (!trimmed) return DEFAULT_GEMINI_MODEL;
if (trimmed.startsWith("models/")) return trimmed.slice("models/".length);
if (trimmed.startsWith("google/")) return trimmed.slice("google/".length);
return trimmed;
}
async function createOpenAiEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: OpenAiEmbeddingClient }> {
const client = await resolveOpenAiEmbeddingClient(options);
const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`;
const embed = async (input: string[]): Promise<number[][]> => {
if (input.length === 0) return [];
const res = await fetch(url, {
method: "POST",
headers: client.headers,
body: JSON.stringify({ model: client.model, input }),
});
if (!res.ok) {
const text = await res.text();
throw new Error(`openai embeddings failed: ${res.status} ${text}`);
}
const payload = (await res.json()) as {
data?: Array<{ embedding?: number[] }>;
};
const data = payload.data ?? [];
return data.map((entry) => entry.embedding ?? []);
};
return {
provider: {
id: "openai",
model: client.model,
embedQuery: async (text) => {
const [vec] = await embed([text]);
return vec ?? [];
},
embedBatch: embed,
},
client,
};
}
function extractGeminiEmbeddingValues(entry: unknown): number[] {
if (!entry || typeof entry !== "object") return [];
const record = entry as { values?: unknown; embedding?: { values?: unknown } };
const values = record.values ?? record.embedding?.values;
if (!Array.isArray(values)) return [];
return values.filter((value): value is number => typeof value === "number");
}
function parseGeminiEmbeddings(payload: unknown): number[][] {
if (!payload || typeof payload !== "object") return [];
const data = payload as { embedding?: unknown; embeddings?: unknown[] };
if (Array.isArray(data.embeddings)) {
return data.embeddings.map((entry) => extractGeminiEmbeddingValues(entry));
function canAutoSelectLocal(options: EmbeddingProviderOptions): boolean {
const modelPath = options.local?.modelPath?.trim();
if (!modelPath) return false;
if (/^(hf:|https?:)/i.test(modelPath)) return false;
const resolved = resolveUserPath(modelPath);
try {
return fsSync.statSync(resolved).isFile();
} catch {
return false;
}
if (data.embedding) {
return [extractGeminiEmbeddingValues(data.embedding)];
}
return [];
}
async function createGeminiEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: GeminiEmbeddingClient }> {
const client = await resolveGeminiEmbeddingClient(options);
const baseUrl = client.baseUrl.replace(/\/$/, "");
const model = `models/${client.model}`;
const embedContent = async (input: string): Promise<number[]> => {
const res = await fetch(`${baseUrl}/${model}:embedContent`, {
method: "POST",
headers: client.headers,
body: JSON.stringify({
model,
content: { parts: [{ text: input }] },
}),
});
if (!res.ok) {
const text = await res.text();
throw new Error(`gemini embeddings failed: ${res.status} ${text}`);
}
const payload = await res.json();
const embeddings = parseGeminiEmbeddings(payload);
return embeddings[0] ?? [];
};
const embedBatch = async (input: string[]): Promise<number[][]> => {
if (input.length === 0) return [];
const res = await fetch(`${baseUrl}/${model}:batchEmbedContents`, {
method: "POST",
headers: client.headers,
body: JSON.stringify({
requests: input.map((text) => ({
model,
content: { parts: [{ text }] },
})),
}),
});
if (!res.ok) {
const text = await res.text();
throw new Error(`gemini embeddings failed: ${res.status} ${text}`);
}
const payload = await res.json();
const embeddings = parseGeminiEmbeddings(payload);
return embeddings;
};
return {
provider: {
id: "gemini",
model: client.model,
embedQuery: embedContent,
embedBatch,
},
client,
};
function isMissingApiKeyError(err: unknown): boolean {
const message = formatError(err);
return message.includes("No API key found for provider");
}
async function resolveOpenAiEmbeddingClient(
options: EmbeddingProviderOptions,
): Promise<OpenAiEmbeddingClient> {
const remote = options.remote;
const remoteApiKey = remote?.apiKey?.trim();
const remoteBaseUrl = remote?.baseUrl?.trim();
const { apiKey } = remoteApiKey
? { apiKey: remoteApiKey }
: await resolveApiKeyForProvider({
provider: "openai",
cfg: options.config,
agentDir: options.agentDir,
});
const providerConfig = options.config.models?.providers?.openai;
const baseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_OPENAI_BASE_URL;
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
const headers: Record<string, string> = {
"Content-Type": "application/json",
Authorization: `Bearer ${apiKey}`,
...headerOverrides,
};
const model = normalizeOpenAiModel(options.model);
return { baseUrl, headers, model };
}
async function resolveGeminiEmbeddingClient(
options: EmbeddingProviderOptions,
): Promise<GeminiEmbeddingClient> {
const remote = options.remote;
const remoteApiKey = remote?.apiKey?.trim();
const remoteBaseUrl = remote?.baseUrl?.trim();
const { apiKey } = remoteApiKey
? { apiKey: remoteApiKey }
: await resolveApiKeyForProvider({
provider: "google",
cfg: options.config,
agentDir: options.agentDir,
});
const providerConfig = options.config.models?.providers?.google;
const baseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GEMINI_BASE_URL;
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
const headers: Record<string, string> = {
"Content-Type": "application/json",
"x-goog-api-key": apiKey,
...headerOverrides,
};
const model = normalizeGeminiModel(options.model);
return { baseUrl, headers, model };
}
async function createLocalEmbeddingProvider(
options: EmbeddingProviderOptions,
@@ -289,35 +121,80 @@ export async function createEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<EmbeddingProviderResult> {
const requestedProvider = options.provider;
if (options.provider === "gemini") {
const { provider, client } = await createGeminiEmbeddingProvider(options);
return { provider, requestedProvider, gemini: client };
}
if (options.provider === "local") {
try {
const fallback = options.fallback;
const createProvider = async (id: "openai" | "local" | "gemini") => {
if (id === "local") {
const provider = await createLocalEmbeddingProvider(options);
return { provider, requestedProvider };
} catch (err) {
const reason = formatLocalSetupError(err);
if (options.fallback === "openai") {
try {
const { provider, client } = await createOpenAiEmbeddingProvider(options);
return {
provider,
requestedProvider,
fallbackFrom: "local",
fallbackReason: reason,
openAi: client,
};
} catch (fallbackErr) {
throw new Error(`${reason}\n\nFallback to OpenAI failed: ${formatError(fallbackErr)}`);
}
}
throw new Error(reason);
return { provider };
}
if (id === "gemini") {
const { provider, client } = await createGeminiEmbeddingProvider(options);
return { provider, gemini: client };
}
const { provider, client } = await createOpenAiEmbeddingProvider(options);
return { provider, openAi: client };
};
const formatPrimaryError = (err: unknown, provider: "openai" | "local" | "gemini") =>
provider === "local" ? formatLocalSetupError(err) : formatError(err);
if (requestedProvider === "auto") {
const missingKeyErrors: string[] = [];
let localError: string | null = null;
if (canAutoSelectLocal(options)) {
try {
const local = await createProvider("local");
return { ...local, requestedProvider };
} catch (err) {
localError = formatLocalSetupError(err);
}
}
for (const provider of ["openai", "gemini"] as const) {
try {
const result = await createProvider(provider);
return { ...result, requestedProvider };
} catch (err) {
const message = formatPrimaryError(err, provider);
if (isMissingApiKeyError(err)) {
missingKeyErrors.push(message);
continue;
}
throw new Error(message);
}
}
const details = [...missingKeyErrors, localError].filter(Boolean) as string[];
if (details.length > 0) {
throw new Error(details.join("\n\n"));
}
throw new Error("No embeddings provider available.");
}
try {
const primary = await createProvider(requestedProvider);
return { ...primary, requestedProvider };
} catch (primaryErr) {
const reason = formatPrimaryError(primaryErr, requestedProvider);
if (fallback && fallback !== "none" && fallback !== requestedProvider) {
try {
const fallbackResult = await createProvider(fallback);
return {
...fallbackResult,
requestedProvider,
fallbackFrom: requestedProvider,
fallbackReason: reason,
};
} catch (fallbackErr) {
throw new Error(
`${reason}\n\nFallback to ${fallback} failed: ${formatError(fallbackErr)}`,
);
}
}
throw new Error(reason);
}
const { provider, client } = await createOpenAiEmbeddingProvider(options);
return { provider, requestedProvider, openAi: client };
}
function formatError(err: unknown): string {