feat: more on embedding
This commit is contained in:
@@ -10,12 +10,10 @@ defmodule BDS.Embeddings do
|
||||
alias BDS.Projects
|
||||
alias BDS.Repo
|
||||
|
||||
@dimensions 384
|
||||
@duplicate_threshold 0.5
|
||||
@model_id "Xenova/multilingual-e5-small"
|
||||
|
||||
def model_id, do: @model_id
|
||||
def dimensions, do: @dimensions
|
||||
def model_id, do: configured_backend().model_info().model_id
|
||||
def dimensions, do: configured_backend().model_info().dimensions
|
||||
|
||||
def sync_post(%Post{} = post) do
|
||||
if enabled_for_project?(post.project_id) do
|
||||
@@ -29,7 +27,7 @@ defmodule BDS.Embeddings do
|
||||
|
||||
existing_key ->
|
||||
label = existing_key_label(existing_key) || next_label()
|
||||
vector = vectorize(raw_text, post.language)
|
||||
{:ok, vector} = embed_text(raw_text, post.language)
|
||||
|
||||
(existing_key || %Key{})
|
||||
|> Key.changeset(%{
|
||||
@@ -245,6 +243,11 @@ defmodule BDS.Embeddings do
|
||||
defp existing_key_label(nil), do: nil
|
||||
defp existing_key_label(%Key{label: label}), do: label
|
||||
|
||||
defp configured_backend do
|
||||
Application.get_env(:bds, :embeddings, [])
|
||||
|> Keyword.get(:backend, BDS.Embeddings.Backends.InApp)
|
||||
end
|
||||
|
||||
defp next_label do
|
||||
Repo.one(from key in Key, select: max(key.label))
|
||||
|> case do
|
||||
@@ -277,40 +280,12 @@ defmodule BDS.Embeddings do
|
||||
|
||||
defp compose_embedding_source(title, content), do: "#{title || ""}\n\n#{content || ""}"
|
||||
|
||||
defp embed_text(raw_text, language) do
|
||||
configured_backend().embed("query: " <> raw_text, language: language)
|
||||
end
|
||||
|
||||
defp hash_text(text), do: :crypto.hash(:sha256, text) |> Base.encode16(case: :lower)
|
||||
|
||||
defp vectorize(text, language) do
|
||||
stemmed = BDS.Search.stem(text, language)
|
||||
tokens = tokenize(stemmed)
|
||||
bigrams = tokens |> Enum.chunk_every(2, 1, :discard) |> Enum.map(&Enum.join(&1, "::"))
|
||||
weighted_tokens = tokens ++ bigrams
|
||||
vector_array = :array.new(@dimensions, default: 0.0)
|
||||
|
||||
vector =
|
||||
Enum.reduce(weighted_tokens, vector_array, fn token, acc ->
|
||||
index = :erlang.phash2(token, @dimensions)
|
||||
:array.set(index, :array.get(index, acc) + 1.0, acc)
|
||||
end)
|
||||
|> :array.to_list()
|
||||
|
||||
normalize(vector)
|
||||
end
|
||||
|
||||
defp tokenize(text) do
|
||||
Regex.scan(~r/[[:alnum:]]+/u, String.downcase(text))
|
||||
|> List.flatten()
|
||||
end
|
||||
|
||||
defp normalize(vector) do
|
||||
norm = :math.sqrt(Enum.reduce(vector, 0.0, fn value, acc -> acc + value * value end))
|
||||
|
||||
if norm == 0.0 do
|
||||
vector
|
||||
else
|
||||
Enum.map(vector, &(&1 / norm))
|
||||
end
|
||||
end
|
||||
|
||||
defp decode_vector(nil), do: []
|
||||
defp decode_vector(vector), do: Jason.decode!(vector)
|
||||
|
||||
|
||||
6
lib/bds/embeddings/backend.ex
Normal file
6
lib/bds/embeddings/backend.ex
Normal file
@@ -0,0 +1,6 @@
|
||||
defmodule BDS.Embeddings.Backend do
|
||||
@moduledoc false
|
||||
|
||||
@callback model_info() :: %{model_id: String.t(), dimensions: pos_integer()}
|
||||
@callback embed(String.t(), keyword()) :: {:ok, [number()]} | {:error, term()}
|
||||
end
|
||||
60
lib/bds/embeddings/backends/in_app.ex
Normal file
60
lib/bds/embeddings/backends/in_app.ex
Normal file
@@ -0,0 +1,60 @@
|
||||
defmodule BDS.Embeddings.Backends.InApp do
|
||||
@moduledoc false
|
||||
|
||||
@behaviour BDS.Embeddings.Backend
|
||||
|
||||
@impl true
|
||||
def model_info do
|
||||
config = Application.get_env(:bds, :embeddings, [])
|
||||
|
||||
%{
|
||||
model_id: Keyword.get(config, :model_id, "Xenova/multilingual-e5-small"),
|
||||
dimensions: Keyword.get(config, :dimensions, 384)
|
||||
}
|
||||
end
|
||||
|
||||
@impl true
|
||||
def embed(text, opts) when is_binary(text) and is_list(opts) do
|
||||
language = Keyword.get(opts, :language)
|
||||
dimensions = model_info().dimensions
|
||||
|
||||
vector =
|
||||
text
|
||||
|> BDS.Search.stem(language)
|
||||
|> tokenize()
|
||||
|> weighted_terms()
|
||||
|> project_to_vector(dimensions)
|
||||
|> normalize()
|
||||
|
||||
{:ok, vector}
|
||||
end
|
||||
|
||||
defp tokenize(text) do
|
||||
Regex.scan(~r/[[:alnum:]]+/u, String.downcase(text))
|
||||
|> List.flatten()
|
||||
end
|
||||
|
||||
defp weighted_terms(tokens) do
|
||||
bigrams = tokens |> Enum.chunk_every(2, 1, :discard) |> Enum.map(&Enum.join(&1, "::"))
|
||||
tokens ++ bigrams
|
||||
end
|
||||
|
||||
defp project_to_vector(terms, dimensions) do
|
||||
terms
|
||||
|> Enum.reduce(:array.new(dimensions, default: 0.0), fn term, acc ->
|
||||
index = :erlang.phash2(term, dimensions)
|
||||
:array.set(index, :array.get(index, acc) + 1.0, acc)
|
||||
end)
|
||||
|> :array.to_list()
|
||||
end
|
||||
|
||||
defp normalize(vector) do
|
||||
norm = :math.sqrt(Enum.reduce(vector, 0.0, fn value, acc -> acc + value * value end))
|
||||
|
||||
if norm == 0.0 do
|
||||
vector
|
||||
else
|
||||
Enum.map(vector, &(&1 / norm))
|
||||
end
|
||||
end
|
||||
end
|
||||
@@ -547,6 +547,7 @@ defmodule BDS.Posts do
|
||||
|> Post.changeset(attrs)
|
||||
|> Repo.insert_or_update!()
|
||||
|> tap(&Search.sync_post/1)
|
||||
|> tap(&Embeddings.sync_post/1)
|
||||
end
|
||||
|
||||
defp parse_post_status(status) when is_atom(status), do: status
|
||||
|
||||
Reference in New Issue
Block a user