perf: A1-14b replace O(n^2) embedding snapshot with hnswlib HNSW index and debounced persistence
This commit is contained in:
@@ -166,11 +166,9 @@ defmodule BDS.Embeddings do
|
||||
|
||||
case Repo.get_by(Key, post_id: post.id, project_id: post.project_id) do
|
||||
%Key{content_hash: ^content_hash} ->
|
||||
if Keyword.get(opts, :refresh_index, true) and
|
||||
snapshot_content_hash(post.project_id, post.id) != content_hash do
|
||||
:ok = rebuild_snapshot(post.project_id)
|
||||
end
|
||||
|
||||
# Embedding is already current. The HNSW index self-heals on query
|
||||
# (find_similar/find_duplicates rebuild when no index is loaded), so
|
||||
# there is nothing to refresh here.
|
||||
:ok
|
||||
|
||||
existing_key ->
|
||||
@@ -361,28 +359,28 @@ defmodule BDS.Embeddings do
|
||||
{:error, :not_found} ->
|
||||
{:ok, []}
|
||||
|
||||
{:ok, post, source_vector} ->
|
||||
similar =
|
||||
case Index.neighbors(post.project_id, post.id, limit) do
|
||||
{:ok, neighbors} ->
|
||||
neighbors
|
||||
{:ok, _post, nil} ->
|
||||
{:ok, []}
|
||||
|
||||
{:error, :missing} ->
|
||||
Repo.all(
|
||||
from key in Key,
|
||||
where: key.project_id == ^post.project_id and key.post_id != ^post.id
|
||||
)
|
||||
|> Enum.map(fn key ->
|
||||
%{
|
||||
post_id: key.post_id,
|
||||
score: cosine_similarity(source_vector, decode_vector(key.vector))
|
||||
}
|
||||
end)
|
||||
|> Enum.sort_by(& &1.score, :desc)
|
||||
|> Enum.take(max(limit, 0))
|
||||
end
|
||||
{:ok, post, %Key{} = key} ->
|
||||
{:ok, query_similar(post.project_id, key, limit)}
|
||||
end
|
||||
end
|
||||
|
||||
{:ok, similar}
|
||||
# Queries the HNSW index for a post's neighbours, rebuilding the index from
|
||||
# the DB vectors if it is not currently loaded (e.g. after a restart).
|
||||
defp query_similar(project_id, %Key{} = key, limit) do
|
||||
case Index.neighbors(project_id, key.label, key.vector, limit) do
|
||||
{:ok, neighbors} ->
|
||||
neighbors
|
||||
|
||||
{:error, :missing} ->
|
||||
:ok = rebuild_snapshot(project_id)
|
||||
|
||||
case Index.neighbors(project_id, key.label, key.vector, limit) do
|
||||
{:ok, neighbors} -> neighbors
|
||||
{:error, :missing} -> []
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@@ -395,8 +393,12 @@ defmodule BDS.Embeddings do
|
||||
{:error, :not_found} ->
|
||||
{:ok, %{}}
|
||||
|
||||
{:ok, post, source_vector} ->
|
||||
{:ok, _post, nil} ->
|
||||
{:ok, %{}}
|
||||
|
||||
{:ok, post, %Key{} = source_key} ->
|
||||
target_ids = Enum.uniq(target_post_ids)
|
||||
source_vector = decode_vector(source_key.vector)
|
||||
|
||||
scores =
|
||||
Repo.all(
|
||||
@@ -452,46 +454,18 @@ defmodule BDS.Embeddings do
|
||||
if enabled_for_project?(project_id) do
|
||||
on_progress = progress_callback(opts)
|
||||
dismissed = dismissed_pair_keys(project_id)
|
||||
entries = load_index_entries(project_id)
|
||||
|
||||
pairs =
|
||||
case duplicate_pairs_with_rebuild(project_id, entries, on_progress) do
|
||||
{:ok, pairs} -> pairs
|
||||
{:error, :missing} -> []
|
||||
end
|
||||
|
||||
duplicates =
|
||||
case Index.duplicate_pairs(project_id, @duplicate_threshold, on_progress: on_progress) do
|
||||
{:ok, pairs} ->
|
||||
pairs
|
||||
|> Enum.reject(fn pair -> pair_key(pair.post_id_a, pair.post_id_b) in dismissed end)
|
||||
|> enrich_duplicate_pairs(project_id)
|
||||
|
||||
{:error, :missing} ->
|
||||
keys =
|
||||
Repo.all(
|
||||
from key in Key,
|
||||
where: key.project_id == ^project_id,
|
||||
order_by: [asc: key.post_id]
|
||||
)
|
||||
|
||||
total_keys = length(keys)
|
||||
|
||||
:ok = report_rebuild_started(on_progress, total_keys, "embedding entries")
|
||||
|
||||
keys
|
||||
|> Enum.with_index(1)
|
||||
|> Enum.flat_map(fn {left, index} ->
|
||||
:ok = report_rebuild_progress(on_progress, index, total_keys, "embedding entries")
|
||||
|
||||
for right <- keys,
|
||||
left.post_id < right.post_id,
|
||||
pair_key(left.post_id, right.post_id) not in dismissed,
|
||||
similarity =
|
||||
cosine_similarity(decode_vector(left.vector), decode_vector(right.vector)),
|
||||
similarity >= @duplicate_threshold do
|
||||
%{
|
||||
post_id_a: left.post_id,
|
||||
post_id_b: right.post_id,
|
||||
score: similarity
|
||||
}
|
||||
end
|
||||
end)
|
||||
|> enrich_duplicate_pairs(project_id)
|
||||
end
|
||||
pairs
|
||||
|> Enum.reject(fn pair -> pair_key(pair.post_id_a, pair.post_id_b) in dismissed end)
|
||||
|> enrich_duplicate_pairs(project_id)
|
||||
|
||||
:ok = report_rebuild_phase(on_progress, 0.99, "Resolving duplicate candidates")
|
||||
{:ok, duplicates}
|
||||
@@ -555,17 +529,33 @@ defmodule BDS.Embeddings do
|
||||
with {:ok, post} <- fetch_post(post_id) do
|
||||
if enabled_for_project?(post.project_id) do
|
||||
:ok = ensure_key(post)
|
||||
|
||||
case Repo.get_by(Key, post_id: post.id, project_id: post.project_id) do
|
||||
nil -> {:ok, post, []}
|
||||
key -> {:ok, post, decode_vector(key.vector)}
|
||||
end
|
||||
{:ok, post, Repo.get_by(Key, post_id: post.id, project_id: post.project_id)}
|
||||
else
|
||||
{:disabled, post.project_id}
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
defp duplicate_pairs_with_rebuild(project_id, entries, on_progress) do
|
||||
case Index.duplicate_pairs(project_id, entries, @duplicate_threshold, on_progress: on_progress) do
|
||||
{:ok, pairs} ->
|
||||
{:ok, pairs}
|
||||
|
||||
{:error, :missing} ->
|
||||
:ok = rebuild_snapshot(project_id)
|
||||
Index.duplicate_pairs(project_id, entries, @duplicate_threshold, on_progress: on_progress)
|
||||
end
|
||||
end
|
||||
|
||||
defp load_index_entries(project_id) do
|
||||
Repo.all(
|
||||
from key in Key,
|
||||
where: key.project_id == ^project_id,
|
||||
order_by: [asc: key.post_id]
|
||||
)
|
||||
|> Enum.map(fn key -> %{label: key.label, post_id: key.post_id, vector: key.vector} end)
|
||||
end
|
||||
|
||||
defp ensure_key(%Post{} = post) do
|
||||
case Repo.get_by(Key, post_id: post.id, project_id: post.project_id) do
|
||||
nil -> sync_post(post)
|
||||
@@ -704,7 +694,7 @@ defmodule BDS.Embeddings do
|
||||
end
|
||||
|
||||
defp rebuild_snapshot(project_id) do
|
||||
Index.rebuild(project_id, model_id: model_id(), dimensions: dimensions())
|
||||
Index.put(project_id, dimensions(), load_index_entries(project_id))
|
||||
end
|
||||
|
||||
defp progress_callback(opts), do: ProgressReporter.callback(opts)
|
||||
@@ -729,13 +719,6 @@ defmodule BDS.Embeddings do
|
||||
defp report_rebuild_phase(callback, value, label),
|
||||
do: ProgressReporter.report_phase(callback, value, label)
|
||||
|
||||
defp snapshot_content_hash(project_id, post_id) do
|
||||
case Index.read(project_id) do
|
||||
{:ok, snapshot} -> get_in(snapshot, ["entries", post_id, "content_hash"])
|
||||
_other -> nil
|
||||
end
|
||||
end
|
||||
|
||||
defp current_embedding_status(nil, _expected_hash), do: "missing"
|
||||
|
||||
defp current_embedding_status(%Key{vector: vector}, _expected_hash) when vector in [nil, ""],
|
||||
|
||||
Reference in New Issue
Block a user