From c00ea63bb08adbef90a16615111b4937ce9ee0a4 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Sun, 18 Jan 2026 03:09:28 +0000 Subject: [PATCH] refactor: split memory manager internals --- src/memory/manager-search.ts | 181 +++++++++ src/memory/manager.ts | 700 ++++------------------------------- src/memory/memory-schema.ts | 95 +++++ src/memory/openai-batch.ts | 360 ++++++++++++++++++ src/memory/sqlite-vec.ts | 25 ++ 5 files changed, 740 insertions(+), 621 deletions(-) create mode 100644 src/memory/manager-search.ts create mode 100644 src/memory/memory-schema.ts create mode 100644 src/memory/openai-batch.ts create mode 100644 src/memory/sqlite-vec.ts diff --git a/src/memory/manager-search.ts b/src/memory/manager-search.ts new file mode 100644 index 000000000..0cd6492b1 --- /dev/null +++ b/src/memory/manager-search.ts @@ -0,0 +1,181 @@ +import type { DatabaseSync } from "node:sqlite"; + +import { truncateUtf16Safe } from "../utils.js"; +import { cosineSimilarity, parseEmbedding } from "./internal.js"; + +const vectorToBlob = (embedding: number[]): Buffer => Buffer.from(new Float32Array(embedding).buffer); + +export type SearchSource = string; + +export type SearchRowResult = { + id: string; + path: string; + startLine: number; + endLine: number; + score: number; + snippet: string; + source: SearchSource; +}; + +export async function searchVector(params: { + db: DatabaseSync; + vectorTable: string; + providerModel: string; + queryVec: number[]; + limit: number; + snippetMaxChars: number; + ensureVectorReady: (dimensions: number) => Promise; + sourceFilterVec: { sql: string; params: SearchSource[] }; + sourceFilterChunks: { sql: string; params: SearchSource[] }; +}): Promise { + if (params.queryVec.length === 0 || params.limit <= 0) return []; + if (await params.ensureVectorReady(params.queryVec.length)) { + const rows = params.db + .prepare( + `SELECT c.id, c.path, c.start_line, c.end_line, c.text,\n` + + ` c.source,\n` + + ` vec_distance_cosine(v.embedding, ?) AS dist\n` + + ` FROM ${params.vectorTable} v\n` + + ` JOIN chunks c ON c.id = v.id\n` + + ` WHERE c.model = ?${params.sourceFilterVec.sql}\n` + + ` ORDER BY dist ASC\n` + + ` LIMIT ?`, + ) + .all( + vectorToBlob(params.queryVec), + params.providerModel, + ...params.sourceFilterVec.params, + params.limit, + ) as Array<{ + id: string; + path: string; + start_line: number; + end_line: number; + text: string; + source: SearchSource; + dist: number; + }>; + return rows.map((row) => ({ + id: row.id, + path: row.path, + startLine: row.start_line, + endLine: row.end_line, + score: 1 - row.dist, + snippet: truncateUtf16Safe(row.text, params.snippetMaxChars), + source: row.source, + })); + } + + const candidates = listChunks({ + db: params.db, + providerModel: params.providerModel, + sourceFilter: params.sourceFilterChunks, + }); + const scored = candidates + .map((chunk) => ({ + chunk, + score: cosineSimilarity(params.queryVec, chunk.embedding), + })) + .filter((entry) => Number.isFinite(entry.score)); + return scored + .sort((a, b) => b.score - a.score) + .slice(0, params.limit) + .map((entry) => ({ + id: entry.chunk.id, + path: entry.chunk.path, + startLine: entry.chunk.startLine, + endLine: entry.chunk.endLine, + score: entry.score, + snippet: truncateUtf16Safe(entry.chunk.text, params.snippetMaxChars), + source: entry.chunk.source, + })); +} + +export function listChunks(params: { + db: DatabaseSync; + providerModel: string; + sourceFilter: { sql: string; params: SearchSource[] }; +}): Array<{ + id: string; + path: string; + startLine: number; + endLine: number; + text: string; + embedding: number[]; + source: SearchSource; +}> { + const rows = params.db + .prepare( + `SELECT id, path, start_line, end_line, text, embedding, source\n` + + ` FROM chunks\n` + + ` WHERE model = ?${params.sourceFilter.sql}`, + ) + .all(params.providerModel, ...params.sourceFilter.params) as Array<{ + id: string; + path: string; + start_line: number; + end_line: number; + text: string; + embedding: string; + source: SearchSource; + }>; + + return rows.map((row) => ({ + id: row.id, + path: row.path, + startLine: row.start_line, + endLine: row.end_line, + text: row.text, + embedding: parseEmbedding(row.embedding), + source: row.source, + })); +} + +export async function searchKeyword(params: { + db: DatabaseSync; + ftsTable: string; + providerModel: string; + query: string; + limit: number; + snippetMaxChars: number; + sourceFilter: { sql: string; params: SearchSource[] }; + buildFtsQuery: (raw: string) => string | null; + bm25RankToScore: (rank: number) => number; +}): Promise> { + if (params.limit <= 0) return []; + const ftsQuery = params.buildFtsQuery(params.query); + if (!ftsQuery) return []; + + const rows = params.db + .prepare( + `SELECT id, path, source, start_line, end_line, text,\n` + + ` bm25(${params.ftsTable}) AS rank\n` + + ` FROM ${params.ftsTable}\n` + + ` WHERE ${params.ftsTable} MATCH ? AND model = ?${params.sourceFilter.sql}\n` + + ` ORDER BY rank ASC\n` + + ` LIMIT ?`, + ) + .all(ftsQuery, params.providerModel, ...params.sourceFilter.params, params.limit) as Array<{ + id: string; + path: string; + source: SearchSource; + start_line: number; + end_line: number; + text: string; + rank: number; + }>; + + return rows.map((row) => { + const textScore = params.bm25RankToScore(row.rank); + return { + id: row.id, + path: row.path, + startLine: row.start_line, + endLine: row.end_line, + score: textScore, + textScore, + snippet: truncateUtf16Safe(row.text, params.snippetMaxChars), + source: row.source, + }; + }); +} diff --git a/src/memory/manager.ts b/src/memory/manager.ts index 3078416e6..db11bdeb6 100644 --- a/src/memory/manager.ts +++ b/src/memory/manager.ts @@ -13,16 +13,22 @@ import { createSubsystemLogger } from "../logging.js"; import { onSessionTranscriptUpdate } from "../sessions/transcript-events.js"; import { resolveUserPath, truncateUtf16Safe } from "../utils.js"; import { colorize, isRich, theme } from "../terminal/theme.js"; +import { resolveUserPath, truncateUtf16Safe } from "../utils.js"; +import { colorize, isRich, theme } from "../terminal/theme.js"; import { createEmbeddingProvider, type EmbeddingProvider, type EmbeddingProviderResult, type OpenAiEmbeddingClient, } from "./embeddings.js"; +import { + OPENAI_BATCH_ENDPOINT, + type OpenAiBatchRequest, + runOpenAiEmbeddingBatches, +} from "./openai-batch.js"; import { buildFileEntry, chunkMarkdown, - cosineSimilarity, ensureDir, hashText, isMemoryPath, @@ -32,7 +38,11 @@ import { normalizeRelPath, parseEmbedding, } from "./internal.js"; +import { bm25RankToScore, buildFtsQuery, mergeHybridResults } from "./hybrid.js"; +import { searchKeyword, searchVector } from "./manager-search.js"; +import { ensureMemoryIndexSchema } from "./memory-schema.js"; import { requireNodeSqlite } from "./sqlite.js"; +import { loadSqliteVecExtension } from "./sqlite-vec.js"; type MemorySource = "memory" | "sessions"; @@ -76,40 +86,6 @@ type MemorySyncProgressState = { report: (update: MemorySyncProgressUpdate) => void; }; -type OpenAiBatchRequest = { - custom_id: string; - method: "POST"; - url: "/v1/embeddings"; - body: { - model: string; - input: string; - }; -}; - -type OpenAiBatchStatus = { - id?: string; - status?: string; - output_file_id?: string | null; - error_file_id?: string | null; - request_counts?: { - total?: number; - completed?: number; - failed?: number; - }; -}; - -type OpenAiBatchOutputLine = { - custom_id?: string; - response?: { - status_code?: number; - body?: { - data?: Array<{ embedding?: number[]; index?: number }>; - error?: { message?: string }; - }; - }; - error?: { message?: string }; -}; - const META_KEY = "memory_index_meta_v1"; const SNIPPET_MAX_CHARS = 700; const VECTOR_TABLE = "chunks_vec"; @@ -122,9 +98,6 @@ const EMBEDDING_INDEX_CONCURRENCY = 4; const EMBEDDING_RETRY_MAX_ATTEMPTS = 3; const EMBEDDING_RETRY_BASE_DELAY_MS = 500; const EMBEDDING_RETRY_MAX_DELAY_MS = 8000; -const OPENAI_BATCH_ENDPOINT = "/v1/embeddings"; -const OPENAI_BATCH_COMPLETION_WINDOW = "24h"; -const OPENAI_BATCH_MAX_REQUESTS = 50000; const log = createSubsystemLogger("memory"); @@ -321,70 +294,22 @@ export class MemoryIndexManager { queryVec: number[], limit: number, ): Promise> { - if (queryVec.length === 0 || limit <= 0) return []; - if (await this.ensureVectorReady(queryVec.length)) { - const sourceFilter = this.buildSourceFilter("c"); - const rows = this.db - .prepare( - `SELECT c.id, c.path, c.start_line, c.end_line, c.text,\n` + - ` c.source,\n` + - ` vec_distance_cosine(v.embedding, ?) AS dist\n` + - ` FROM ${VECTOR_TABLE} v\n` + - ` JOIN chunks c ON c.id = v.id\n` + - ` WHERE c.model = ?${sourceFilter.sql}\n` + - ` ORDER BY dist ASC\n` + - ` LIMIT ?`, - ) - .all(vectorToBlob(queryVec), this.provider.model, ...sourceFilter.params, limit) as Array<{ - id: string; - path: string; - start_line: number; - end_line: number; - text: string; - source: MemorySource; - dist: number; - }>; - return rows.map((row) => ({ - id: row.id, - path: row.path, - startLine: row.start_line, - endLine: row.end_line, - score: 1 - row.dist, - snippet: truncateUtf16Safe(row.text, SNIPPET_MAX_CHARS), - source: row.source, - })); - } - - const candidates = this.listChunks(); - const scored = candidates - .map((chunk) => ({ - chunk, - score: cosineSimilarity(queryVec, chunk.embedding), - })) - .filter((entry) => Number.isFinite(entry.score)); - return scored - .sort((a, b) => b.score - a.score) - .slice(0, limit) - .map((entry) => ({ - id: entry.chunk.id, - path: entry.chunk.path, - startLine: entry.chunk.startLine, - endLine: entry.chunk.endLine, - score: entry.score, - snippet: truncateUtf16Safe(entry.chunk.text, SNIPPET_MAX_CHARS), - source: entry.chunk.source, - })); + const results = await searchVector({ + db: this.db, + vectorTable: VECTOR_TABLE, + providerModel: this.provider.model, + queryVec, + limit, + snippetMaxChars: SNIPPET_MAX_CHARS, + ensureVectorReady: async (dimensions) => await this.ensureVectorReady(dimensions), + sourceFilterVec: this.buildSourceFilter("c"), + sourceFilterChunks: this.buildSourceFilter(), + }); + return results.map((entry) => entry as MemorySearchResult & { id: string }); } private buildFtsQuery(raw: string): string | null { - const tokens = - raw - .match(/[A-Za-z0-9_]+/g) - ?.map((t) => t.trim()) - .filter(Boolean) ?? []; - if (tokens.length === 0) return null; - const quoted = tokens.map((t) => `"${t.replaceAll('"', "")}"`); - return quoted.join(" AND "); + return buildFtsQuery(raw); } private async searchKeyword( @@ -392,42 +317,19 @@ export class MemoryIndexManager { limit: number, ): Promise> { if (!this.fts.enabled || !this.fts.available) return []; - if (limit <= 0) return []; - const ftsQuery = this.buildFtsQuery(query); - if (!ftsQuery) return []; const sourceFilter = this.buildSourceFilter(); - const rows = this.db - .prepare( - `SELECT id, path, source, start_line, end_line, text,\n` + - ` bm25(${FTS_TABLE}) AS rank\n` + - ` FROM ${FTS_TABLE}\n` + - ` WHERE ${FTS_TABLE} MATCH ? AND model = ?${sourceFilter.sql}\n` + - ` ORDER BY rank ASC\n` + - ` LIMIT ?`, - ) - .all(ftsQuery, this.provider.model, ...sourceFilter.params, limit) as Array<{ - id: string; - path: string; - source: MemorySource; - start_line: number; - end_line: number; - text: string; - rank: number; - }>; - return rows.map((row) => { - const rank = Number.isFinite(row.rank) ? Math.max(0, row.rank) : 999; - const textScore = 1 / (1 + rank); - return { - id: row.id, - path: row.path, - startLine: row.start_line, - endLine: row.end_line, - score: textScore, - textScore, - snippet: truncateUtf16Safe(row.text, SNIPPET_MAX_CHARS), - source: row.source, - }; + const results = await searchKeyword({ + db: this.db, + ftsTable: FTS_TABLE, + providerModel: this.provider.model, + query, + limit, + snippetMaxChars: SNIPPET_MAX_CHARS, + sourceFilter, + buildFtsQuery: (raw) => this.buildFtsQuery(raw), + bm25RankToScore, }); + return results.map((entry) => entry as MemorySearchResult & { id: string; textScore: number }); } private mergeHybridResults(params: { @@ -436,22 +338,8 @@ export class MemoryIndexManager { vectorWeight: number; textWeight: number; }): MemorySearchResult[] { - const byId = new Map< - string, - { - id: string; - path: string; - startLine: number; - endLine: number; - source: MemorySource; - snippet: string; - vectorScore: number; - textScore: number; - } - >(); - - for (const r of params.vector) { - byId.set(r.id, { + const merged = mergeHybridResults({ + vector: params.vector.map((r) => ({ id: r.id, path: r.path, startLine: r.startLine, @@ -459,40 +347,20 @@ export class MemoryIndexManager { source: r.source, snippet: r.snippet, vectorScore: r.score, - textScore: 0, - }); - } - for (const r of params.keyword) { - const existing = byId.get(r.id); - if (existing) { - existing.textScore = r.textScore; - if (r.snippet && r.snippet.length > 0) existing.snippet = r.snippet; - } else { - byId.set(r.id, { - id: r.id, - path: r.path, - startLine: r.startLine, - endLine: r.endLine, - source: r.source, - snippet: r.snippet, - vectorScore: 0, - textScore: r.textScore, - }); - } - } - - const merged = Array.from(byId.values()).map((entry) => { - const score = params.vectorWeight * entry.vectorScore + params.textWeight * entry.textScore; - return { - path: entry.path, - startLine: entry.startLine, - endLine: entry.endLine, - score, - snippet: entry.snippet, - source: entry.source, - } satisfies MemorySearchResult; + })), + keyword: params.keyword.map((r) => ({ + id: r.id, + path: r.path, + startLine: r.startLine, + endLine: r.endLine, + source: r.source, + snippet: r.snippet, + textScore: r.textScore, + })), + vectorWeight: params.vectorWeight, + textWeight: params.textWeight, }); - return merged.sort((a, b) => b.score - a.score); + return merged.map((entry) => entry as MemorySearchResult); } async sync(params?: { @@ -693,17 +561,12 @@ export class MemoryIndexManager { return false; } try { - const sqliteVec = await import("sqlite-vec"); - const extensionPath = this.vector.extensionPath?.trim() + const resolvedPath = this.vector.extensionPath?.trim() ? resolveUserPath(this.vector.extensionPath) - : sqliteVec.getLoadablePath(); - this.db.enableLoadExtension(true); - if (this.vector.extensionPath?.trim()) { - this.db.loadExtension(extensionPath); - } else { - sqliteVec.load(this.db); - } - this.vector.extensionPath = extensionPath; + : undefined; + const loaded = await loadSqliteVecExtension({ db: this.db, extensionPath: resolvedPath }); + if (!loaded.ok) throw new Error(loaded.error ?? "unknown sqlite-vec load error"); + this.vector.extensionPath = loaded.extensionPath; this.vector.available = true; return true; } catch (err) { @@ -746,14 +609,6 @@ export class MemoryIndexManager { return { sql: ` AND ${column} IN (${placeholders})`, params: sources }; } - private ensureColumn(table: "files" | "chunks", column: string, definition: string): void { - const rows = this.db.prepare(`PRAGMA table_info(${table})`).all() as Array<{ - name: string; - }>; - if (rows.some((row) => row.name === column)) return; - this.db.exec(`ALTER TABLE ${table} ADD COLUMN ${column} ${definition}`); - } - private openDatabase(): DatabaseSync { const dbPath = resolveUserPath(this.settings.store.path); const dir = path.dirname(dbPath); @@ -763,75 +618,17 @@ export class MemoryIndexManager { } private ensureSchema() { - this.db.exec(` - CREATE TABLE IF NOT EXISTS meta ( - key TEXT PRIMARY KEY, - value TEXT NOT NULL - ); - `); - this.db.exec(` - CREATE TABLE IF NOT EXISTS files ( - path TEXT PRIMARY KEY, - source TEXT NOT NULL DEFAULT 'memory', - hash TEXT NOT NULL, - mtime INTEGER NOT NULL, - size INTEGER NOT NULL - ); - `); - this.db.exec(` - CREATE TABLE IF NOT EXISTS chunks ( - id TEXT PRIMARY KEY, - path TEXT NOT NULL, - source TEXT NOT NULL DEFAULT 'memory', - start_line INTEGER NOT NULL, - end_line INTEGER NOT NULL, - hash TEXT NOT NULL, - model TEXT NOT NULL, - text TEXT NOT NULL, - embedding TEXT NOT NULL, - updated_at INTEGER NOT NULL - ); - `); - this.db.exec(` - CREATE TABLE IF NOT EXISTS ${EMBEDDING_CACHE_TABLE} ( - provider TEXT NOT NULL, - model TEXT NOT NULL, - provider_key TEXT NOT NULL, - hash TEXT NOT NULL, - embedding TEXT NOT NULL, - dims INTEGER, - updated_at INTEGER NOT NULL, - PRIMARY KEY (provider, model, provider_key, hash) - ); - `); - this.db.exec( - `CREATE INDEX IF NOT EXISTS idx_embedding_cache_updated_at ON ${EMBEDDING_CACHE_TABLE}(updated_at);`, - ); - if (this.fts.enabled) { - try { - this.db.exec( - `CREATE VIRTUAL TABLE IF NOT EXISTS ${FTS_TABLE} USING fts5(\n` + - ` text,\n` + - ` id UNINDEXED,\n` + - ` path UNINDEXED,\n` + - ` source UNINDEXED,\n` + - ` model UNINDEXED,\n` + - ` start_line UNINDEXED,\n` + - ` end_line UNINDEXED\n` + - `);`, - ); - this.fts.available = true; - } catch (err) { - const message = err instanceof Error ? err.message : String(err); - this.fts.available = false; - this.fts.loadError = message; - log.warn(`fts unavailable: ${message}`); - } + const result = ensureMemoryIndexSchema({ + db: this.db, + embeddingCacheTable: EMBEDDING_CACHE_TABLE, + ftsTable: FTS_TABLE, + ftsEnabled: this.fts.enabled, + }); + this.fts.available = result.ftsAvailable; + if (result.ftsError) { + this.fts.loadError = result.ftsError; + log.warn(`fts unavailable: ${result.ftsError}`); } - this.ensureColumn("files", "source", "TEXT NOT NULL DEFAULT 'memory'"); - this.ensureColumn("chunks", "source", "TEXT NOT NULL DEFAULT 'memory'"); - this.db.exec(`CREATE INDEX IF NOT EXISTS idx_chunks_path ON chunks(path);`); - this.db.exec(`CREATE INDEX IF NOT EXISTS idx_chunks_source ON chunks(source);`); } private ensureWatcher() { @@ -905,42 +702,6 @@ export class MemoryIndexManager { }, this.settings.sync.watchDebounceMs); } - private listChunks(): Array<{ - id: string; - path: string; - startLine: number; - endLine: number; - text: string; - embedding: number[]; - source: MemorySource; - }> { - const sourceFilter = this.buildSourceFilter(); - const rows = this.db - .prepare( - `SELECT id, path, start_line, end_line, text, embedding, source - FROM chunks - WHERE model = ?${sourceFilter.sql}`, - ) - .all(this.provider.model, ...sourceFilter.params) as Array<{ - id: string; - path: string; - start_line: number; - end_line: number; - text: string; - embedding: string; - source: MemorySource; - }>; - return rows.map((row) => ({ - id: row.id, - path: row.path, - startLine: row.start_line, - endLine: row.end_line, - text: row.text, - embedding: parseEmbedding(row.embedding), - source: row.source, - })); - } - private shouldSyncSessions( params?: { reason?: string; force?: boolean }, needsFullReindex = false, @@ -1475,23 +1236,6 @@ export class MemoryIndexManager { return embeddings; } - private getOpenAiBaseUrl(): string { - return this.openAi?.baseUrl?.replace(/\/$/, "") ?? ""; - } - - private getOpenAiHeaders(params: { json: boolean }): Record { - const headers = this.openAi?.headers ? { ...this.openAi.headers } : {}; - if (params.json) { - if (!headers["Content-Type"] && !headers["content-type"]) { - headers["Content-Type"] = "application/json"; - } - } else { - delete headers["Content-Type"]; - delete headers["content-type"]; - } - return headers; - } - private computeProviderKey(): string { if (this.provider.id === "openai" && this.openAi) { const entries = Object.entries(this.openAi.headers) @@ -1510,225 +1254,6 @@ export class MemoryIndexManager { return hashText(JSON.stringify({ provider: this.provider.id, model: this.provider.model })); } - private buildOpenAiBatchRequests( - chunks: MemoryChunk[], - entry: MemoryFileEntry | SessionFileEntry, - source: MemorySource, - ): { requests: OpenAiBatchRequest[]; mapping: Map } { - const requests: OpenAiBatchRequest[] = []; - const mapping = new Map(); - for (let i = 0; i < chunks.length; i += 1) { - const chunk = chunks[i]; - const customId = hashText( - `${source}:${entry.path}:${chunk.startLine}:${chunk.endLine}:${chunk.hash}:${i}`, - ); - mapping.set(customId, i); - requests.push({ - custom_id: customId, - method: "POST", - url: OPENAI_BATCH_ENDPOINT, - body: { - model: this.openAi?.model ?? this.provider.model, - input: chunk.text, - }, - }); - } - return { requests, mapping }; - } - - private splitOpenAiBatchRequests(requests: OpenAiBatchRequest[]): OpenAiBatchRequest[][] { - if (requests.length <= OPENAI_BATCH_MAX_REQUESTS) return [requests]; - const groups: OpenAiBatchRequest[][] = []; - for (let i = 0; i < requests.length; i += OPENAI_BATCH_MAX_REQUESTS) { - groups.push(requests.slice(i, i + OPENAI_BATCH_MAX_REQUESTS)); - } - return groups; - } - - private async submitOpenAiBatch(requests: OpenAiBatchRequest[]): Promise { - if (!this.openAi) { - throw new Error("OpenAI batch requested without an OpenAI embedding client."); - } - const baseUrl = this.getOpenAiBaseUrl(); - const jsonl = requests.map((request) => JSON.stringify(request)).join("\n"); - const form = new FormData(); - form.append("purpose", "batch"); - form.append( - "file", - new Blob([jsonl], { type: "application/jsonl" }), - "memory-embeddings.jsonl", - ); - - const fileRes = await fetch(`${baseUrl}/files`, { - method: "POST", - headers: this.getOpenAiHeaders({ json: false }), - body: form, - }); - if (!fileRes.ok) { - const text = await fileRes.text(); - throw new Error(`openai batch file upload failed: ${fileRes.status} ${text}`); - } - const filePayload = (await fileRes.json()) as { id?: string }; - if (!filePayload.id) { - throw new Error("openai batch file upload failed: missing file id"); - } - - const batchRes = await fetch(`${baseUrl}/batches`, { - method: "POST", - headers: this.getOpenAiHeaders({ json: true }), - body: JSON.stringify({ - input_file_id: filePayload.id, - endpoint: OPENAI_BATCH_ENDPOINT, - completion_window: OPENAI_BATCH_COMPLETION_WINDOW, - metadata: { - source: "clawdbot-memory", - agent: this.agentId, - }, - }), - }); - if (!batchRes.ok) { - const text = await batchRes.text(); - throw new Error(`openai batch create failed: ${batchRes.status} ${text}`); - } - return (await batchRes.json()) as OpenAiBatchStatus; - } - - private async fetchOpenAiBatchStatus(batchId: string): Promise { - const baseUrl = this.getOpenAiBaseUrl(); - const res = await fetch(`${baseUrl}/batches/${batchId}`, { - headers: this.getOpenAiHeaders({ json: true }), - }); - if (!res.ok) { - const text = await res.text(); - throw new Error(`openai batch status failed: ${res.status} ${text}`); - } - return (await res.json()) as OpenAiBatchStatus; - } - - private async fetchOpenAiFileContent(fileId: string): Promise { - const baseUrl = this.getOpenAiBaseUrl(); - const res = await fetch(`${baseUrl}/files/${fileId}/content`, { - headers: this.getOpenAiHeaders({ json: true }), - }); - if (!res.ok) { - const text = await res.text(); - throw new Error(`openai batch file content failed: ${res.status} ${text}`); - } - return await res.text(); - } - - private parseOpenAiBatchOutput(text: string): OpenAiBatchOutputLine[] { - if (!text.trim()) return []; - return text - .split("\n") - .map((line) => line.trim()) - .filter(Boolean) - .map((line) => JSON.parse(line) as OpenAiBatchOutputLine); - } - - private async readOpenAiBatchError(errorFileId: string): Promise { - try { - const content = await this.fetchOpenAiFileContent(errorFileId); - const lines = this.parseOpenAiBatchOutput(content); - const first = lines.find((line) => line.error?.message || line.response?.body?.error); - const message = - first?.error?.message ?? - (typeof first?.response?.body?.error?.message === "string" - ? first?.response?.body?.error?.message - : undefined); - return message; - } catch (err) { - const message = err instanceof Error ? err.message : String(err); - return message ? `error file unavailable: ${message}` : undefined; - } - } - - private async waitForOpenAiBatch( - batchId: string, - initial?: OpenAiBatchStatus, - ): Promise<{ outputFileId: string; errorFileId?: string }> { - const start = Date.now(); - let current: OpenAiBatchStatus | undefined = initial; - while (true) { - const status = current ?? (await this.fetchOpenAiBatchStatus(batchId)); - const state = status.status ?? "unknown"; - if (state === "completed") { - log.debug(`openai batch ${batchId} ${state}`, { - consoleMessage: this.formatOpenAiBatchConsoleMessage({ - batchId, - state, - counts: status.request_counts, - }), - }); - if (!status.output_file_id) { - throw new Error(`openai batch ${batchId} completed without output file`); - } - return { - outputFileId: status.output_file_id, - errorFileId: status.error_file_id ?? undefined, - }; - } - if (["failed", "expired", "cancelled", "canceled"].includes(state)) { - const detail = status.error_file_id - ? await this.readOpenAiBatchError(status.error_file_id) - : undefined; - const suffix = detail ? `: ${detail}` : ""; - throw new Error(`openai batch ${batchId} ${state}${suffix}`); - } - if (!this.batch.wait) { - throw new Error(`openai batch ${batchId} still ${state}; wait disabled`); - } - if (Date.now() - start > this.batch.timeoutMs) { - throw new Error(`openai batch ${batchId} timed out after ${this.batch.timeoutMs}ms`); - } - log.debug(`openai batch ${batchId} ${state}; waiting ${this.batch.pollIntervalMs}ms`, { - consoleMessage: this.formatOpenAiBatchConsoleMessage({ - batchId, - state, - waitMs: this.batch.pollIntervalMs, - counts: status.request_counts, - }), - }); - await new Promise((resolve) => setTimeout(resolve, this.batch.pollIntervalMs)); - current = undefined; - } - } - - private formatOpenAiBatchConsoleMessage(params: { - batchId: string; - state: string; - waitMs?: number; - counts?: OpenAiBatchStatus["request_counts"]; - }): string { - const rich = isRich(); - const normalized = params.state.toLowerCase(); - const successStates = new Set(["completed", "succeeded"]); - const errorStates = new Set(["failed", "expired", "cancelled", "canceled"]); - const warnStates = new Set(["finalizing", "validating"]); - let color = theme.info; - if (successStates.has(normalized)) color = theme.success; - else if (errorStates.has(normalized)) color = theme.error; - else if (warnStates.has(normalized)) color = theme.warn; - const status = colorize(rich, color, params.state); - const progress = this.formatOpenAiBatchProgress(params.counts); - const suffix = typeof params.waitMs === "number" ? `; waiting ${params.waitMs}ms` : ""; - const progressText = progress ? ` ${progress}` : ""; - return `openai batch ${params.batchId} ${status}${progressText}${suffix}`; - } - - private formatOpenAiBatchProgress( - counts?: OpenAiBatchStatus["request_counts"], - ): string | undefined { - if (!counts) return undefined; - const total = counts.total ?? 0; - if (!Number.isFinite(total) || total <= 0) return undefined; - const completed = Math.max(0, counts.completed ?? 0); - const failed = Math.max(0, counts.failed ?? 0); - const percent = Math.min(100, Math.max(0, Math.round((completed / total) * 100))); - const failureSuffix = failed > 0 ? `, ${failed} failed` : ""; - return `(${completed}/${total} ${percent}%${failureSuffix})`; - } - private async embedChunksWithBatch( chunks: MemoryChunk[], entry: MemoryFileEntry | SessionFileEntry, @@ -1755,13 +1280,13 @@ export class MemoryIndexManager { if (missing.length === 0) return embeddings; const requests: OpenAiBatchRequest[] = []; - const mapping = new Map(); + const mapping = new Map(); for (const item of missing) { const chunk = item.chunk; const customId = hashText( `${source}:${entry.path}:${chunk.startLine}:${chunk.endLine}:${chunk.hash}:${item.index}`, ); - mapping.set(customId, item.index); + mapping.set(customId, { index: item.index, hash: chunk.hash }); requests.push({ custom_id: customId, method: "POST", @@ -1772,91 +1297,24 @@ export class MemoryIndexManager { }, }); } - const groups = this.splitOpenAiBatchRequests(requests); - log.debug("memory embeddings: openai batch submit", { - source, - chunks: chunks.length, - requests: requests.length, - groups: groups.length, + const byCustomId = await runOpenAiEmbeddingBatches({ + openAi: this.openAi, + agentId: this.agentId, + requests, wait: this.batch.wait, concurrency: this.batch.concurrency, pollIntervalMs: this.batch.pollIntervalMs, timeoutMs: this.batch.timeoutMs, + debug: (message, data) => log.debug(message, { ...data, source, chunks: chunks.length }), }); + const toCache: Array<{ hash: string; embedding: number[] }> = []; - - const tasks = groups.map((group, groupIndex) => async () => { - const batchInfo = await this.submitOpenAiBatch(group); - if (!batchInfo.id) { - throw new Error("openai batch create failed: missing batch id"); - } - log.debug("memory embeddings: openai batch created", { - batchId: batchInfo.id, - status: batchInfo.status, - group: groupIndex + 1, - groups: groups.length, - requests: group.length, - }); - if (!this.batch.wait && batchInfo.status !== "completed") { - throw new Error( - `openai batch ${batchInfo.id} submitted; enable remote.batch.wait to await completion`, - ); - } - const completed = - batchInfo.status === "completed" - ? { - outputFileId: batchInfo.output_file_id ?? "", - errorFileId: batchInfo.error_file_id ?? undefined, - } - : await this.waitForOpenAiBatch(batchInfo.id, batchInfo); - if (!completed.outputFileId) { - throw new Error(`openai batch ${batchInfo.id} completed without output file`); - } - const content = await this.fetchOpenAiFileContent(completed.outputFileId); - const outputLines = this.parseOpenAiBatchOutput(content); - const errors: string[] = []; - const remaining = new Set(group.map((request) => request.custom_id)); - for (const line of outputLines) { - const customId = line.custom_id; - if (!customId) continue; - const index = mapping.get(customId); - if (index === undefined) continue; - remaining.delete(customId); - if (line.error?.message) { - errors.push(`${customId}: ${line.error.message}`); - continue; - } - const response = line.response; - const statusCode = response?.status_code ?? 0; - if (statusCode >= 400) { - const message = - response?.body?.error?.message ?? - (typeof response?.body === "string" ? response.body : undefined) ?? - "unknown error"; - errors.push(`${customId}: ${message}`); - continue; - } - const data = response?.body?.data ?? []; - const embedding = data[0]?.embedding ?? []; - if (embedding.length === 0) { - errors.push(`${customId}: empty embedding`); - continue; - } - embeddings[index] = embedding; - const chunk = chunks[index]; - if (chunk) toCache.push({ hash: chunk.hash, embedding }); - } - if (errors.length > 0) { - throw new Error(`openai batch ${batchInfo.id} failed: ${errors.join("; ")}`); - } - if (remaining.size > 0) { - throw new Error( - `openai batch ${batchInfo.id} missing ${remaining.size} embedding responses`, - ); - } - }); - await this.runWithConcurrency(tasks, this.batch.concurrency); - + for (const [customId, embedding] of byCustomId.entries()) { + const mapped = mapping.get(customId); + if (!mapped) continue; + embeddings[mapped.index] = embedding; + toCache.push({ hash: mapped.hash, embedding }); + } this.upsertEmbeddingCache(toCache); return embeddings; } diff --git a/src/memory/memory-schema.ts b/src/memory/memory-schema.ts new file mode 100644 index 000000000..741793793 --- /dev/null +++ b/src/memory/memory-schema.ts @@ -0,0 +1,95 @@ +import type { DatabaseSync } from "node:sqlite"; + +export function ensureMemoryIndexSchema(params: { + db: DatabaseSync; + embeddingCacheTable: string; + ftsTable: string; + ftsEnabled: boolean; +}): { ftsAvailable: boolean; ftsError?: string } { + params.db.exec(` + CREATE TABLE IF NOT EXISTS meta ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ); + `); + params.db.exec(` + CREATE TABLE IF NOT EXISTS files ( + path TEXT PRIMARY KEY, + source TEXT NOT NULL DEFAULT 'memory', + hash TEXT NOT NULL, + mtime INTEGER NOT NULL, + size INTEGER NOT NULL + ); + `); + params.db.exec(` + CREATE TABLE IF NOT EXISTS chunks ( + id TEXT PRIMARY KEY, + path TEXT NOT NULL, + source TEXT NOT NULL DEFAULT 'memory', + start_line INTEGER NOT NULL, + end_line INTEGER NOT NULL, + hash TEXT NOT NULL, + model TEXT NOT NULL, + text TEXT NOT NULL, + embedding TEXT NOT NULL, + updated_at INTEGER NOT NULL + ); + `); + params.db.exec(` + CREATE TABLE IF NOT EXISTS ${params.embeddingCacheTable} ( + provider TEXT NOT NULL, + model TEXT NOT NULL, + provider_key TEXT NOT NULL, + hash TEXT NOT NULL, + embedding TEXT NOT NULL, + dims INTEGER, + updated_at INTEGER NOT NULL, + PRIMARY KEY (provider, model, provider_key, hash) + ); + `); + params.db.exec( + `CREATE INDEX IF NOT EXISTS idx_embedding_cache_updated_at ON ${params.embeddingCacheTable}(updated_at);`, + ); + + let ftsAvailable = false; + let ftsError: string | undefined; + if (params.ftsEnabled) { + try { + params.db.exec( + `CREATE VIRTUAL TABLE IF NOT EXISTS ${params.ftsTable} USING fts5(\n` + + ` text,\n` + + ` id UNINDEXED,\n` + + ` path UNINDEXED,\n` + + ` source UNINDEXED,\n` + + ` model UNINDEXED,\n` + + ` start_line UNINDEXED,\n` + + ` end_line UNINDEXED\n` + + `);`, + ); + ftsAvailable = true; + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + ftsAvailable = false; + ftsError = message; + } + } + + ensureColumn(params.db, "files", "source", "TEXT NOT NULL DEFAULT 'memory'"); + ensureColumn(params.db, "chunks", "source", "TEXT NOT NULL DEFAULT 'memory'"); + params.db.exec(`CREATE INDEX IF NOT EXISTS idx_chunks_path ON chunks(path);`); + params.db.exec(`CREATE INDEX IF NOT EXISTS idx_chunks_source ON chunks(source);`); + + return { ftsAvailable, ...(ftsError ? { ftsError } : {}) }; +} + +function ensureColumn( + db: DatabaseSync, + table: "files" | "chunks", + column: string, + definition: string, +): void { + const rows = db.prepare(`PRAGMA table_info(${table})`).all() as Array<{ name: string }>; + if (rows.some((row) => row.name === column)) return; + db.exec(`ALTER TABLE ${table} ADD COLUMN ${column} ${definition}`); +} + diff --git a/src/memory/openai-batch.ts b/src/memory/openai-batch.ts new file mode 100644 index 000000000..eb21ff68f --- /dev/null +++ b/src/memory/openai-batch.ts @@ -0,0 +1,360 @@ +import type { OpenAiEmbeddingClient } from "./embeddings.js"; +import { hashText } from "./internal.js"; + +export type OpenAiBatchRequest = { + custom_id: string; + method: "POST"; + url: "/v1/embeddings"; + body: { + model: string; + input: string; + }; +}; + +export type OpenAiBatchStatus = { + id?: string; + status?: string; + output_file_id?: string | null; + error_file_id?: string | null; +}; + +export type OpenAiBatchOutputLine = { + custom_id?: string; + response?: { + status_code?: number; + body?: { + data?: Array<{ embedding?: number[]; index?: number }>; + error?: { message?: string }; + }; + }; + error?: { message?: string }; +}; + +export const OPENAI_BATCH_ENDPOINT = "/v1/embeddings"; +const OPENAI_BATCH_COMPLETION_WINDOW = "24h"; +const OPENAI_BATCH_MAX_REQUESTS = 50000; + +function getOpenAiBaseUrl(openAi: OpenAiEmbeddingClient): string { + return openAi.baseUrl?.replace(/\/$/, "") ?? ""; +} + +function getOpenAiHeaders( + openAi: OpenAiEmbeddingClient, + params: { json: boolean }, +): Record { + const headers = openAi.headers ? { ...openAi.headers } : {}; + if (params.json) { + if (!headers["Content-Type"] && !headers["content-type"]) { + headers["Content-Type"] = "application/json"; + } + } else { + delete headers["Content-Type"]; + delete headers["content-type"]; + } + return headers; +} + +function splitOpenAiBatchRequests(requests: OpenAiBatchRequest[]): OpenAiBatchRequest[][] { + if (requests.length <= OPENAI_BATCH_MAX_REQUESTS) return [requests]; + const groups: OpenAiBatchRequest[][] = []; + for (let i = 0; i < requests.length; i += OPENAI_BATCH_MAX_REQUESTS) { + groups.push(requests.slice(i, i + OPENAI_BATCH_MAX_REQUESTS)); + } + return groups; +} + +async function submitOpenAiBatch(params: { + openAi: OpenAiEmbeddingClient; + requests: OpenAiBatchRequest[]; + agentId: string; +}): Promise { + const baseUrl = getOpenAiBaseUrl(params.openAi); + const jsonl = params.requests.map((request) => JSON.stringify(request)).join("\n"); + const form = new FormData(); + form.append("purpose", "batch"); + form.append( + "file", + new Blob([jsonl], { type: "application/jsonl" }), + `memory-embeddings.${hashText(String(Date.now()))}.jsonl`, + ); + + const fileRes = await fetch(`${baseUrl}/files`, { + method: "POST", + headers: getOpenAiHeaders(params.openAi, { json: false }), + body: form, + }); + if (!fileRes.ok) { + const text = await fileRes.text(); + throw new Error(`openai batch file upload failed: ${fileRes.status} ${text}`); + } + const filePayload = (await fileRes.json()) as { id?: string }; + if (!filePayload.id) { + throw new Error("openai batch file upload failed: missing file id"); + } + + const batchRes = await fetch(`${baseUrl}/batches`, { + method: "POST", + headers: getOpenAiHeaders(params.openAi, { json: true }), + body: JSON.stringify({ + input_file_id: filePayload.id, + endpoint: OPENAI_BATCH_ENDPOINT, + completion_window: OPENAI_BATCH_COMPLETION_WINDOW, + metadata: { + source: "clawdbot-memory", + agent: params.agentId, + }, + }), + }); + if (!batchRes.ok) { + const text = await batchRes.text(); + throw new Error(`openai batch create failed: ${batchRes.status} ${text}`); + } + return (await batchRes.json()) as OpenAiBatchStatus; +} + +async function fetchOpenAiBatchStatus(params: { + openAi: OpenAiEmbeddingClient; + batchId: string; +}): Promise { + const baseUrl = getOpenAiBaseUrl(params.openAi); + const res = await fetch(`${baseUrl}/batches/${params.batchId}`, { + headers: getOpenAiHeaders(params.openAi, { json: true }), + }); + if (!res.ok) { + const text = await res.text(); + throw new Error(`openai batch status failed: ${res.status} ${text}`); + } + return (await res.json()) as OpenAiBatchStatus; +} + +async function fetchOpenAiFileContent(params: { + openAi: OpenAiEmbeddingClient; + fileId: string; +}): Promise { + const baseUrl = getOpenAiBaseUrl(params.openAi); + const res = await fetch(`${baseUrl}/files/${params.fileId}/content`, { + headers: getOpenAiHeaders(params.openAi, { json: true }), + }); + if (!res.ok) { + const text = await res.text(); + throw new Error(`openai batch file content failed: ${res.status} ${text}`); + } + return await res.text(); +} + +function parseOpenAiBatchOutput(text: string): OpenAiBatchOutputLine[] { + if (!text.trim()) return []; + return text + .split("\n") + .map((line) => line.trim()) + .filter(Boolean) + .map((line) => JSON.parse(line) as OpenAiBatchOutputLine); +} + +async function readOpenAiBatchError(params: { + openAi: OpenAiEmbeddingClient; + errorFileId: string; +}): Promise { + try { + const content = await fetchOpenAiFileContent({ openAi: params.openAi, fileId: params.errorFileId }); + const lines = parseOpenAiBatchOutput(content); + const first = lines.find((line) => line.error?.message || line.response?.body?.error); + const message = + first?.error?.message ?? + (typeof first?.response?.body?.error?.message === "string" + ? first?.response?.body?.error?.message + : undefined); + return message; + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + return message ? `error file unavailable: ${message}` : undefined; + } +} + +async function waitForOpenAiBatch(params: { + openAi: OpenAiEmbeddingClient; + batchId: string; + wait: boolean; + pollIntervalMs: number; + timeoutMs: number; + debug?: (message: string, data?: Record) => void; + initial?: OpenAiBatchStatus; +}): Promise<{ outputFileId: string; errorFileId?: string }> { + const start = Date.now(); + let current: OpenAiBatchStatus | undefined = params.initial; + while (true) { + const status = + current ?? + (await fetchOpenAiBatchStatus({ + openAi: params.openAi, + batchId: params.batchId, + })); + const state = status.status ?? "unknown"; + if (state === "completed") { + if (!status.output_file_id) { + throw new Error(`openai batch ${params.batchId} completed without output file`); + } + return { + outputFileId: status.output_file_id, + errorFileId: status.error_file_id ?? undefined, + }; + } + if (["failed", "expired", "cancelled", "canceled"].includes(state)) { + const detail = status.error_file_id + ? await readOpenAiBatchError({ openAi: params.openAi, errorFileId: status.error_file_id }) + : undefined; + const suffix = detail ? `: ${detail}` : ""; + throw new Error(`openai batch ${params.batchId} ${state}${suffix}`); + } + if (!params.wait) { + throw new Error(`openai batch ${params.batchId} still ${state}; wait disabled`); + } + if (Date.now() - start > params.timeoutMs) { + throw new Error(`openai batch ${params.batchId} timed out after ${params.timeoutMs}ms`); + } + params.debug?.(`openai batch ${params.batchId} ${state}; waiting ${params.pollIntervalMs}ms`); + await new Promise((resolve) => setTimeout(resolve, params.pollIntervalMs)); + current = undefined; + } +} + +async function runWithConcurrency(tasks: Array<() => Promise>, limit: number): Promise { + if (tasks.length === 0) return []; + const resolvedLimit = Math.max(1, Math.min(limit, tasks.length)); + const results: T[] = Array.from({ length: tasks.length }); + let next = 0; + let firstError: unknown = null; + + const workers = Array.from({ length: resolvedLimit }, async () => { + while (true) { + if (firstError) return; + const index = next; + next += 1; + if (index >= tasks.length) return; + try { + results[index] = await tasks[index](); + } catch (err) { + firstError = err; + return; + } + } + }); + + await Promise.allSettled(workers); + if (firstError) throw firstError; + return results; +} + +export async function runOpenAiEmbeddingBatches(params: { + openAi: OpenAiEmbeddingClient; + agentId: string; + requests: OpenAiBatchRequest[]; + wait: boolean; + pollIntervalMs: number; + timeoutMs: number; + concurrency: number; + debug?: (message: string, data?: Record) => void; +}): Promise> { + if (params.requests.length === 0) return new Map(); + const groups = splitOpenAiBatchRequests(params.requests); + const byCustomId = new Map(); + + const tasks = groups.map((group, groupIndex) => async () => { + const batchInfo = await submitOpenAiBatch({ + openAi: params.openAi, + requests: group, + agentId: params.agentId, + }); + if (!batchInfo.id) { + throw new Error("openai batch create failed: missing batch id"); + } + + params.debug?.("memory embeddings: openai batch created", { + batchId: batchInfo.id, + status: batchInfo.status, + group: groupIndex + 1, + groups: groups.length, + requests: group.length, + }); + + if (!params.wait && batchInfo.status !== "completed") { + throw new Error( + `openai batch ${batchInfo.id} submitted; enable remote.batch.wait to await completion`, + ); + } + + const completed = + batchInfo.status === "completed" + ? { + outputFileId: batchInfo.output_file_id ?? "", + errorFileId: batchInfo.error_file_id ?? undefined, + } + : await waitForOpenAiBatch({ + openAi: params.openAi, + batchId: batchInfo.id, + wait: params.wait, + pollIntervalMs: params.pollIntervalMs, + timeoutMs: params.timeoutMs, + debug: params.debug, + initial: batchInfo, + }); + if (!completed.outputFileId) { + throw new Error(`openai batch ${batchInfo.id} completed without output file`); + } + + const content = await fetchOpenAiFileContent({ + openAi: params.openAi, + fileId: completed.outputFileId, + }); + const outputLines = parseOpenAiBatchOutput(content); + const errors: string[] = []; + const remaining = new Set(group.map((request) => request.custom_id)); + + for (const line of outputLines) { + const customId = line.custom_id; + if (!customId) continue; + remaining.delete(customId); + if (line.error?.message) { + errors.push(`${customId}: ${line.error.message}`); + continue; + } + const response = line.response; + const statusCode = response?.status_code ?? 0; + if (statusCode >= 400) { + const message = + response?.body?.error?.message ?? + (typeof response?.body === "string" ? response.body : undefined) ?? + "unknown error"; + errors.push(`${customId}: ${message}`); + continue; + } + const data = response?.body?.data ?? []; + const embedding = data[0]?.embedding ?? []; + if (embedding.length === 0) { + errors.push(`${customId}: empty embedding`); + continue; + } + byCustomId.set(customId, embedding); + } + + if (errors.length > 0) { + throw new Error(`openai batch ${batchInfo.id} failed: ${errors.join("; ")}`); + } + if (remaining.size > 0) { + throw new Error(`openai batch ${batchInfo.id} missing ${remaining.size} embedding responses`); + } + }); + + params.debug?.("memory embeddings: openai batch submit", { + requests: params.requests.length, + groups: groups.length, + wait: params.wait, + concurrency: params.concurrency, + pollIntervalMs: params.pollIntervalMs, + timeoutMs: params.timeoutMs, + }); + + await runWithConcurrency(tasks, params.concurrency); + return byCustomId; +} + diff --git a/src/memory/sqlite-vec.ts b/src/memory/sqlite-vec.ts new file mode 100644 index 000000000..288e16375 --- /dev/null +++ b/src/memory/sqlite-vec.ts @@ -0,0 +1,25 @@ +import type { DatabaseSync } from "node:sqlite"; + +export async function loadSqliteVecExtension(params: { + db: DatabaseSync; + extensionPath?: string; +}): Promise<{ ok: boolean; extensionPath?: string; error?: string }> { + try { + const sqliteVec = await import("sqlite-vec"); + const resolvedPath = params.extensionPath?.trim() ? params.extensionPath.trim() : undefined; + const extensionPath = resolvedPath ?? sqliteVec.getLoadablePath(); + + params.db.enableLoadExtension(true); + if (resolvedPath) { + params.db.loadExtension(extensionPath); + } else { + sqliteVec.load(params.db); + } + + return { ok: true, extensionPath }; + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + return { ok: false, error: message }; + } +} +