perf: batch CPU embedding inference and add A1-14c Apple GPU (EMLX) spec gap
This commit is contained in:
@@ -75,21 +75,7 @@ defmodule BDS.Embeddings do
|
||||
)
|
||||
|
||||
existing_keys = preload_keys_by_post_id(project_id, Enum.map(posts, & &1.id))
|
||||
base_label = max_label_value()
|
||||
|
||||
{rows, _next_label} =
|
||||
Enum.reduce(posts, {[], base_label + 1}, fn post, {acc, next_label} ->
|
||||
existing_key = Map.get(existing_keys, post.id)
|
||||
|
||||
case compute_key_data(post, existing_key, next_label) do
|
||||
:skip ->
|
||||
{acc, next_label}
|
||||
|
||||
{:upsert, row} ->
|
||||
bump = if existing_key, do: 0, else: 1
|
||||
{[row | acc], next_label + bump}
|
||||
end
|
||||
end)
|
||||
rows = build_key_rows(posts, existing_keys, max_label_value(), nil)
|
||||
|
||||
batch_upsert_keys(rows)
|
||||
:ok = rebuild_snapshot(project_id)
|
||||
@@ -113,9 +99,6 @@ defmodule BDS.Embeddings do
|
||||
)
|
||||
|
||||
post_ids = Enum.map(posts, & &1.id)
|
||||
total_posts = length(posts)
|
||||
|
||||
:ok = report_rebuild_started(on_progress, total_posts, "embedding entries")
|
||||
|
||||
Repo.delete_all(
|
||||
from key in Key,
|
||||
@@ -123,24 +106,7 @@ defmodule BDS.Embeddings do
|
||||
)
|
||||
|
||||
existing_keys = preload_keys_by_post_id(project_id)
|
||||
base_label = max_label_value()
|
||||
|
||||
{rows, _next_label} =
|
||||
posts
|
||||
|> Enum.with_index(1)
|
||||
|> Enum.reduce({[], base_label + 1}, fn {post, index}, {acc, next_label} ->
|
||||
:ok = report_rebuild_progress(on_progress, index, total_posts, "embedding entries")
|
||||
existing_key = Map.get(existing_keys, post.id)
|
||||
|
||||
case compute_key_data(post, existing_key, next_label) do
|
||||
:skip ->
|
||||
{acc, next_label}
|
||||
|
||||
{:upsert, row} ->
|
||||
bump = if existing_key, do: 0, else: 1
|
||||
{[row | acc], next_label + bump}
|
||||
end
|
||||
end)
|
||||
rows = build_key_rows(posts, existing_keys, max_label_value(), on_progress)
|
||||
|
||||
batch_upsert_keys(rows)
|
||||
|
||||
@@ -246,18 +212,83 @@ defmodule BDS.Embeddings do
|
||||
Repo.one(from key in Key, select: max(key.label)) || 0
|
||||
end
|
||||
|
||||
defp compute_key_data(%Post{} = post, existing_key, next_label) do
|
||||
body = resolve_post_body(post)
|
||||
raw_text = compose_embedding_source(post.title, body)
|
||||
content_hash = hash_text(raw_text)
|
||||
# Builds the upsert rows for a batch of posts. Posts whose content_hash is
|
||||
# unchanged are skipped (ContentHashSkipsUnchanged); the rest are embedded in
|
||||
# batches (see embed_pending/2) so model inference is not serialised one post
|
||||
# at a time. Labels keep their existing value or take the next free integer.
|
||||
defp build_key_rows(posts, existing_keys, base_label, on_progress) do
|
||||
prepared =
|
||||
Enum.map(posts, fn post ->
|
||||
raw_text = compose_embedding_source(post.title, resolve_post_body(post))
|
||||
existing = Map.get(existing_keys, post.id)
|
||||
content_hash = hash_text(raw_text)
|
||||
|
||||
if existing_key && existing_key.content_hash == content_hash do
|
||||
:skip
|
||||
else
|
||||
{:ok, vector} = embed_text(raw_text, post.language)
|
||||
label = if existing_key, do: existing_key.label, else: next_label
|
||||
{:upsert, [label, post.id, post.project_id, content_hash, encode_vector(vector)]}
|
||||
end
|
||||
%{
|
||||
post: post,
|
||||
existing: existing,
|
||||
raw_text: raw_text,
|
||||
content_hash: content_hash,
|
||||
needs_embed?: is_nil(existing) or existing.content_hash != content_hash
|
||||
}
|
||||
end)
|
||||
|
||||
pending = Enum.filter(prepared, & &1.needs_embed?)
|
||||
:ok = report_rebuild_started(on_progress, length(pending), "embedding entries")
|
||||
vectors_by_post_id = embed_pending(pending, on_progress)
|
||||
|
||||
{rows, _next_label} =
|
||||
Enum.reduce(prepared, {[], base_label + 1}, fn entry, {acc, next_label} ->
|
||||
if entry.needs_embed? do
|
||||
vector = Map.fetch!(vectors_by_post_id, entry.post.id)
|
||||
label = if entry.existing, do: entry.existing.label, else: next_label
|
||||
bump = if entry.existing, do: 0, else: 1
|
||||
|
||||
row = [
|
||||
label,
|
||||
entry.post.id,
|
||||
entry.post.project_id,
|
||||
entry.content_hash,
|
||||
encode_vector(vector)
|
||||
]
|
||||
|
||||
{[row | acc], next_label + bump}
|
||||
else
|
||||
{acc, next_label}
|
||||
end
|
||||
end)
|
||||
|
||||
rows
|
||||
end
|
||||
|
||||
defp embed_pending([], _on_progress), do: %{}
|
||||
|
||||
defp embed_pending(pending, on_progress) do
|
||||
total = length(pending)
|
||||
batch = batch_size()
|
||||
|
||||
pending
|
||||
# Group by language so the lexical stub stems consistently; the neural
|
||||
# backend is multilingual and ignores the language hint.
|
||||
|> Enum.group_by(& &1.post.language)
|
||||
|> Enum.reduce({%{}, 0}, fn {language, group}, acc ->
|
||||
group
|
||||
|> Enum.chunk_every(batch)
|
||||
|> Enum.reduce(acc, fn chunk, {vectors, done} ->
|
||||
{:ok, chunk_vectors} = embed_many(Enum.map(chunk, & &1.raw_text), language)
|
||||
|
||||
vectors =
|
||||
chunk
|
||||
|> Enum.zip(chunk_vectors)
|
||||
|> Enum.reduce(vectors, fn {entry, vector}, acc ->
|
||||
Map.put(acc, entry.post.id, vector)
|
||||
end)
|
||||
|
||||
done = done + length(chunk)
|
||||
:ok = report_rebuild_progress(on_progress, done, total, "embedding entries")
|
||||
{vectors, done}
|
||||
end)
|
||||
end)
|
||||
|> elem(0)
|
||||
end
|
||||
|
||||
defp batch_upsert_keys([]), do: :ok
|
||||
@@ -308,21 +339,7 @@ defmodule BDS.Embeddings do
|
||||
)
|
||||
|
||||
existing_keys = preload_keys_by_post_id(project_id)
|
||||
base_label = max_label_value()
|
||||
|
||||
{rows, _next_label} =
|
||||
Enum.reduce(posts, {[], base_label + 1}, fn post, {acc, next_label} ->
|
||||
existing_key = Map.get(existing_keys, post.id)
|
||||
|
||||
case compute_key_data(post, existing_key, next_label) do
|
||||
:skip ->
|
||||
{acc, next_label}
|
||||
|
||||
{:upsert, row} ->
|
||||
bump = if existing_key, do: 0, else: 1
|
||||
{[row | acc], next_label + bump}
|
||||
end
|
||||
end)
|
||||
rows = build_key_rows(posts, existing_keys, max_label_value(), nil)
|
||||
|
||||
batch_upsert_keys(rows)
|
||||
:ok = rebuild_snapshot(project_id)
|
||||
@@ -660,6 +677,32 @@ defmodule BDS.Embeddings do
|
||||
configured_backend().embed(raw_text, language: language)
|
||||
end
|
||||
|
||||
# Embeds a batch of texts in one shot. Backends that implement the optional
|
||||
# embed_many/2 callback (e.g. the neural backend, which feeds them through the
|
||||
# model as a single batched inference run) handle the whole list; others fall
|
||||
# back to sequential single embeds.
|
||||
defp embed_many(texts, language) do
|
||||
backend = configured_backend()
|
||||
|
||||
if function_exported?(backend, :embed_many, 2) do
|
||||
backend.embed_many(texts, language: language)
|
||||
else
|
||||
vectors =
|
||||
Enum.map(texts, fn text ->
|
||||
{:ok, vector} = backend.embed(text, language: language)
|
||||
vector
|
||||
end)
|
||||
|
||||
{:ok, vectors}
|
||||
end
|
||||
end
|
||||
|
||||
defp batch_size do
|
||||
Application.get_env(:bds, :embeddings, [])
|
||||
|> Keyword.get(:batch_size, 16)
|
||||
|> max(1)
|
||||
end
|
||||
|
||||
defp rebuild_snapshot(project_id) do
|
||||
Index.rebuild(project_id, model_id: model_id(), dimensions: dimensions())
|
||||
end
|
||||
|
||||
@@ -3,4 +3,15 @@ defmodule BDS.Embeddings.Backend do
|
||||
|
||||
@callback model_info() :: %{model_id: String.t(), dimensions: pos_integer()}
|
||||
@callback embed(String.t(), keyword()) :: {:ok, [number()]} | {:error, term()}
|
||||
|
||||
@doc """
|
||||
Embeds a list of texts in a single call.
|
||||
|
||||
Backends that can amortise work across inputs (e.g. running the neural model
|
||||
on a batched tensor) should implement this. The result list is aligned with
|
||||
the input list. Optional — callers fall back to repeated `embed/2`.
|
||||
"""
|
||||
@callback embed_many([String.t()], keyword()) :: {:ok, [[number()]]} | {:error, term()}
|
||||
|
||||
@optional_callbacks embed_many: 2
|
||||
end
|
||||
|
||||
@@ -37,6 +37,17 @@ defmodule BDS.Embeddings.Backends.InApp do
|
||||
{:ok, vector}
|
||||
end
|
||||
|
||||
@impl true
|
||||
def embed_many(texts, opts) when is_list(texts) and is_list(opts) do
|
||||
vectors =
|
||||
Enum.map(texts, fn text ->
|
||||
{:ok, vector} = embed(text, opts)
|
||||
vector
|
||||
end)
|
||||
|
||||
{:ok, vectors}
|
||||
end
|
||||
|
||||
defp tokenize(text) do
|
||||
Regex.scan(~r/[[:alnum:]]+/u, String.downcase(text))
|
||||
|> List.flatten()
|
||||
|
||||
@@ -17,6 +17,14 @@ defmodule BDS.Embeddings.Backends.Neural do
|
||||
with `"query: "`, pooled with mean pooling over the attention mask, and
|
||||
L2-normalised. This is what makes cross-language semantic similarity
|
||||
work.
|
||||
* Inference is batched. `embed_many/2` runs the model on `batch_size`
|
||||
texts per compiled inference run instead of one at a time, which is the
|
||||
dominant cost when (re)indexing large numbers of posts. The serving is
|
||||
compiled for a fixed `batch_size`/`sequence_length` (configurable);
|
||||
shorter sequences mean less wasted transformer compute.
|
||||
|
||||
EXLA on Apple Silicon runs on the CPU — XLA has no Metal/GPU backend. See
|
||||
SPECGAPS A1-14c for the planned EMLX (Apple GPU via MLX) acceleration path.
|
||||
"""
|
||||
|
||||
@behaviour BDS.Embeddings.Backend
|
||||
@@ -24,11 +32,13 @@ defmodule BDS.Embeddings.Backends.Neural do
|
||||
use GenServer
|
||||
|
||||
@query_prefix "query: "
|
||||
@embed_timeout :timer.minutes(2)
|
||||
@embed_timeout :timer.minutes(10)
|
||||
|
||||
@default_model_id "Xenova/multilingual-e5-small"
|
||||
@default_model_repo "intfloat/multilingual-e5-small"
|
||||
@default_dimensions 384
|
||||
@default_batch_size 16
|
||||
@default_sequence_length 256
|
||||
|
||||
def child_spec(opts) do
|
||||
%{id: __MODULE__, start: {__MODULE__, :start_link, [opts]}}
|
||||
@@ -50,7 +60,22 @@ defmodule BDS.Embeddings.Backends.Neural do
|
||||
|
||||
@impl BDS.Embeddings.Backend
|
||||
def embed(text, _opts) when is_binary(text) do
|
||||
GenServer.call(__MODULE__, {:embed, @query_prefix <> text}, @embed_timeout)
|
||||
case run([@query_prefix <> text]) do
|
||||
{:ok, [vector]} -> {:ok, vector}
|
||||
{:ok, _other} -> {:error, :unexpected_embedding_result}
|
||||
{:error, _reason} = error -> error
|
||||
end
|
||||
end
|
||||
|
||||
@impl BDS.Embeddings.Backend
|
||||
def embed_many([], _opts), do: {:ok, []}
|
||||
|
||||
def embed_many(texts, _opts) when is_list(texts) do
|
||||
run(Enum.map(texts, &(@query_prefix <> &1)))
|
||||
end
|
||||
|
||||
defp run(prefixed_texts) do
|
||||
GenServer.call(__MODULE__, {:embed, prefixed_texts}, @embed_timeout)
|
||||
catch
|
||||
:exit, reason -> {:error, {:embedding_backend_unavailable, reason}}
|
||||
end
|
||||
@@ -59,11 +84,15 @@ defmodule BDS.Embeddings.Backends.Neural do
|
||||
def init(_opts), do: {:ok, %{serving: nil}}
|
||||
|
||||
@impl GenServer
|
||||
def handle_call({:embed, text}, _from, state) do
|
||||
def handle_call({:embed, texts}, _from, state) do
|
||||
case ensure_serving(state) do
|
||||
{:ok, %{serving: serving} = next_state} ->
|
||||
%{embedding: tensor} = Nx.Serving.run(serving, text)
|
||||
{:reply, {:ok, Nx.to_flat_list(tensor)}, next_state}
|
||||
vectors =
|
||||
texts
|
||||
|> Enum.chunk_every(batch_size())
|
||||
|> Enum.flat_map(&run_chunk(serving, &1))
|
||||
|
||||
{:reply, {:ok, vectors}, next_state}
|
||||
|
||||
{:error, _reason} = error ->
|
||||
{:reply, error, state}
|
||||
@@ -73,6 +102,17 @@ defmodule BDS.Embeddings.Backends.Neural do
|
||||
{:reply, {:error, Exception.message(exception)}, state}
|
||||
end
|
||||
|
||||
defp run_chunk(serving, [single]) do
|
||||
%{embedding: tensor} = Nx.Serving.run(serving, single)
|
||||
[Nx.to_flat_list(tensor)]
|
||||
end
|
||||
|
||||
defp run_chunk(serving, chunk) do
|
||||
serving
|
||||
|> Nx.Serving.run(chunk)
|
||||
|> Enum.map(fn %{embedding: tensor} -> Nx.to_flat_list(tensor) end)
|
||||
end
|
||||
|
||||
defp ensure_serving(%{serving: nil} = state) do
|
||||
case build_serving() do
|
||||
{:ok, serving} -> {:ok, %{state | serving: serving}}
|
||||
@@ -92,7 +132,7 @@ defmodule BDS.Embeddings.Backends.Neural do
|
||||
output_pool: :mean_pooling,
|
||||
output_attribute: :hidden_state,
|
||||
embedding_processor: :l2_norm,
|
||||
compile: [batch_size: 1, sequence_length: 512],
|
||||
compile: [batch_size: batch_size(), sequence_length: sequence_length()],
|
||||
defn_options: [compiler: EXLA]
|
||||
)
|
||||
|
||||
@@ -100,5 +140,13 @@ defmodule BDS.Embeddings.Backends.Neural do
|
||||
end
|
||||
end
|
||||
|
||||
defp batch_size do
|
||||
config() |> Keyword.get(:batch_size, @default_batch_size) |> max(1)
|
||||
end
|
||||
|
||||
defp sequence_length do
|
||||
config() |> Keyword.get(:sequence_length, @default_sequence_length) |> max(1)
|
||||
end
|
||||
|
||||
defp config, do: Application.get_env(:bds, :embeddings, [])
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user