588 lines
18 KiB
Elixir
588 lines
18 KiB
Elixir
defmodule BDS.Embeddings.Index do
|
|
@moduledoc """
|
|
Per-project approximate-nearest-neighbour index over post embeddings.
|
|
|
|
Backed by an HNSW graph (hnswlib) per the A1-14b / `specs/embedding.allium`
|
|
requirement — cosine space, connectivity M=16, efConstruction=128,
|
|
efSearch=64. This replaces the previous O(n²) brute-force cosine snapshot:
|
|
building is O(n·log n) and queries are O(log n).
|
|
|
|
The process is intentionally **database-free**: callers (running in their own
|
|
process, e.g. under the test SQL sandbox) read embedding vectors from the DB
|
|
and hand them in. This GenServer owns only the in-memory HNSW graphs, the
|
|
`label → post_id` maps, and file persistence.
|
|
|
|
Persistence (DebouncedPersistence invariant): the index file
|
|
(`embeddings.usearch`) plus a small sidecar holding the dimension and the
|
|
label→post_id map are written behind a 5s debounce, and force-saved on
|
|
project switch / shutdown. On a cold query the index is lazily reloaded from
|
|
those files; if they are absent the caller rebuilds from the DB vectors.
|
|
"""
|
|
|
|
use GenServer
|
|
|
|
alias BDS.Projects
|
|
alias BDS.ProgressReporter
|
|
|
|
@neighbor_limit 21
|
|
@debounce_ms 5_000
|
|
@space :cosine
|
|
@m 16
|
|
@ef_construction 128
|
|
@ef_search 64
|
|
@meta_key :"$meta"
|
|
|
|
# ─── Public API ─────────────────────────────────────────────
|
|
|
|
def start_link(opts \\ []) do
|
|
GenServer.start_link(__MODULE__, opts, name: __MODULE__)
|
|
end
|
|
|
|
@doc "On-disk path of the HNSW index file for a project."
|
|
def path(project_id) when is_binary(project_id) do
|
|
Path.join(Projects.project_cache_dir(project_id), "embeddings.usearch")
|
|
end
|
|
|
|
@doc """
|
|
(Re)builds the index for a project from the given entries and schedules a
|
|
debounced save. `entries` is a list of `%{label:, post_id:, vector:}` where
|
|
`vector` is the packed little-endian Float32 BLOB.
|
|
"""
|
|
def put(project_id, dimensions, entries)
|
|
when is_binary(project_id) and is_integer(dimensions) and is_list(entries) do
|
|
GenServer.call(__MODULE__, {:put, project_id, dimensions, entries}, :infinity)
|
|
end
|
|
|
|
@doc """
|
|
Returns up to `limit` nearest neighbours of `query_vector` (the post's packed
|
|
BLOB), excluding `query_label`. `{:error, :missing}` if no index is available.
|
|
"""
|
|
def neighbors(project_id, query_label, query_vector, limit)
|
|
when is_binary(project_id) and is_integer(query_label) and is_binary(query_vector) do
|
|
GenServer.call(
|
|
__MODULE__,
|
|
{:neighbors, project_id, query_label, query_vector, limit},
|
|
:infinity
|
|
)
|
|
end
|
|
|
|
@doc """
|
|
Finds near-duplicate pairs at/above `threshold` by querying the HNSW graph for
|
|
each entry's neighbours. `{:error, :missing}` if no index is available.
|
|
"""
|
|
def duplicate_pairs(project_id, entries, threshold, opts \\ [])
|
|
when is_binary(project_id) and is_list(entries) and is_number(threshold) do
|
|
GenServer.call(
|
|
__MODULE__,
|
|
{:duplicate_pairs, project_id, entries, threshold, opts},
|
|
:infinity
|
|
)
|
|
end
|
|
|
|
@doc "Forces a pending save for a project to disk now (e.g. on project switch)."
|
|
def flush(project_id) when is_binary(project_id) do
|
|
GenServer.call(__MODULE__, {:flush, project_id}, :infinity)
|
|
end
|
|
|
|
@doc "Forces all pending saves to disk now (e.g. on shutdown)."
|
|
def flush_all do
|
|
GenServer.call(__MODULE__, :flush_all, :infinity)
|
|
end
|
|
|
|
@doc "Drops the in-memory index for a project (e.g. on project deletion)."
|
|
def forget(project_id) when is_binary(project_id) do
|
|
GenServer.call(__MODULE__, {:forget, project_id}, :infinity)
|
|
end
|
|
|
|
# ─── GenServer ──────────────────────────────────────────────
|
|
|
|
@impl true
|
|
def init(_opts) do
|
|
Process.flag(:trap_exit, true)
|
|
{:ok, %{@meta_key => %{flush_all_waiters: []}}}
|
|
end
|
|
|
|
@impl true
|
|
def handle_call({:put, project_id, dimensions, entries}, from, state) do
|
|
# Cancel any pending debounce for this project first: build_entry/2 returns a
|
|
# fresh entry with timer: nil, so without this the previous timer would be
|
|
# orphaned (left to fire a redundant save) instead of coalescing.
|
|
state = cancel_pending_save(state, project_id)
|
|
state = start_build(state, project_id, dimensions, entries, from)
|
|
{:noreply, state}
|
|
end
|
|
|
|
def handle_call({:neighbors, project_id, query_label, query_vector, limit}, _from, state) do
|
|
case ensure_loaded(state, project_id) do
|
|
{:ok, %{index: nil}, state} ->
|
|
{:reply, {:error, :missing}, state}
|
|
|
|
{:ok, entry, state} ->
|
|
{:reply, {:ok, query_neighbors(entry, query_label, query_vector, limit)}, state}
|
|
|
|
{:missing, state} ->
|
|
{:reply, {:error, :missing}, state}
|
|
end
|
|
end
|
|
|
|
def handle_call({:duplicate_pairs, project_id, entries, threshold, opts}, from, state) do
|
|
case ensure_loaded(state, project_id) do
|
|
{:ok, %{index: nil}, state} ->
|
|
{:reply, {:error, :missing}, state}
|
|
|
|
{:ok, entry, state} ->
|
|
state = start_duplicate_scan(state, project_id, entry, entries, threshold, opts, from)
|
|
{:noreply, state}
|
|
|
|
{:missing, state} ->
|
|
{:reply, {:error, :missing}, state}
|
|
end
|
|
end
|
|
|
|
def handle_call({:flush, project_id}, from, state) do
|
|
case Map.get(state, project_id) do
|
|
%{build: %{}} = entry ->
|
|
entry = update_in(entry.build.flush_waiters, &[from | &1])
|
|
{:noreply, Map.put(state, project_id, entry)}
|
|
|
|
_other ->
|
|
{:reply, :ok, save_now(state, project_id)}
|
|
end
|
|
end
|
|
|
|
def handle_call(:flush_all, from, state) do
|
|
if builds_in_progress?(state) do
|
|
state = update_meta(state, fn meta -> %{meta | flush_all_waiters: [from | meta.flush_all_waiters]} end)
|
|
{:noreply, state}
|
|
else
|
|
state = flush_all_projects(state)
|
|
{:reply, :ok, state}
|
|
end
|
|
end
|
|
|
|
def handle_call({:forget, project_id}, _from, state) do
|
|
{:reply, :ok, forget_project(state, project_id)}
|
|
end
|
|
|
|
@impl true
|
|
def handle_info({:save, project_id}, state) do
|
|
{:noreply, save_now(state, project_id)}
|
|
end
|
|
|
|
def handle_info({ref, built_entry}, state) when is_reference(ref) do
|
|
case find_build_owner(state, ref) do
|
|
{:ok, project_id, entry} ->
|
|
Process.demonitor(ref, [:flush])
|
|
{:noreply, complete_build(state, project_id, entry, built_entry)}
|
|
|
|
:error ->
|
|
case find_scan_owner(state, ref) do
|
|
{:ok, project_id, entry, %{from: from}} ->
|
|
Process.demonitor(ref, [:flush])
|
|
GenServer.reply(from, {:ok, built_entry})
|
|
entry = %{entry | scans: Map.delete(entry.scans, ref)}
|
|
{:noreply, Map.put(state, project_id, entry)}
|
|
|
|
:error ->
|
|
{:noreply, state}
|
|
end
|
|
end
|
|
end
|
|
|
|
def handle_info({:DOWN, ref, :process, _pid, reason}, state) when is_reference(ref) do
|
|
case find_build_owner(state, ref) do
|
|
{:ok, _project_id, _entry} ->
|
|
exit({:index_build_failed, reason})
|
|
|
|
:error ->
|
|
case find_scan_owner(state, ref) do
|
|
{:ok, _project_id, _entry, _scan} -> exit({:duplicate_scan_failed, reason})
|
|
:error -> {:noreply, state}
|
|
end
|
|
end
|
|
end
|
|
|
|
def handle_info(_message, state), do: {:noreply, state}
|
|
|
|
@impl true
|
|
def terminate(_reason, state) do
|
|
Enum.each(project_ids(state), &save_now(state, &1))
|
|
:ok
|
|
end
|
|
|
|
# ─── Build / query ──────────────────────────────────────────
|
|
|
|
defp build_entry(dimensions, []), do: %{index: nil, labels: %{}, dim: dimensions, timer: nil}
|
|
|
|
defp build_entry(dimensions, entries) do
|
|
count = length(entries)
|
|
|
|
{:ok, index} =
|
|
HNSWLib.Index.new(@space, dimensions, count, m: @m, ef_construction: @ef_construction)
|
|
|
|
:ok = HNSWLib.Index.set_ef(index, @ef_search)
|
|
|
|
tensor =
|
|
entries
|
|
|> Enum.map(& &1.vector)
|
|
|> IO.iodata_to_binary()
|
|
|> Nx.from_binary(:f32)
|
|
|> Nx.reshape({count, dimensions})
|
|
|
|
:ok = HNSWLib.Index.add_items(index, tensor, ids: Enum.map(entries, & &1.label))
|
|
|
|
%{
|
|
index: index,
|
|
labels: Map.new(entries, &{&1.label, &1.post_id}),
|
|
dim: dimensions,
|
|
timer: nil
|
|
}
|
|
end
|
|
|
|
defp query_neighbors(%{index: index, labels: labels}, query_label, query_vector, limit) do
|
|
case query(index, query_vector, limit + 1) do
|
|
[] ->
|
|
[]
|
|
|
|
results ->
|
|
results
|
|
|> Enum.reject(fn {label, _score} -> label == query_label end)
|
|
|> Enum.map(fn {label, score} -> %{post_id: Map.get(labels, label), score: score} end)
|
|
|> Enum.reject(&is_nil(&1.post_id))
|
|
|> Enum.take(max(limit, 0))
|
|
end
|
|
end
|
|
|
|
defp scan_duplicates(%{index: index, labels: labels}, entries, threshold, opts) do
|
|
on_progress = ProgressReporter.callback(opts)
|
|
total = length(entries)
|
|
:ok = report_scan_started(on_progress, total, "embedding entries")
|
|
|
|
entries
|
|
|> Enum.with_index(1)
|
|
|> Enum.flat_map(fn {entry, position} ->
|
|
:ok = report_scan_progress(on_progress, position, total, "embedding entries")
|
|
|
|
index
|
|
|> query(entry.vector, @neighbor_limit)
|
|
|> Enum.reject(fn {label, _score} -> label == entry.label end)
|
|
|> Enum.map(fn {label, score} -> {Map.get(labels, label), score} end)
|
|
|> Enum.filter(fn {post_id, score} -> not is_nil(post_id) and score >= threshold end)
|
|
|> Enum.map(fn {other_post_id, score} ->
|
|
{post_id_a, post_id_b} = sort_pair(entry.post_id, other_post_id)
|
|
{{post_id_a, post_id_b}, %{post_id_a: post_id_a, post_id_b: post_id_b, score: score}}
|
|
end)
|
|
end)
|
|
|> Map.new()
|
|
|> Map.values()
|
|
|> Enum.sort_by(& &1.score, :desc)
|
|
end
|
|
|
|
# Runs a knn query and returns [{label, similarity}] sorted by descending
|
|
# similarity. Cosine distance is converted to similarity as max(0, 1 - d).
|
|
defp query(index, query_vector, k) do
|
|
case HNSWLib.Index.get_current_count(index) do
|
|
{:ok, count} when count > 0 ->
|
|
clamped = min(k, count)
|
|
|
|
case HNSWLib.Index.knn_query(index, query_vector, k: clamped) do
|
|
{:ok, labels, distances} ->
|
|
Enum.zip(
|
|
Nx.to_flat_list(labels),
|
|
Enum.map(Nx.to_flat_list(distances), fn distance -> max(0.0, 1.0 - distance) end)
|
|
)
|
|
|
|
{:error, _reason} ->
|
|
[]
|
|
end
|
|
|
|
_other ->
|
|
[]
|
|
end
|
|
end
|
|
|
|
# ─── Persistence ────────────────────────────────────────────
|
|
|
|
defp schedule_save(state, project_id) do
|
|
entry = Map.fetch!(state, project_id)
|
|
if is_reference(entry.timer), do: Process.cancel_timer(entry.timer)
|
|
timer = Process.send_after(self(), {:save, project_id}, @debounce_ms)
|
|
Map.put(state, project_id, %{entry | timer: timer})
|
|
end
|
|
|
|
defp cancel_pending_save(state, project_id) do
|
|
case Map.get(state, project_id) do
|
|
%{timer: timer} = entry when is_reference(timer) ->
|
|
Process.cancel_timer(timer)
|
|
Map.put(state, project_id, %{entry | timer: nil})
|
|
|
|
_other ->
|
|
state
|
|
end
|
|
end
|
|
|
|
defp save_now(state, project_id) do
|
|
case Map.get(state, project_id) do
|
|
nil ->
|
|
state
|
|
|
|
%{build: %{}} = entry ->
|
|
Map.put(state, project_id, entry)
|
|
|
|
entry ->
|
|
if is_reference(entry.timer), do: Process.cancel_timer(entry.timer)
|
|
persist(project_id, entry)
|
|
Map.put(state, project_id, %{entry | timer: nil})
|
|
end
|
|
end
|
|
|
|
defp persist(_project_id, %{index: nil}), do: :ok
|
|
|
|
defp persist(project_id, %{index: index, labels: labels, dim: dim}) do
|
|
index_path = path(project_id)
|
|
File.mkdir_p!(Path.dirname(index_path))
|
|
HNSWLib.Index.save_index(index, index_path)
|
|
write_meta(index_path, dim, labels)
|
|
:ok
|
|
rescue
|
|
_exception -> :ok
|
|
end
|
|
|
|
defp write_meta(index_path, dim, labels) do
|
|
payload = %{
|
|
"dim" => dim,
|
|
"labels" => Enum.map(labels, fn {label, post_id} -> [label, post_id] end)
|
|
}
|
|
|
|
File.write(meta_path(index_path), Jason.encode!(payload))
|
|
end
|
|
|
|
defp ensure_loaded(state, project_id) do
|
|
case Map.get(state, project_id) do
|
|
nil ->
|
|
case load_from_disk(project_id) do
|
|
{:ok, entry} ->
|
|
entry = runtime_entry(entry)
|
|
{:ok, entry, Map.put(state, project_id, entry)}
|
|
|
|
:error -> {:missing, state}
|
|
end
|
|
|
|
entry ->
|
|
{:ok, entry, state}
|
|
end
|
|
end
|
|
|
|
defp load_from_disk(project_id) do
|
|
index_path = path(project_id)
|
|
|
|
with {:ok, %{dim: dim, labels: labels}} <- read_meta(index_path),
|
|
true <- File.exists?(index_path),
|
|
{:ok, index} <- HNSWLib.Index.load_index(@space, dim, index_path) do
|
|
:ok = HNSWLib.Index.set_ef(index, @ef_search)
|
|
{:ok, %{index: index, labels: labels, dim: dim, timer: nil}}
|
|
else
|
|
_other -> :error
|
|
end
|
|
rescue
|
|
_exception -> :error
|
|
end
|
|
|
|
defp read_meta(index_path) do
|
|
with {:ok, contents} <- File.read(meta_path(index_path)),
|
|
{:ok, %{"dim" => dim, "labels" => labels}} <- Jason.decode(contents) do
|
|
{:ok,
|
|
%{
|
|
dim: dim,
|
|
labels: Map.new(labels, fn [label, post_id] -> {label, post_id} end)
|
|
}}
|
|
else
|
|
_other -> :error
|
|
end
|
|
end
|
|
|
|
defp meta_path(index_path), do: index_path <> ".meta.json"
|
|
|
|
defp runtime_entry(entry) do
|
|
Map.merge(%{timer: nil, build: nil, scans: %{}}, entry)
|
|
end
|
|
|
|
defp start_build(state, project_id, dimensions, entries, from) do
|
|
entry =
|
|
state
|
|
|> Map.get(project_id, runtime_entry(%{index: nil, labels: %{}, dim: dimensions, timer: nil}))
|
|
|> Map.put(:dim, dimensions)
|
|
|
|
case entry.build do
|
|
nil ->
|
|
task = start_build_task(project_id, dimensions, entries)
|
|
build = %{ref: task.ref, pid: task.pid, callers: [from], flush_waiters: [], next_request: nil}
|
|
Map.put(state, project_id, %{entry | build: build})
|
|
|
|
build ->
|
|
build = %{build | callers: [from | build.callers], next_request: {dimensions, entries}}
|
|
Map.put(state, project_id, %{entry | build: build})
|
|
end
|
|
end
|
|
|
|
defp start_build_task(project_id, dimensions, entries) do
|
|
Task.Supervisor.async_nolink(BDS.Tasks.TaskSupervisor, fn ->
|
|
maybe_run_test_hook({:before_build, project_id, self()})
|
|
build_entry(dimensions, entries)
|
|
end)
|
|
end
|
|
|
|
defp complete_build(state, project_id, entry, built_entry) do
|
|
build = entry.build
|
|
|
|
case build.next_request do
|
|
{next_dimensions, next_entries} ->
|
|
task = start_build_task(project_id, next_dimensions, next_entries)
|
|
|
|
build = %{build | ref: task.ref, pid: task.pid, next_request: nil}
|
|
Map.put(state, project_id, %{entry | build: build})
|
|
|
|
nil ->
|
|
Enum.each(build.callers, &GenServer.reply(&1, :ok))
|
|
|
|
entry = %{runtime_entry(built_entry) | scans: entry.scans}
|
|
state = Map.put(state, project_id, entry)
|
|
|
|
state =
|
|
if build.flush_waiters == [] do
|
|
schedule_save(state, project_id)
|
|
else
|
|
Enum.each(build.flush_waiters, &GenServer.reply(&1, :ok))
|
|
save_now(state, project_id)
|
|
end
|
|
|
|
maybe_finish_flush_all_waiters(state)
|
|
end
|
|
end
|
|
|
|
defp start_duplicate_scan(state, project_id, entry, entries, threshold, opts, from) do
|
|
task =
|
|
Task.Supervisor.async_nolink(BDS.Tasks.TaskSupervisor, fn ->
|
|
scan_duplicates(entry, entries, threshold, opts)
|
|
end)
|
|
|
|
scans = Map.put(entry.scans, task.ref, %{pid: task.pid, from: from})
|
|
Map.put(state, project_id, %{entry | scans: scans})
|
|
end
|
|
|
|
defp forget_project(state, project_id) do
|
|
case Map.get(state, project_id) do
|
|
nil ->
|
|
maybe_finish_flush_all_waiters(state)
|
|
|
|
entry ->
|
|
if is_reference(entry.timer), do: Process.cancel_timer(entry.timer)
|
|
|
|
if build = entry.build do
|
|
_ = Task.Supervisor.terminate_child(BDS.Tasks.TaskSupervisor, build.pid)
|
|
Enum.each(build.callers, &GenServer.reply(&1, :ok))
|
|
Enum.each(build.flush_waiters, &GenServer.reply(&1, :ok))
|
|
end
|
|
|
|
Enum.each(entry.scans, fn {_ref, %{pid: pid, from: from}} ->
|
|
_ = Task.Supervisor.terminate_child(BDS.Tasks.TaskSupervisor, pid)
|
|
GenServer.reply(from, {:error, :missing})
|
|
end)
|
|
|
|
state
|
|
|> Map.delete(project_id)
|
|
|> maybe_finish_flush_all_waiters()
|
|
end
|
|
end
|
|
|
|
defp find_build_owner(state, ref) do
|
|
Enum.find_value(project_ids(state), :error, fn project_id ->
|
|
case Map.get(state, project_id) do
|
|
%{build: %{ref: ^ref}} = entry -> {:ok, project_id, entry}
|
|
_other -> false
|
|
end
|
|
end)
|
|
end
|
|
|
|
defp find_scan_owner(state, ref) do
|
|
Enum.find_value(project_ids(state), :error, fn project_id ->
|
|
case Map.get(state, project_id) do
|
|
%{scans: scans} = entry ->
|
|
case Map.get(scans, ref) do
|
|
nil -> false
|
|
scan -> {:ok, project_id, entry, scan}
|
|
end
|
|
|
|
_other ->
|
|
false
|
|
end
|
|
end)
|
|
end
|
|
|
|
defp project_ids(state) do
|
|
state
|
|
|> Map.keys()
|
|
|> Enum.filter(&is_binary/1)
|
|
end
|
|
|
|
defp builds_in_progress?(state) do
|
|
Enum.any?(project_ids(state), fn project_id ->
|
|
match?(%{build: %{}} , Map.get(state, project_id))
|
|
end)
|
|
end
|
|
|
|
defp flush_all_projects(state) do
|
|
Enum.reduce(project_ids(state), state, &save_now(&2, &1))
|
|
end
|
|
|
|
defp maybe_finish_flush_all_waiters(state) do
|
|
meta = meta(state)
|
|
|
|
cond do
|
|
meta.flush_all_waiters == [] ->
|
|
state
|
|
|
|
builds_in_progress?(state) ->
|
|
state
|
|
|
|
true ->
|
|
state = flush_all_projects(state)
|
|
Enum.each(meta.flush_all_waiters, &GenServer.reply(&1, :ok))
|
|
put_meta(state, %{meta | flush_all_waiters: []})
|
|
end
|
|
end
|
|
|
|
defp meta(state), do: Map.get(state, @meta_key, %{flush_all_waiters: []})
|
|
|
|
defp put_meta(state, meta), do: Map.put(state, @meta_key, meta)
|
|
|
|
defp update_meta(state, fun), do: put_meta(state, fun.(meta(state)))
|
|
|
|
defp maybe_run_test_hook(event) do
|
|
case Application.get_env(:bds, :embeddings_index_test_hook) do
|
|
callback when is_function(callback, 1) -> callback.(event)
|
|
_other -> :ok
|
|
end
|
|
end
|
|
|
|
defp sort_pair(post_id_a, post_id_b) when post_id_a <= post_id_b, do: {post_id_a, post_id_b}
|
|
defp sort_pair(post_id_a, post_id_b), do: {post_id_b, post_id_a}
|
|
|
|
defp report_scan_started(callback, total, label) do
|
|
ProgressReporter.report_count_started(callback, total, label,
|
|
verb: "Scanning",
|
|
start_progress: 0.0,
|
|
empty_suffix: "to scan",
|
|
message_style: :prefix_count
|
|
)
|
|
end
|
|
|
|
defp report_scan_progress(callback, current, total, label) do
|
|
ProgressReporter.report_count_progress(callback, current, total, label,
|
|
verb: "Scanning",
|
|
start_progress: 0.0,
|
|
message_style: :prefix_count
|
|
)
|
|
end
|
|
end
|