feat: finalisation (hopefully) for embedding
This commit is contained in:
@@ -4,18 +4,111 @@ defmodule BDS.Embeddings do
|
||||
import Ecto.Query
|
||||
|
||||
alias BDS.Embeddings.DismissedDuplicatePair
|
||||
alias BDS.Embeddings.Index
|
||||
alias BDS.Embeddings.Key
|
||||
alias BDS.Metadata
|
||||
alias BDS.Posts.Post
|
||||
alias BDS.Projects
|
||||
alias BDS.Repo
|
||||
|
||||
@duplicate_threshold 0.5
|
||||
@duplicate_threshold 0.92
|
||||
@exact_match_score 0.999999
|
||||
|
||||
def model_id, do: configured_backend().model_info().model_id
|
||||
def dimensions, do: configured_backend().model_info().dimensions
|
||||
def index_path(project_id), do: Index.path(project_id)
|
||||
def reindex_all(project_id), do: rebuild_project(project_id)
|
||||
|
||||
def get_indexing_progress(project_id) when is_binary(project_id) do
|
||||
indexed =
|
||||
Repo.one(
|
||||
from key in Key,
|
||||
where: key.project_id == ^project_id,
|
||||
select: count(key.post_id, :distinct)
|
||||
) || 0
|
||||
|
||||
total =
|
||||
Repo.one(
|
||||
from post in Post,
|
||||
where: post.project_id == ^project_id,
|
||||
select: count(post.id)
|
||||
) || 0
|
||||
|
||||
{:ok, %{indexed: indexed, total: total}}
|
||||
end
|
||||
|
||||
def sync_post(%Post{} = post) do
|
||||
sync_post(post, refresh_index: true)
|
||||
end
|
||||
|
||||
def sync_post(post_id) when is_binary(post_id) do
|
||||
case Repo.get(Post, post_id) do
|
||||
nil -> :ok
|
||||
post -> sync_post(post)
|
||||
end
|
||||
end
|
||||
|
||||
def rebuild_project(project_id) when is_binary(project_id) do
|
||||
if enabled_for_project?(project_id) do
|
||||
posts =
|
||||
Repo.all(from post in Post, where: post.project_id == ^project_id, order_by: [asc: post.created_at, asc: post.slug])
|
||||
|
||||
post_ids = Enum.map(posts, & &1.id)
|
||||
|
||||
Repo.delete_all(
|
||||
from key in Key,
|
||||
where: key.project_id == ^project_id and key.post_id not in ^post_ids
|
||||
)
|
||||
|
||||
Enum.each(posts, &sync_post(&1, refresh_index: false))
|
||||
:ok = rebuild_snapshot(project_id)
|
||||
{:ok, post_ids}
|
||||
else
|
||||
{:ok, []}
|
||||
end
|
||||
end
|
||||
|
||||
def diff_reports(project_id) when is_binary(project_id) do
|
||||
if enabled_for_project?(project_id) do
|
||||
snapshot_entries =
|
||||
case Index.read(project_id) do
|
||||
{:ok, snapshot} -> Map.get(snapshot, "entries", %{})
|
||||
_other -> %{}
|
||||
end
|
||||
|
||||
keys_by_post =
|
||||
Repo.all(from key in Key, where: key.project_id == ^project_id)
|
||||
|> Map.new(&{&1.post_id, &1})
|
||||
|
||||
Repo.all(from post in Post, where: post.project_id == ^project_id)
|
||||
|> Enum.flat_map(fn post ->
|
||||
expected_hash = post_content_hash(post)
|
||||
key = Map.get(keys_by_post, post.id)
|
||||
snapshot_entry = Map.get(snapshot_entries, post.id)
|
||||
|
||||
differences =
|
||||
[
|
||||
diff_field("content_hash", key && key.content_hash, expected_hash),
|
||||
diff_field(
|
||||
"snapshot_content_hash",
|
||||
snapshot_entry && snapshot_entry["content_hash"],
|
||||
key && key.content_hash
|
||||
)
|
||||
]
|
||||
|> Enum.reject(&is_nil/1)
|
||||
|
||||
if differences == [] do
|
||||
[]
|
||||
else
|
||||
[%{entity_type: "embedding", entity_id: post.id, differences: differences}]
|
||||
end
|
||||
end)
|
||||
else
|
||||
[]
|
||||
end
|
||||
end
|
||||
|
||||
defp sync_post(%Post{} = post, opts) do
|
||||
if enabled_for_project?(post.project_id) do
|
||||
body = resolve_post_body(post)
|
||||
raw_text = compose_embedding_source(post.title, body)
|
||||
@@ -39,6 +132,10 @@ defmodule BDS.Embeddings do
|
||||
})
|
||||
|> Repo.insert_or_update()
|
||||
|
||||
if Keyword.get(opts, :refresh_index, true) do
|
||||
:ok = rebuild_snapshot(post.project_id)
|
||||
end
|
||||
|
||||
:ok
|
||||
end
|
||||
else
|
||||
@@ -46,15 +143,22 @@ defmodule BDS.Embeddings do
|
||||
end
|
||||
end
|
||||
|
||||
def sync_post(post_id) when is_binary(post_id) do
|
||||
case Repo.get(Post, post_id) do
|
||||
nil -> :ok
|
||||
post -> sync_post(post)
|
||||
end
|
||||
end
|
||||
|
||||
def remove_post(post_id) when is_binary(post_id) do
|
||||
project_id =
|
||||
case Repo.get_by(Key, post_id: post_id) do
|
||||
%Key{project_id: project_id} -> project_id
|
||||
nil -> case Repo.get(Post, post_id) do
|
||||
%Post{project_id: project_id} -> project_id
|
||||
nil -> nil
|
||||
end
|
||||
end
|
||||
|
||||
Repo.delete_all(from key in Key, where: key.post_id == ^post_id)
|
||||
|
||||
if is_binary(project_id) and enabled_for_project?(project_id) do
|
||||
:ok = rebuild_snapshot(project_id)
|
||||
end
|
||||
|
||||
:ok
|
||||
end
|
||||
|
||||
@@ -70,10 +174,16 @@ defmodule BDS.Embeddings do
|
||||
case Repo.get_by(Key, post_id: post.id, project_id: project_id) do
|
||||
%Key{content_hash: ^content_hash} -> :ok
|
||||
_other ->
|
||||
:ok = sync_post(%{post | content: if(post.content in [nil, ""], do: body, else: post.content)})
|
||||
:ok =
|
||||
sync_post(
|
||||
%{post | content: if(post.content in [nil, ""], do: body, else: post.content)},
|
||||
refresh_index: false
|
||||
)
|
||||
end
|
||||
end)
|
||||
|
||||
:ok = rebuild_snapshot(project_id)
|
||||
|
||||
indexed = Repo.all(from key in Key, where: key.project_id == ^project_id, select: key.post_id)
|
||||
|
||||
{:ok, indexed}
|
||||
@@ -88,10 +198,14 @@ defmodule BDS.Embeddings do
|
||||
{:error, :not_found} -> {:ok, []}
|
||||
{:ok, post, source_vector} ->
|
||||
similar =
|
||||
Repo.all(from key in Key, where: key.project_id == ^post.project_id and key.post_id != ^post.id)
|
||||
|> Enum.map(fn key -> %{post_id: key.post_id, score: cosine_similarity(source_vector, decode_vector(key.vector))} end)
|
||||
|> Enum.sort_by(& &1.score, :desc)
|
||||
|> Enum.take(max(limit, 0))
|
||||
case Index.neighbors(post.project_id, post.id, limit) do
|
||||
{:ok, neighbors} -> neighbors
|
||||
{:error, :missing} ->
|
||||
Repo.all(from key in Key, where: key.project_id == ^post.project_id and key.post_id != ^post.id)
|
||||
|> Enum.map(fn key -> %{post_id: key.post_id, score: cosine_similarity(source_vector, decode_vector(key.vector))} end)
|
||||
|> Enum.sort_by(& &1.score, :desc)
|
||||
|> Enum.take(max(limit, 0))
|
||||
end
|
||||
|
||||
{:ok, similar}
|
||||
end
|
||||
@@ -150,22 +264,31 @@ defmodule BDS.Embeddings do
|
||||
def find_duplicates(project_id) when is_binary(project_id) do
|
||||
if enabled_for_project?(project_id) do
|
||||
dismissed = dismissed_pair_keys(project_id)
|
||||
keys = Repo.all(from key in Key, where: key.project_id == ^project_id, order_by: [asc: key.post_id])
|
||||
|
||||
duplicates =
|
||||
for left <- keys,
|
||||
right <- keys,
|
||||
left.post_id < right.post_id,
|
||||
pair_key(left.post_id, right.post_id) not in dismissed,
|
||||
similarity = cosine_similarity(decode_vector(left.vector), decode_vector(right.vector)),
|
||||
similarity >= @duplicate_threshold do
|
||||
%{
|
||||
post_id_a: left.post_id,
|
||||
post_id_b: right.post_id,
|
||||
score: similarity
|
||||
}
|
||||
case Index.duplicate_pairs(project_id, @duplicate_threshold) do
|
||||
{:ok, pairs} ->
|
||||
pairs
|
||||
|> Enum.reject(fn pair -> pair_key(pair.post_id_a, pair.post_id_b) in dismissed end)
|
||||
|> enrich_duplicate_pairs(project_id)
|
||||
|
||||
{:error, :missing} ->
|
||||
keys = Repo.all(from key in Key, where: key.project_id == ^project_id, order_by: [asc: key.post_id])
|
||||
|
||||
for left <- keys,
|
||||
right <- keys,
|
||||
left.post_id < right.post_id,
|
||||
pair_key(left.post_id, right.post_id) not in dismissed,
|
||||
similarity = cosine_similarity(decode_vector(left.vector), decode_vector(right.vector)),
|
||||
similarity >= @duplicate_threshold do
|
||||
%{
|
||||
post_id_a: left.post_id,
|
||||
post_id_b: right.post_id,
|
||||
score: similarity
|
||||
}
|
||||
end
|
||||
|> enrich_duplicate_pairs(project_id)
|
||||
end
|
||||
|> Enum.sort_by(& &1.score, :desc)
|
||||
|
||||
{:ok, duplicates}
|
||||
else
|
||||
@@ -204,6 +327,26 @@ defmodule BDS.Embeddings do
|
||||
end
|
||||
end
|
||||
|
||||
def dismiss_duplicate_pairs(pair_ids) when is_list(pair_ids) do
|
||||
pair_ids
|
||||
|> Enum.filter(fn
|
||||
{post_id_a, post_id_b} when is_binary(post_id_a) and is_binary(post_id_b) -> true
|
||||
_other -> false
|
||||
end)
|
||||
|> Enum.map(fn {post_id_a, post_id_b} -> sort_pair(post_id_a, post_id_b) end)
|
||||
|> Enum.uniq()
|
||||
|> Enum.reduce_while({:ok, []}, fn {post_id_a, post_id_b}, {:ok, acc} ->
|
||||
case dismiss_duplicate_pair(post_id_a, post_id_b) do
|
||||
{:ok, saved_pair} -> {:cont, {:ok, [saved_pair | acc]}}
|
||||
{:error, reason} -> {:halt, {:error, reason}}
|
||||
end
|
||||
end)
|
||||
|> case do
|
||||
{:ok, saved_pairs} -> {:ok, Enum.reverse(saved_pairs)}
|
||||
{:error, reason} -> {:error, reason}
|
||||
end
|
||||
end
|
||||
|
||||
defp source_post_and_vector(post_id) do
|
||||
with {:ok, post} <- fetch_post(post_id) do
|
||||
if enabled_for_project?(post.project_id) do
|
||||
@@ -233,6 +376,37 @@ defmodule BDS.Embeddings do
|
||||
end
|
||||
end
|
||||
|
||||
defp enrich_duplicate_pairs(pairs, project_id) do
|
||||
posts_by_id =
|
||||
pairs
|
||||
|> Enum.flat_map(&[&1.post_id_a, &1.post_id_b])
|
||||
|> Enum.uniq()
|
||||
|> then(fn post_ids ->
|
||||
Repo.all(from post in Post, where: post.project_id == ^project_id and post.id in ^post_ids)
|
||||
|> Map.new(&{&1.id, &1})
|
||||
end)
|
||||
|
||||
pairs
|
||||
|> Enum.map(fn pair ->
|
||||
post_a = Map.fetch!(posts_by_id, pair.post_id_a)
|
||||
post_b = Map.fetch!(posts_by_id, pair.post_id_b)
|
||||
exact_match = exact_duplicate_match?(pair.score, post_a, post_b)
|
||||
|
||||
pair
|
||||
|> Map.put(:title_a, post_a.title || "")
|
||||
|> Map.put(:title_b, post_b.title || "")
|
||||
|> Map.put(:similarity, pair.score)
|
||||
|> Map.put(:exact_match, exact_match)
|
||||
end)
|
||||
|> Enum.sort_by(fn pair -> {not pair.exact_match, -pair.score, pair.post_id_a, pair.post_id_b} end)
|
||||
end
|
||||
|
||||
defp exact_duplicate_match?(score, %Post{} = post_a, %Post{} = post_b) do
|
||||
score >= @exact_match_score and
|
||||
(post_a.title || "") == (post_b.title || "") and
|
||||
resolve_post_body(post_a) == resolve_post_body(post_b)
|
||||
end
|
||||
|
||||
defp enabled_for_project?(project_id) do
|
||||
case Metadata.get_project_metadata(project_id) do
|
||||
{:ok, metadata} -> metadata.semantic_similarity_enabled == true
|
||||
@@ -280,10 +454,30 @@ defmodule BDS.Embeddings do
|
||||
|
||||
defp compose_embedding_source(title, content), do: "#{title || ""}\n\n#{content || ""}"
|
||||
|
||||
defp post_content_hash(%Post{} = post) do
|
||||
body = resolve_post_body(post)
|
||||
hash_text(compose_embedding_source(post.title, body))
|
||||
end
|
||||
|
||||
defp embed_text(raw_text, language) do
|
||||
configured_backend().embed("query: " <> raw_text, language: language)
|
||||
end
|
||||
|
||||
defp rebuild_snapshot(project_id) do
|
||||
Index.rebuild(project_id, model_id: model_id(), dimensions: dimensions())
|
||||
end
|
||||
|
||||
defp diff_field(name, db_value, file_value) do
|
||||
db_value = if(is_binary(db_value), do: db_value, else: db_value || "")
|
||||
file_value = if(is_binary(file_value), do: file_value, else: file_value || "")
|
||||
|
||||
if db_value == file_value do
|
||||
nil
|
||||
else
|
||||
%{name: name, db_value: db_value, file_value: file_value}
|
||||
end
|
||||
end
|
||||
|
||||
defp hash_text(text), do: :crypto.hash(:sha256, text) |> Base.encode16(case: :lower)
|
||||
|
||||
defp decode_vector(nil), do: []
|
||||
|
||||
Reference in New Issue
Block a user