perf: batch CPU embedding inference and add A1-14c Apple GPU (EMLX) spec gap

This commit is contained in:
2026-05-29 14:43:39 +02:00
parent a1004d72bf
commit 744f7543d7
10 changed files with 275 additions and 75 deletions

View File

@@ -37,6 +37,17 @@ defmodule BDS.Embeddings.Backends.InApp do
{:ok, vector}
end
@impl true
def embed_many(texts, opts) when is_list(texts) and is_list(opts) do
vectors =
Enum.map(texts, fn text ->
{:ok, vector} = embed(text, opts)
vector
end)
{:ok, vectors}
end
defp tokenize(text) do
Regex.scan(~r/[[:alnum:]]+/u, String.downcase(text))
|> List.flatten()

View File

@@ -17,6 +17,14 @@ defmodule BDS.Embeddings.Backends.Neural do
with `"query: "`, pooled with mean pooling over the attention mask, and
L2-normalised. This is what makes cross-language semantic similarity
work.
* Inference is batched. `embed_many/2` runs the model on `batch_size`
texts per compiled inference run instead of one at a time, which is the
dominant cost when (re)indexing large numbers of posts. The serving is
compiled for a fixed `batch_size`/`sequence_length` (configurable);
shorter sequences mean less wasted transformer compute.
EXLA on Apple Silicon runs on the CPU — XLA has no Metal/GPU backend. See
SPECGAPS A1-14c for the planned EMLX (Apple GPU via MLX) acceleration path.
"""
@behaviour BDS.Embeddings.Backend
@@ -24,11 +32,13 @@ defmodule BDS.Embeddings.Backends.Neural do
use GenServer
@query_prefix "query: "
@embed_timeout :timer.minutes(2)
@embed_timeout :timer.minutes(10)
@default_model_id "Xenova/multilingual-e5-small"
@default_model_repo "intfloat/multilingual-e5-small"
@default_dimensions 384
@default_batch_size 16
@default_sequence_length 256
def child_spec(opts) do
%{id: __MODULE__, start: {__MODULE__, :start_link, [opts]}}
@@ -50,7 +60,22 @@ defmodule BDS.Embeddings.Backends.Neural do
@impl BDS.Embeddings.Backend
def embed(text, _opts) when is_binary(text) do
GenServer.call(__MODULE__, {:embed, @query_prefix <> text}, @embed_timeout)
case run([@query_prefix <> text]) do
{:ok, [vector]} -> {:ok, vector}
{:ok, _other} -> {:error, :unexpected_embedding_result}
{:error, _reason} = error -> error
end
end
@impl BDS.Embeddings.Backend
def embed_many([], _opts), do: {:ok, []}
def embed_many(texts, _opts) when is_list(texts) do
run(Enum.map(texts, &(@query_prefix <> &1)))
end
defp run(prefixed_texts) do
GenServer.call(__MODULE__, {:embed, prefixed_texts}, @embed_timeout)
catch
:exit, reason -> {:error, {:embedding_backend_unavailable, reason}}
end
@@ -59,11 +84,15 @@ defmodule BDS.Embeddings.Backends.Neural do
def init(_opts), do: {:ok, %{serving: nil}}
@impl GenServer
def handle_call({:embed, text}, _from, state) do
def handle_call({:embed, texts}, _from, state) do
case ensure_serving(state) do
{:ok, %{serving: serving} = next_state} ->
%{embedding: tensor} = Nx.Serving.run(serving, text)
{:reply, {:ok, Nx.to_flat_list(tensor)}, next_state}
vectors =
texts
|> Enum.chunk_every(batch_size())
|> Enum.flat_map(&run_chunk(serving, &1))
{:reply, {:ok, vectors}, next_state}
{:error, _reason} = error ->
{:reply, error, state}
@@ -73,6 +102,17 @@ defmodule BDS.Embeddings.Backends.Neural do
{:reply, {:error, Exception.message(exception)}, state}
end
defp run_chunk(serving, [single]) do
%{embedding: tensor} = Nx.Serving.run(serving, single)
[Nx.to_flat_list(tensor)]
end
defp run_chunk(serving, chunk) do
serving
|> Nx.Serving.run(chunk)
|> Enum.map(fn %{embedding: tensor} -> Nx.to_flat_list(tensor) end)
end
defp ensure_serving(%{serving: nil} = state) do
case build_serving() do
{:ok, serving} -> {:ok, %{state | serving: serving}}
@@ -92,7 +132,7 @@ defmodule BDS.Embeddings.Backends.Neural do
output_pool: :mean_pooling,
output_attribute: :hidden_state,
embedding_processor: :l2_norm,
compile: [batch_size: 1, sequence_length: 512],
compile: [batch_size: batch_size(), sequence_length: sequence_length()],
defn_options: [compiler: EXLA]
)
@@ -100,5 +140,13 @@ defmodule BDS.Embeddings.Backends.Neural do
end
end
defp batch_size do
config() |> Keyword.get(:batch_size, @default_batch_size) |> max(1)
end
defp sequence_length do
config() |> Keyword.get(:sequence_length, @default_sequence_length) |> max(1)
end
defp config, do: Application.get_env(:bds, :embeddings, [])
end