Files
bDS2/lib/bds/embeddings/backends/neural.ex

210 lines
7.1 KiB
Elixir

defmodule BDS.Embeddings.Backends.Neural do
@moduledoc """
Real on-device neural embedding backend.
Implements the `RealNeuralModel` and `ModelCaching` invariants from
`specs/embedding.allium`: embeddings are produced by the actual
multilingual-e5-small transformer (the `intfloat/multilingual-e5-small`
weights behind the `Xenova/multilingual-e5-small` identifier) via
Bumblebee + EXLA, never by a lexical approximation.
* Lazy-loaded — the model pipeline is built on the first embedding
request, not at application startup.
* Model files (~100 MB) are downloaded from the Hugging Face Hub on
first use and cached on disk (Bumblebee cache dir), persisting across
sessions and project switches.
* Text preprocessing follows the e5 convention: every input is prefixed
with `"query: "`, pooled with mean pooling over the attention mask, and
L2-normalised. This is what makes cross-language semantic similarity
work.
* Inference is batched. `embed_many/2` runs the model on `batch_size`
texts per compiled inference run instead of one at a time, which is the
dominant cost when (re)indexing large numbers of posts. The serving is
compiled for a fixed `batch_size`/`sequence_length` (configurable);
shorter sequences mean less wasted transformer compute.
Hardware acceleration follows the `NativeAcceleratedExecution` invariant.
The serving's defn compiler is chosen at build time:
* On Apple Silicon (arm64 macOS) with EMLX available, inference runs on the
Apple GPU via MLX/Metal (`compiler: EMLX`, params placed on the
`EMLX.Backend` GPU device).
* Everywhere else — and as a fallback when EMLX is unavailable or explicitly
disabled — it runs on optimised native CPU via XLA (`compiler: EXLA`).
The accelerator can be pinned with `config :bds, :embeddings, accelerator:`
to `:auto` (default), `:emlx`, or `:exla`.
"""
@behaviour BDS.Embeddings.Backend
use GenServer
@query_prefix "query: "
@embed_timeout :timer.minutes(10)
@default_model_id "Xenova/multilingual-e5-small"
@default_model_repo "intfloat/multilingual-e5-small"
@default_dimensions 384
@default_batch_size 16
@default_sequence_length 256
@default_accelerator :auto
def child_spec(opts) do
%{id: __MODULE__, start: {__MODULE__, :start_link, [opts]}}
end
def start_link(opts \\ []) do
GenServer.start_link(__MODULE__, opts, name: __MODULE__)
end
@impl BDS.Embeddings.Backend
def model_info do
config = config()
%{
model_id: Keyword.get(config, :model_id, @default_model_id),
dimensions: Keyword.get(config, :dimensions, @default_dimensions)
}
end
@impl BDS.Embeddings.Backend
def embed(text, _opts) when is_binary(text) do
case run([@query_prefix <> text]) do
{:ok, [vector]} -> {:ok, vector}
{:ok, _other} -> {:error, :unexpected_embedding_result}
{:error, _reason} = error -> error
end
end
@impl BDS.Embeddings.Backend
def embed_many([], _opts), do: {:ok, []}
def embed_many(texts, _opts) when is_list(texts) do
run(Enum.map(texts, &(@query_prefix <> &1)))
end
defp run(prefixed_texts) do
GenServer.call(__MODULE__, {:embed, prefixed_texts}, @embed_timeout)
catch
:exit, reason -> {:error, {:embedding_backend_unavailable, reason}}
end
@impl GenServer
def init(_opts), do: {:ok, %{serving: nil}}
@impl GenServer
def handle_call({:embed, texts}, _from, state) do
case ensure_serving(state) do
{:ok, %{serving: serving} = next_state} ->
vectors =
texts
|> Enum.chunk_every(batch_size())
|> Enum.flat_map(&run_chunk(serving, &1))
{:reply, {:ok, vectors}, next_state}
{:error, _reason} = error ->
{:reply, error, state}
end
rescue
exception ->
{:reply, {:error, Exception.message(exception)}, state}
end
defp run_chunk(serving, [single]) do
%{embedding: tensor} = Nx.Serving.run(serving, single)
[Nx.to_flat_list(tensor)]
end
defp run_chunk(serving, chunk) do
serving
|> Nx.Serving.run(chunk)
|> Enum.map(fn %{embedding: tensor} -> Nx.to_flat_list(tensor) end)
end
defp ensure_serving(%{serving: nil} = state) do
case build_serving() do
{:ok, serving} -> {:ok, %{state | serving: serving}}
{:error, _reason} = error -> error
end
end
defp ensure_serving(state), do: {:ok, state}
defp build_serving do
repo = {:hf, Keyword.get(config(), :model_repo, @default_model_repo)}
accelerator = current_accelerator()
maybe_set_default_backend(accelerator)
with {:ok, model_info} <- Bumblebee.load_model(repo),
{:ok, tokenizer} <- Bumblebee.load_tokenizer(repo) do
serving =
Bumblebee.Text.text_embedding(model_info, tokenizer,
output_pool: :mean_pooling,
output_attribute: :hidden_state,
embedding_processor: :l2_norm,
compile: [batch_size: batch_size(), sequence_length: sequence_length()],
defn_options: defn_options(accelerator)
)
{:ok, serving}
end
end
# Place model params/tensors on the Apple GPU (Metal) when accelerating with
# EMLX so the compiled inference pass actually runs on-device. EXLA manages
# its own device placement, so nothing to do there.
defp maybe_set_default_backend(:emlx), do: Nx.global_default_backend({EMLX.Backend, device: :gpu})
defp maybe_set_default_backend(:exla), do: :ok
@doc false
@spec defn_options(:emlx | :exla) :: keyword()
def defn_options(:emlx), do: [compiler: EMLX]
def defn_options(:exla), do: [compiler: EXLA]
@doc false
@spec current_accelerator() :: :emlx | :exla
def current_accelerator do
select_accelerator(configured_accelerator(), emlx_available?(), apple_silicon?())
end
@doc """
Pure accelerator-selection policy for `NativeAcceleratedExecution`.
Prefer the Apple GPU (EMLX) under `:auto` only when it is both available and
running on Apple Silicon; honour an explicit `:emlx`/`:exla` request, but
degrade a forced `:emlx` to EXLA when EMLX is not loaded so a misconfigured
host still gets working CPU inference instead of crashing.
"""
@spec select_accelerator(:auto | :emlx | :exla, boolean(), boolean()) :: :emlx | :exla
def select_accelerator(:exla, _emlx_available?, _apple_silicon?), do: :exla
def select_accelerator(:emlx, true, _apple_silicon?), do: :emlx
def select_accelerator(:emlx, false, _apple_silicon?), do: :exla
def select_accelerator(:auto, true, true), do: :emlx
def select_accelerator(:auto, _emlx_available?, _apple_silicon?), do: :exla
defp configured_accelerator do
config() |> Keyword.get(:accelerator, @default_accelerator)
end
defp emlx_available? do
Code.ensure_loaded?(EMLX) and Code.ensure_loaded?(EMLX.Backend)
end
defp apple_silicon? do
:os.type() == {:unix, :darwin} and
to_string(:erlang.system_info(:system_architecture)) =~ ~r/aarch64|arm/
end
defp batch_size do
config() |> Keyword.get(:batch_size, @default_batch_size) |> max(1)
end
defp sequence_length do
config() |> Keyword.get(:sequence_length, @default_sequence_length) |> max(1)
end
defp config, do: Application.get_env(:bds, :embeddings, [])
end