feat: add gemini memory embeddings

This commit is contained in:
Peter Steinberger
2026-01-18 09:09:13 +00:00
parent b015c7e5ad
commit a3a4996adb
13 changed files with 482 additions and 24 deletions

View File

@@ -12,10 +12,11 @@ export type EmbeddingProvider = {
export type EmbeddingProviderResult = {
provider: EmbeddingProvider;
requestedProvider: "openai" | "local";
requestedProvider: "openai" | "gemini" | "local";
fallbackFrom?: "local";
fallbackReason?: string;
openAi?: OpenAiEmbeddingClient;
gemini?: GeminiEmbeddingClient;
};
export type OpenAiEmbeddingClient = {
@@ -24,10 +25,16 @@ export type OpenAiEmbeddingClient = {
model: string;
};
export type GeminiEmbeddingClient = {
baseUrl: string;
headers: Record<string, string>;
model: string;
};
export type EmbeddingProviderOptions = {
config: ClawdbotConfig;
agentDir?: string;
provider: "openai" | "local";
provider: "openai" | "gemini" | "local";
remote?: {
baseUrl?: string;
apiKey?: string;
@@ -43,6 +50,8 @@ export type EmbeddingProviderOptions = {
const DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1";
const DEFAULT_LOCAL_MODEL = "hf:ggml-org/embeddinggemma-300M-GGUF/embeddinggemma-300M-Q8_0.gguf";
const DEFAULT_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta";
const DEFAULT_GEMINI_MODEL = "gemini-embedding-001";
function normalizeOpenAiModel(model: string): string {
const trimmed = model.trim();
@@ -51,6 +60,14 @@ function normalizeOpenAiModel(model: string): string {
return trimmed;
}
function normalizeGeminiModel(model: string): string {
const trimmed = model.trim();
if (!trimmed) return DEFAULT_GEMINI_MODEL;
if (trimmed.startsWith("models/")) return trimmed.slice("models/".length);
if (trimmed.startsWith("google/")) return trimmed.slice("google/".length);
return trimmed;
}
async function createOpenAiEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: OpenAiEmbeddingClient }> {
@@ -89,6 +106,83 @@ async function createOpenAiEmbeddingProvider(
};
}
function extractGeminiEmbeddingValues(entry: unknown): number[] {
if (!entry || typeof entry !== "object") return [];
const record = entry as { values?: unknown; embedding?: { values?: unknown } };
const values = record.values ?? record.embedding?.values;
if (!Array.isArray(values)) return [];
return values.filter((value): value is number => typeof value === "number");
}
function parseGeminiEmbeddings(payload: unknown): number[][] {
if (!payload || typeof payload !== "object") return [];
const data = payload as { embedding?: unknown; embeddings?: unknown[] };
if (Array.isArray(data.embeddings)) {
return data.embeddings.map((entry) => extractGeminiEmbeddingValues(entry));
}
if (data.embedding) {
return [extractGeminiEmbeddingValues(data.embedding)];
}
return [];
}
async function createGeminiEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<{ provider: EmbeddingProvider; client: GeminiEmbeddingClient }> {
const client = await resolveGeminiEmbeddingClient(options);
const baseUrl = client.baseUrl.replace(/\/$/, "");
const model = `models/${client.model}`;
const embedContent = async (input: string): Promise<number[]> => {
const res = await fetch(`${baseUrl}/${model}:embedContent`, {
method: "POST",
headers: client.headers,
body: JSON.stringify({
model,
content: { parts: [{ text: input }] },
}),
});
if (!res.ok) {
const text = await res.text();
throw new Error(`gemini embeddings failed: ${res.status} ${text}`);
}
const payload = await res.json();
const embeddings = parseGeminiEmbeddings(payload);
return embeddings[0] ?? [];
};
const embedBatch = async (input: string[]): Promise<number[][]> => {
if (input.length === 0) return [];
const res = await fetch(`${baseUrl}/${model}:batchEmbedContents`, {
method: "POST",
headers: client.headers,
body: JSON.stringify({
requests: input.map((text) => ({
model,
content: { parts: [{ text }] },
})),
}),
});
if (!res.ok) {
const text = await res.text();
throw new Error(`gemini embeddings failed: ${res.status} ${text}`);
}
const payload = await res.json();
const embeddings = parseGeminiEmbeddings(payload);
return embeddings;
};
return {
provider: {
id: "gemini",
model: client.model,
embedQuery: embedContent,
embedBatch,
},
client,
};
}
async function resolveOpenAiEmbeddingClient(
options: EmbeddingProviderOptions,
): Promise<OpenAiEmbeddingClient> {
@@ -116,6 +210,33 @@ async function resolveOpenAiEmbeddingClient(
return { baseUrl, headers, model };
}
async function resolveGeminiEmbeddingClient(
options: EmbeddingProviderOptions,
): Promise<GeminiEmbeddingClient> {
const remote = options.remote;
const remoteApiKey = remote?.apiKey?.trim();
const remoteBaseUrl = remote?.baseUrl?.trim();
const { apiKey } = remoteApiKey
? { apiKey: remoteApiKey }
: await resolveApiKeyForProvider({
provider: "google",
cfg: options.config,
agentDir: options.agentDir,
});
const providerConfig = options.config.models?.providers?.google;
const baseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GEMINI_BASE_URL;
const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers);
const headers: Record<string, string> = {
"Content-Type": "application/json",
"x-goog-api-key": apiKey,
...headerOverrides,
};
const model = normalizeGeminiModel(options.model);
return { baseUrl, headers, model };
}
async function createLocalEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<EmbeddingProvider> {
@@ -168,6 +289,10 @@ export async function createEmbeddingProvider(
options: EmbeddingProviderOptions,
): Promise<EmbeddingProviderResult> {
const requestedProvider = options.provider;
if (options.provider === "gemini") {
const { provider, client } = await createGeminiEmbeddingProvider(options);
return { provider, requestedProvider, gemini: client };
}
if (options.provider === "local") {
try {
const provider = await createLocalEmbeddingProvider(options);