perf: batch CPU embedding inference and add A1-14c Apple GPU (EMLX) spec gap
This commit is contained in:
@@ -17,6 +17,14 @@ defmodule BDS.Embeddings.Backends.Neural do
|
||||
with `"query: "`, pooled with mean pooling over the attention mask, and
|
||||
L2-normalised. This is what makes cross-language semantic similarity
|
||||
work.
|
||||
* Inference is batched. `embed_many/2` runs the model on `batch_size`
|
||||
texts per compiled inference run instead of one at a time, which is the
|
||||
dominant cost when (re)indexing large numbers of posts. The serving is
|
||||
compiled for a fixed `batch_size`/`sequence_length` (configurable);
|
||||
shorter sequences mean less wasted transformer compute.
|
||||
|
||||
EXLA on Apple Silicon runs on the CPU — XLA has no Metal/GPU backend. See
|
||||
SPECGAPS A1-14c for the planned EMLX (Apple GPU via MLX) acceleration path.
|
||||
"""
|
||||
|
||||
@behaviour BDS.Embeddings.Backend
|
||||
@@ -24,11 +32,13 @@ defmodule BDS.Embeddings.Backends.Neural do
|
||||
use GenServer
|
||||
|
||||
@query_prefix "query: "
|
||||
@embed_timeout :timer.minutes(2)
|
||||
@embed_timeout :timer.minutes(10)
|
||||
|
||||
@default_model_id "Xenova/multilingual-e5-small"
|
||||
@default_model_repo "intfloat/multilingual-e5-small"
|
||||
@default_dimensions 384
|
||||
@default_batch_size 16
|
||||
@default_sequence_length 256
|
||||
|
||||
def child_spec(opts) do
|
||||
%{id: __MODULE__, start: {__MODULE__, :start_link, [opts]}}
|
||||
@@ -50,7 +60,22 @@ defmodule BDS.Embeddings.Backends.Neural do
|
||||
|
||||
@impl BDS.Embeddings.Backend
|
||||
def embed(text, _opts) when is_binary(text) do
|
||||
GenServer.call(__MODULE__, {:embed, @query_prefix <> text}, @embed_timeout)
|
||||
case run([@query_prefix <> text]) do
|
||||
{:ok, [vector]} -> {:ok, vector}
|
||||
{:ok, _other} -> {:error, :unexpected_embedding_result}
|
||||
{:error, _reason} = error -> error
|
||||
end
|
||||
end
|
||||
|
||||
@impl BDS.Embeddings.Backend
|
||||
def embed_many([], _opts), do: {:ok, []}
|
||||
|
||||
def embed_many(texts, _opts) when is_list(texts) do
|
||||
run(Enum.map(texts, &(@query_prefix <> &1)))
|
||||
end
|
||||
|
||||
defp run(prefixed_texts) do
|
||||
GenServer.call(__MODULE__, {:embed, prefixed_texts}, @embed_timeout)
|
||||
catch
|
||||
:exit, reason -> {:error, {:embedding_backend_unavailable, reason}}
|
||||
end
|
||||
@@ -59,11 +84,15 @@ defmodule BDS.Embeddings.Backends.Neural do
|
||||
def init(_opts), do: {:ok, %{serving: nil}}
|
||||
|
||||
@impl GenServer
|
||||
def handle_call({:embed, text}, _from, state) do
|
||||
def handle_call({:embed, texts}, _from, state) do
|
||||
case ensure_serving(state) do
|
||||
{:ok, %{serving: serving} = next_state} ->
|
||||
%{embedding: tensor} = Nx.Serving.run(serving, text)
|
||||
{:reply, {:ok, Nx.to_flat_list(tensor)}, next_state}
|
||||
vectors =
|
||||
texts
|
||||
|> Enum.chunk_every(batch_size())
|
||||
|> Enum.flat_map(&run_chunk(serving, &1))
|
||||
|
||||
{:reply, {:ok, vectors}, next_state}
|
||||
|
||||
{:error, _reason} = error ->
|
||||
{:reply, error, state}
|
||||
@@ -73,6 +102,17 @@ defmodule BDS.Embeddings.Backends.Neural do
|
||||
{:reply, {:error, Exception.message(exception)}, state}
|
||||
end
|
||||
|
||||
defp run_chunk(serving, [single]) do
|
||||
%{embedding: tensor} = Nx.Serving.run(serving, single)
|
||||
[Nx.to_flat_list(tensor)]
|
||||
end
|
||||
|
||||
defp run_chunk(serving, chunk) do
|
||||
serving
|
||||
|> Nx.Serving.run(chunk)
|
||||
|> Enum.map(fn %{embedding: tensor} -> Nx.to_flat_list(tensor) end)
|
||||
end
|
||||
|
||||
defp ensure_serving(%{serving: nil} = state) do
|
||||
case build_serving() do
|
||||
{:ok, serving} -> {:ok, %{state | serving: serving}}
|
||||
@@ -92,7 +132,7 @@ defmodule BDS.Embeddings.Backends.Neural do
|
||||
output_pool: :mean_pooling,
|
||||
output_attribute: :hidden_state,
|
||||
embedding_processor: :l2_norm,
|
||||
compile: [batch_size: 1, sequence_length: 512],
|
||||
compile: [batch_size: batch_size(), sequence_length: sequence_length()],
|
||||
defn_options: [compiler: EXLA]
|
||||
)
|
||||
|
||||
@@ -100,5 +140,13 @@ defmodule BDS.Embeddings.Backends.Neural do
|
||||
end
|
||||
end
|
||||
|
||||
defp batch_size do
|
||||
config() |> Keyword.get(:batch_size, @default_batch_size) |> max(1)
|
||||
end
|
||||
|
||||
defp sequence_length do
|
||||
config() |> Keyword.get(:sequence_length, @default_sequence_length) |> max(1)
|
||||
end
|
||||
|
||||
defp config, do: Application.get_env(:bds, :embeddings, [])
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user