/** * EmbeddingEngine * * Provides semantic similarity features using local ONNX embeddings (multilingual-e5-small) * and HNSW vector search via USearch. All processing is fully local — no external API calls. * * Features: * - findSimilar: Find thematically related posts (InsertModal, "have I written this?") * - suggestTags: Infer tags from similar posts * - findDuplicates: Audit tool for near-duplicate post detection * * Architecture: * - Model stays loaded across project switches (one model, multiple indexes) * - USearch index file per project: {userData}/projects/{projectId}/embeddings.usearch * - Label→postId mapping in `embedding_keys` DB table (avoids bigint JSON issues) * - Vector cache persisted in `embedding_keys.vector` DB column as BLOB for instant reload */ import { EventEmitter } from 'events'; import * as path from 'path'; import * as fs from 'fs/promises'; import * as crypto from 'crypto'; import { v4 as uuidv4 } from 'uuid'; import { eq, and, inArray } from 'drizzle-orm'; import { getDatabase } from '../database'; import { embeddingKeys, dismissedDuplicatePairs, posts } from '../database/schema'; export interface SimilarPost { postId: string; similarity: number; // cosine similarity 0-1 } export interface TagSuggestion { name: string; score: number; // weighted frequency } export interface DuplicatePair { postA: { id: string; title: string; slug: string; publishedAt?: Date }; postB: { id: string; title: string; slug: string; publishedAt?: Date }; similarity: number; exactMatch?: boolean; } // Injected dependencies for testability export interface EmbeddingEngineDeps { /** Return the path to the USearch index file for a project */ getIndexPath: (projectId: string) => string; /** Create the embedding pipeline (dependency-injected for tests) */ createPipeline?: () => Promise; } export interface EmbeddingPipeline { embed(text: string): Promise; } export class EmbeddingEngine extends EventEmitter { private deps: EmbeddingEngineDeps; private pipeline: EmbeddingPipeline | null = null; private pipelineLoadPromise: Promise | null = null; // USearch index (lazily loaded per-project) private index: import('usearch').Index | null = null; private currentProjectId: string | null = null; // Label->postId map (backed by DB, kept in memory for fast lookup) private labelToPostId: Map = new Map(); private postIdToLabel: Map = new Map(); private nextLabel: bigint = 1n; // In-memory vector cache -- loaded from DB on startup, updated during embedding. private vectorCache: Map = new Map(); // postId -> vector // Debounced save timer private saveTimer: ReturnType | null = null; private readonly SAVE_DEBOUNCE_MS = 5000; // Model dimensions private readonly DIMENSIONS = 384; private readonly MODEL_ID = 'Xenova/multilingual-e5-small'; constructor(deps: EmbeddingEngineDeps) { super(); this.deps = deps; } // Lifecycle async initialize(): Promise { if (this.pipeline) return; if (this.pipelineLoadPromise) { await this.pipelineLoadPromise; return; } this.pipelineLoadPromise = this.loadPipeline(); this.pipeline = await this.pipelineLoadPromise; } private async loadPipeline(): Promise { if (this.deps.createPipeline) { return this.deps.createPipeline(); } // Dynamic import to avoid loading heavy ONNX runtime at startup const { pipeline, env } = await import('@huggingface/transformers'); // Configure cache for Electron -- use ~/.cache/huggingface env.useFSCache = true; const extractor = await pipeline('feature-extraction', this.MODEL_ID, { dtype: 'fp32', }); return { embed: async (text: string): Promise => { const output = await extractor(text, { pooling: 'mean', normalize: true }); // v3: output.data is Float32Array return output.data as Float32Array; }, }; } async shutdown(): Promise { if (this.saveTimer) { clearTimeout(this.saveTimer); this.saveTimer = null; } if (this.index && this.currentProjectId) { await this.save(); } this.index = null; this.currentProjectId = null; this.labelToPostId.clear(); this.postIdToLabel.clear(); this.vectorCache.clear(); this.nextLabel = 1n; this.pipeline = null; this.pipelineLoadPromise = null; } // Project switching async setProjectContext(projectId: string): Promise { if (this.currentProjectId === projectId) return; // Save and unload current index if (this.index && this.currentProjectId) { await this.save(); } this.index = null; this.labelToPostId.clear(); this.postIdToLabel.clear(); this.vectorCache.clear(); this.nextLabel = 1n; this.currentProjectId = projectId; // Load (or create) index for new project await this.ensureIndexLoaded(); } private async ensureIndexLoaded(): Promise { if (this.index) return; if (!this.currentProjectId) return; const { Index, MetricKind, ScalarKind } = await import('usearch'); this.index = new Index({ metric: MetricKind.Cos, quantization: ScalarKind.F32, dimensions: this.DIMENSIONS, connectivity: 16, expansion_add: 128, expansion_search: 64, multi: false, }); const indexPath = this.deps.getIndexPath(this.currentProjectId); try { await fs.access(indexPath); this.index.load(indexPath); } catch { // No index file yet -- start fresh } // Load key mapping and vectors from DB await this.loadKeyMapFromDb(this.currentProjectId); } private async loadKeyMapFromDb(projectId: string): Promise { const db = getDatabase().getLocal(); const rows = await db .select() .from(embeddingKeys) .where(eq(embeddingKeys.projectId, projectId)); this.labelToPostId.clear(); this.postIdToLabel.clear(); this.vectorCache.clear(); this.nextLabel = 1n; for (const row of rows) { const label = BigInt(row.label); this.labelToPostId.set(label, row.postId); this.postIdToLabel.set(row.postId, label); if (label >= this.nextLabel) { this.nextLabel = label + 1n; } if (row.vector) { const buf = row.vector as Buffer; this.vectorCache.set(row.postId, new Float32Array(buf.buffer, buf.byteOffset, buf.byteLength / 4)); } } } // Core operations async embedPost(postId: string, title: string, content: string): Promise { await this.initialize(); await this.ensureIndexLoaded(); if (!this.index || !this.pipeline || !this.currentProjectId) return; const rawText = `${title}\n\n${content}`; const hash = this.computeHash(rawText); // Check if already indexed with same hash (no-op) const db = getDatabase().getLocal(); const existing = await db .select() .from(embeddingKeys) .where( and( eq(embeddingKeys.postId, postId), eq(embeddingKeys.projectId, this.currentProjectId), ), ); if (existing.length > 0 && existing[0]!.contentHash === hash) { return; // Unchanged, skip re-embedding } // Remove old vector if exists if (existing.length > 0) { const oldLabel = BigInt(existing[0]!.label); try { this.index.remove(oldLabel); } catch { // Ignore remove errors -- label may not be in index } this.labelToPostId.delete(oldLabel); this.postIdToLabel.delete(postId); this.vectorCache.delete(postId); await db.delete(embeddingKeys).where( and( eq(embeddingKeys.postId, postId), eq(embeddingKeys.projectId, this.currentProjectId), ), ); } // Compute embedding const text = `query: ${rawText}`; const vector = await this.embedText(text); // Assign new label const label = this.nextLabel++; this.index.add(label, vector); this.labelToPostId.set(label, postId); this.postIdToLabel.set(postId, label); this.vectorCache.set(postId, vector); // Persist key mapping + vector (label is bigint in-memory, stored as number in SQLite) await db.insert(embeddingKeys).values({ label: Number(label), postId, projectId: this.currentProjectId, contentHash: hash, vector: Buffer.from(vector.buffer, vector.byteOffset, vector.byteLength), }); this.scheduleSave(); } async removePost(postId: string): Promise { await this.ensureIndexLoaded(); if (!this.index || !this.currentProjectId) return; const label = this.postIdToLabel.get(postId); if (label === undefined) return; try { this.index.remove(label); } catch { // Ignore remove errors } this.labelToPostId.delete(label); this.postIdToLabel.delete(postId); this.vectorCache.delete(postId); const db = getDatabase().getLocal(); await db.delete(embeddingKeys).where( and( eq(embeddingKeys.postId, postId), eq(embeddingKeys.projectId, this.currentProjectId), ), ); this.scheduleSave(); } async findSimilar(postId: string, k = 5): Promise { await this.ensureIndexLoaded(); if (!this.index || !this.currentProjectId) return []; if (!this.postIdToLabel.has(postId)) return []; // Guard against empty index (USearch throws on empty index search) if (this.postIdToLabel.size < 2) return []; // Get or compute vector for this post const vector = await this.getOrComputeVector(postId); if (!vector) return []; // Search for k+1 (to exclude self) with HNSW const result = this.index.search(vector, k + 1, 0); if (!result) return []; const results: SimilarPost[] = []; for (let i = 0; i < result.keys.length; i++) { const foundLabel = result.keys[i]!; const foundPostId = this.labelToPostId.get(foundLabel); if (!foundPostId || foundPostId === postId) continue; const distance = result.distances[i]!; // USearch cosine metric returns distance (0=identical), convert to similarity const similarity = Math.max(0, 1 - distance); results.push({ postId: foundPostId, similarity }); } return results.sort((a, b) => b.similarity - a.similarity).slice(0, k); } /** * Compute cosine similarity between a source post and a list of target posts. * Returns a map of targetPostId → similarity (0-1). Posts without embeddings are omitted. */ async computeSimilarities(sourcePostId: string, targetPostIds: string[]): Promise> { await this.ensureIndexLoaded(); if (!this.index || !this.currentProjectId || targetPostIds.length === 0) return {}; const sourceVec = await this.getOrComputeVector(sourcePostId); if (!sourceVec) return {}; const result: Record = {}; for (const targetId of targetPostIds) { if (targetId === sourcePostId) continue; const targetVec = await this.getOrComputeVector(targetId); if (!targetVec) continue; result[targetId] = this.cosineSimilarity(sourceVec, targetVec); } return result; } private cosineSimilarity(a: Float32Array, b: Float32Array): number { let dot = 0, normA = 0, normB = 0; for (let i = 0; i < a.length; i++) { dot += a[i]! * b[i]!; normA += a[i]! * a[i]!; normB += b[i]! * b[i]!; } const denom = Math.sqrt(normA) * Math.sqrt(normB); return denom === 0 ? 0 : Math.max(0, dot / denom); } // Derived features async suggestTags(postId: string, excludeTags: string[]): Promise { const similar = await this.findSimilar(postId, 10); if (similar.length === 0) return []; if (!this.currentProjectId) return []; // Get tags for similar posts const similarPostIds = similar.map((s) => s.postId); const db = getDatabase().getLocal(); const postRows = await db .select({ id: posts.id, tags: posts.tags }) .from(posts) .where(inArray(posts.id, similarPostIds)); const excludeSet = new Set(excludeTags.map((t) => t.toLowerCase())); const tagScores = new Map(); for (const row of postRows) { const simItem = similar.find((s) => s.postId === row.id); if (!simItem) continue; const postTags: string[] = JSON.parse(row.tags || '[]'); for (const tag of postTags) { if (excludeSet.has(tag.toLowerCase())) continue; const current = tagScores.get(tag) || 0; tagScores.set(tag, current + simItem.similarity); } } return Array.from(tagScores.entries()) .map(([name, score]) => ({ name, score })) .sort((a, b) => b.score - a.score) .slice(0, 5); } async findDuplicates(threshold = 0.92, onProgress?: (checked: number, total: number) => void): Promise { await this.ensureIndexLoaded(); if (!this.index || !this.currentProjectId) return []; const projectId = this.currentProjectId; const db = getDatabase().getLocal(); // Get dismissed pairs const dismissed = await db .select() .from(dismissedDuplicatePairs) .where(eq(dismissedDuplicatePairs.projectId, projectId)); const dismissedSet = new Set(); for (const d of dismissed) { dismissedSet.add(this.pairKey(d.postIdA, d.postIdB)); } // Get post info for all indexed posts const allPostIds = Array.from(this.postIdToLabel.keys()); if (allPostIds.length === 0) return []; const postRows = await db .select({ id: posts.id, title: posts.title, slug: posts.slug, content: posts.content, status: posts.status, filePath: posts.filePath, publishedAt: posts.publishedAt, }) .from(posts) .where(inArray(posts.id, allPostIds)); const postMap = new Map(postRows.map((p) => [p.id, p])); // Cache for lazily-loaded post bodies (needed for exact-match detection) const bodyCache = new Map(); const getBody = async (postId: string): Promise => { const cached = bodyCache.get(postId); if (cached !== undefined) return cached; const post = postMap.get(postId); if (!post) { bodyCache.set(postId, ''); return ''; } // Draft content is in the DB; published content is on the filesystem if (post.content) { bodyCache.set(postId, post.content); return post.content; } if (post.filePath) { try { const raw = await fs.readFile(post.filePath, 'utf-8'); const { content: body } = (await import('gray-matter')).default(raw); bodyCache.set(postId, body); return body; } catch { bodyCache.set(postId, ''); return ''; } } bodyCache.set(postId, ''); return ''; }; const pairs: DuplicatePair[] = []; const seenPairs = new Set(); for (let idx = 0; idx < allPostIds.length; idx++) { const postId = allPostIds[idx]!; onProgress?.(idx + 1, allPostIds.length); const vector = await this.getOrComputeVector(postId); if (!vector) continue; const result = this.index.search(vector, 21, 0); if (!result) continue; for (let i = 0; i < result.keys.length; i++) { const otherLabel = result.keys[i]!; const otherPostId = this.labelToPostId.get(otherLabel); if (!otherPostId || otherPostId === postId) continue; const distance = result.distances[i]!; const similarity = Math.max(0, 1 - distance); if (similarity < threshold) continue; const key = this.pairKey(postId, otherPostId); if (seenPairs.has(key) || dismissedSet.has(key)) continue; seenPairs.add(key); const postA = postMap.get(postId); const postB = postMap.get(otherPostId); if (!postA || !postB) continue; pairs.push({ postA: { id: postA.id, title: postA.title, slug: postA.slug, publishedAt: postA.publishedAt ?? undefined, }, postB: { id: postB.id, title: postB.title, slug: postB.slug, publishedAt: postB.publishedAt ?? undefined, }, similarity, }); } } // For pairs at 100% embedding similarity, compare actual bodies to find true exact duplicates for (const pair of pairs) { if (Math.round(pair.similarity * 100) >= 100) { const bodyA = await getBody(pair.postA.id); const bodyB = await getBody(pair.postB.id); const postA = postMap.get(pair.postA.id); const postB = postMap.get(pair.postB.id); if (postA && postB && postA.title === postB.title && bodyA === bodyB) { pair.exactMatch = true; } } } return pairs.sort((a, b) => { if (a.exactMatch && !b.exactMatch) return -1; if (!a.exactMatch && b.exactMatch) return 1; return b.similarity - a.similarity; }); } async dismissPair(postIdA: string, postIdB: string): Promise { if (!this.currentProjectId) return; const db = getDatabase().getLocal(); const [a, b] = this.sortedPairIds(postIdA, postIdB); await db.insert(dismissedDuplicatePairs).values({ id: uuidv4(), projectId: this.currentProjectId, postIdA: a, postIdB: b, dismissedAt: new Date(), }).onConflictDoNothing(); } async dismissPairs(pairIds: Array<[string, string]>): Promise { if (!this.currentProjectId) return; const db = getDatabase().getLocal(); const now = new Date(); const rows = pairIds.map(([idA, idB]) => { const [a, b] = this.sortedPairIds(idA, idB); return { id: uuidv4(), projectId: this.currentProjectId!, postIdA: a, postIdB: b, dismissedAt: now }; }); // Insert in batches of 100 to avoid SQLite variable limits for (let i = 0; i < rows.length; i += 100) { await db.insert(dismissedDuplicatePairs).values(rows.slice(i, i + 100)).onConflictDoNothing(); } } // Indexing management async getIndexingProgress(): Promise<{ indexed: number; total: number }> { if (!this.currentProjectId) return { indexed: 0, total: 0 }; await this.ensureIndexLoaded(); const db = getDatabase().getLocal(); const indexed = this.labelToPostId.size; const allPosts = await db .select({ id: posts.id }) .from(posts) .where(eq(posts.projectId, this.currentProjectId)); return { indexed, total: allPosts.length }; } async reindexAll(onProgress?: (indexed: number, total: number) => void): Promise { await this.ensureIndexLoaded(); if (!this.currentProjectId) return; const db = getDatabase().getLocal(); // Clear existing index await db.delete(embeddingKeys).where(eq(embeddingKeys.projectId, this.currentProjectId)); const { Index, MetricKind, ScalarKind } = await import('usearch'); this.index = new Index({ metric: MetricKind.Cos, quantization: ScalarKind.F32, dimensions: this.DIMENSIONS, connectivity: 16, expansion_add: 128, expansion_search: 64, multi: false, }); this.labelToPostId.clear(); this.postIdToLabel.clear(); this.vectorCache.clear(); this.nextLabel = 1n; await this.indexUnindexedPosts(onProgress); } async indexUnindexedPosts(onProgress?: (indexed: number, total: number) => void): Promise { await this.initialize(); await this.ensureIndexLoaded(); if (!this.currentProjectId) return; const db = getDatabase().getLocal(); const allPosts = await db .select({ id: posts.id, title: posts.title, content: posts.content, filePath: posts.filePath, }) .from(posts) .where(eq(posts.projectId, this.currentProjectId)); // Resolve actual content for each post (read from file for published posts) const resolvedPosts: Array<{ id: string; title: string; content: string }> = []; for (const p of allPosts) { let body = p.content || ''; if (!p.content && p.filePath) { try { const raw = await fs.readFile(p.filePath, 'utf-8'); const matter = (await import('gray-matter')).default; const { content: fileBody } = matter(raw); body = fileBody; } catch { // File not found — use empty } } resolvedPosts.push({ id: p.id, title: p.title, content: body }); } // Get current hashes from DB for change detection const keyRows = await db .select() .from(embeddingKeys) .where(eq(embeddingKeys.projectId, this.currentProjectId)); const hashMap = new Map(keyRows.map((r) => [r.postId, r.contentHash])); const toIndex = resolvedPosts.filter((p) => { const raw = `${p.title}\n\n${p.content}`; const hash = this.computeHash(raw); return hashMap.get(p.id) !== hash; }); let count = 0; let batchCount = 0; const BATCH_SAVE_INTERVAL = 100; for (const post of toIndex) { await this.embedPost(post.id, post.title, post.content); count++; batchCount++; onProgress?.(count, toIndex.length); if (batchCount >= BATCH_SAVE_INTERVAL) { await this.save(); batchCount = 0; } } if (batchCount > 0) { await this.save(); } } // Persistence async save(): Promise { if (this.saveTimer) { clearTimeout(this.saveTimer); this.saveTimer = null; } if (!this.index || !this.currentProjectId) return; const indexPath = this.deps.getIndexPath(this.currentProjectId); const dir = path.dirname(indexPath); await fs.mkdir(dir, { recursive: true }); this.index.save(indexPath); } private scheduleSave(): void { if (this.saveTimer) clearTimeout(this.saveTimer); this.saveTimer = setTimeout(() => { this.save().catch((err) => console.error('[EmbeddingEngine] save error:', err)); }, this.SAVE_DEBOUNCE_MS); } // Helpers /** * Get vector for a postId from in-memory cache (loaded from DB at startup). * Falls back to re-computing from post content only if not in cache. */ private async getOrComputeVector(postId: string): Promise { const cached = this.vectorCache.get(postId); if (cached) return cached; // Re-embed from post content await this.initialize(); if (!this.pipeline || !this.currentProjectId) return null; const resolved = await this.resolvePostContent(postId); if (!resolved) return null; const rawText = `${resolved.title}\n\n${resolved.content}`; const text = `query: ${rawText}`; const vector = await this.embedText(text); this.vectorCache.set(postId, vector); return vector; } private async embedText(text: string): Promise { if (!this.pipeline) throw new Error('EmbeddingEngine not initialized'); return this.pipeline.embed(text); } /** * Resolve the actual body text for a post. * Draft posts have content in the DB; published posts have it on the filesystem. */ private async resolvePostContent(postId: string): Promise<{ title: string; content: string } | null> { if (!this.currentProjectId) return null; const db = getDatabase().getLocal(); const rows = await db .select({ title: posts.title, content: posts.content, filePath: posts.filePath }) .from(posts) .where(and(eq(posts.id, postId), eq(posts.projectId, this.currentProjectId))); if (rows.length === 0) return null; const post = rows[0]!; if (post.content) return { title: post.title, content: post.content }; if (post.filePath) { try { const raw = await fs.readFile(post.filePath, 'utf-8'); const matter = (await import('gray-matter')).default; const { content: body } = matter(raw); return { title: post.title, content: body }; } catch { // File not found or unreadable — fall back to empty } } return { title: post.title, content: '' }; } private computeHash(text: string): string { return crypto.createHash('sha256').update(text).digest('hex'); } private pairKey(idA: string, idB: string): string { const [a, b] = this.sortedPairIds(idA, idB); return `${a}::${b}`; } private sortedPairIds(idA: string, idB: string): [string, string] { return idA < idB ? [idA, idB] : [idB, idA]; } }