fix: implemented TD-06 real SSE implementation
This commit is contained in:
@@ -565,9 +565,10 @@ defmodule BDS.AI.Chat do
|
||||
rounds_left
|
||||
) do
|
||||
request = build_chat_request(conversation, messages, model, project_id, tools)
|
||||
generate_opts = put_stream_callback(opts, conversation.id)
|
||||
|
||||
with {:ok, response} <-
|
||||
runtime.generate(Runtime.endpoint_with_model(endpoint, model), request, opts),
|
||||
runtime.generate(Runtime.endpoint_with_model(endpoint, model), request, generate_opts),
|
||||
{:ok, assistant_message} <- persist_assistant_response(conversation.id, response),
|
||||
:ok <- touch_conversation(conversation.id) do
|
||||
if is_binary(Map.get(response, :content)) and String.trim(Map.get(response, :content)) != "" do
|
||||
@@ -921,6 +922,26 @@ defmodule BDS.AI.Chat do
|
||||
end
|
||||
end
|
||||
|
||||
# When someone is listening for chat events, ask the runtime to stream:
|
||||
# it emits cumulative content snapshots, which the editor renders with
|
||||
# replace semantics. The full-content notify after each round stays the
|
||||
# authoritative final state (and the only event for non-streaming runtimes).
|
||||
defp put_stream_callback(opts, conversation_id) do
|
||||
case Keyword.get(opts, :event_target) do
|
||||
nil ->
|
||||
opts
|
||||
|
||||
_target ->
|
||||
Keyword.put(opts, :on_stream, fn %{content: content} ->
|
||||
if is_binary(content) and String.trim(content) != "" do
|
||||
notify_chat_event(opts, {:chat_streaming_content, conversation_id, content})
|
||||
end
|
||||
|
||||
:ok
|
||||
end)
|
||||
end
|
||||
end
|
||||
|
||||
defp notify_chat_event(opts, event) do
|
||||
case Keyword.get(opts, :event_target) do
|
||||
pid when is_pid(pid) -> send(pid, event)
|
||||
|
||||
@@ -59,6 +59,62 @@ defmodule BDS.AI.HttpClient do
|
||||
|> normalize_result()
|
||||
end
|
||||
|
||||
@doc """
|
||||
Streaming POST: body chunks of a 200 response are folded into `acc` via
|
||||
`reducer.(chunk, acc)` as they arrive; non-200 bodies are collected whole
|
||||
for error reporting. Returns the final accumulator alongside the response.
|
||||
|
||||
Never retried (same reasoning as `post/3`), and `accept-encoding` is
|
||||
disabled so event-stream chunks arrive uncompressed. The request runs in
|
||||
the calling process — killing that process aborts the underlying
|
||||
connection, which is what makes mid-stream chat cancellation work.
|
||||
"""
|
||||
@spec post_stream(String.t(), %{String.t() => String.t()}, binary(), acc, (binary(), acc ->
|
||||
acc)) ::
|
||||
{:ok, %{status: non_neg_integer(), headers: map(), body: binary()}, acc}
|
||||
| {:error, term()}
|
||||
when acc: term()
|
||||
def post_stream(url, headers, body, acc, reducer)
|
||||
when is_binary(url) and is_map(headers) and is_binary(body) and is_function(reducer, 2) do
|
||||
into = fn {:data, data}, {req, resp} ->
|
||||
resp =
|
||||
if resp.status == 200 do
|
||||
next_acc = reducer.(data, Req.Response.get_private(resp, :bds_stream_acc, acc))
|
||||
Req.Response.put_private(resp, :bds_stream_acc, next_acc)
|
||||
else
|
||||
%{resp | body: collected_body(resp.body) <> data}
|
||||
end
|
||||
|
||||
{:cont, {req, resp}}
|
||||
end
|
||||
|
||||
[
|
||||
method: :post,
|
||||
url: url,
|
||||
headers: headers,
|
||||
body: body,
|
||||
retry: false,
|
||||
compressed: false,
|
||||
into: into
|
||||
]
|
||||
|> Keyword.merge(base_options())
|
||||
|> Req.request()
|
||||
|> case do
|
||||
{:ok, %Req.Response{} = resp} ->
|
||||
{:ok, %{status: resp.status, headers: normalize_headers(resp.headers), body: collected_body(resp.body)},
|
||||
Req.Response.get_private(resp, :bds_stream_acc, acc)}
|
||||
|
||||
{:error, %Req.TransportError{reason: reason}} ->
|
||||
{:error, reason}
|
||||
|
||||
{:error, reason} ->
|
||||
{:error, reason}
|
||||
end
|
||||
end
|
||||
|
||||
defp collected_body(body) when is_binary(body), do: body
|
||||
defp collected_body(_body), do: ""
|
||||
|
||||
defp base_options do
|
||||
[
|
||||
connect_options: [timeout: config(:connect_timeout_ms, @default_connect_timeout_ms)],
|
||||
|
||||
@@ -4,6 +4,7 @@ defmodule BDS.AI.OpenAICompatibleRuntime do
|
||||
require Logger
|
||||
|
||||
alias BDS.AI.HttpClient
|
||||
alias BDS.AI.SSE
|
||||
|
||||
def list_models(endpoint, opts \\ []) when is_map(endpoint) and is_list(opts) do
|
||||
http_client = Keyword.get(opts, :http_client, HttpClient)
|
||||
@@ -22,7 +23,7 @@ defmodule BDS.AI.OpenAICompatibleRuntime do
|
||||
end
|
||||
end
|
||||
|
||||
def generate(endpoint, request, _opts) when is_map(endpoint) and is_map(request) do
|
||||
def generate(endpoint, request, opts) when is_map(endpoint) and is_map(request) do
|
||||
url = completions_url(endpoint.url)
|
||||
|
||||
headers =
|
||||
@@ -41,6 +42,14 @@ defmodule BDS.AI.OpenAICompatibleRuntime do
|
||||
|> maybe_disable_thinking(request.model)
|
||||
|> maybe_put_tools(Map.get(request, :tools, []))
|
||||
|
||||
if stream?(request, opts) do
|
||||
generate_streaming(url, headers, payload, request, Keyword.fetch!(opts, :on_stream))
|
||||
else
|
||||
generate_blocking(url, headers, payload, request)
|
||||
end
|
||||
end
|
||||
|
||||
defp generate_blocking(url, headers, payload, request) do
|
||||
payload_json = Jason.encode!(payload)
|
||||
|
||||
Logger.debug(
|
||||
@@ -81,6 +90,81 @@ defmodule BDS.AI.OpenAICompatibleRuntime do
|
||||
end
|
||||
end
|
||||
|
||||
# Streaming variant: same request payload plus stream flags; SSE chunks are
|
||||
# folded into a BDS.AI.SSE assembler that emits cumulative content
|
||||
# snapshots to `on_stream` as they arrive. The assembled message goes
|
||||
# through the same normalization as the blocking path.
|
||||
defp generate_streaming(url, headers, payload, request, on_stream) do
|
||||
payload_json =
|
||||
payload
|
||||
|> Map.put("stream", true)
|
||||
|> Map.put("stream_options", %{"include_usage" => true})
|
||||
|> Jason.encode!()
|
||||
|
||||
Logger.debug(
|
||||
"AI OpenAI-compatible streaming request operation=#{inspect(Map.get(request, :operation))} model=#{inspect(request.model)} url=#{url} payload_size=#{byte_size(payload_json)}"
|
||||
)
|
||||
|
||||
sse = SSE.new(on_stream, emit_interval_ms: stream_emit_interval_ms())
|
||||
|
||||
case HttpClient.post_stream(url, headers, payload_json, sse, fn chunk, acc ->
|
||||
SSE.feed(acc, chunk)
|
||||
end) do
|
||||
{:ok, %{status: 200, headers: response_headers}, sse} ->
|
||||
if event_stream?(response_headers) do
|
||||
assembled = SSE.finish(sse)
|
||||
|
||||
{:ok,
|
||||
%{
|
||||
content: assembled.content,
|
||||
json: decode_json_content(assembled.content),
|
||||
tool_calls: normalize_tool_calls(assembled.tool_calls),
|
||||
usage: normalize_usage(assembled.usage || %{})
|
||||
}}
|
||||
else
|
||||
# The provider ignored the stream flag and sent a plain completion.
|
||||
normalize_response(SSE.raw_body(sse))
|
||||
end
|
||||
|
||||
{:ok, %{status: status, body: body}, _sse} ->
|
||||
Logger.error(
|
||||
"AI OpenAI-compatible streaming HTTP error status=#{status} body=#{String.slice(body, 0, 2000)}"
|
||||
)
|
||||
|
||||
{:error, %{kind: :http_error, status: status, body: body}}
|
||||
|
||||
{:error, reason} ->
|
||||
Logger.error("AI OpenAI-compatible streaming request failed: #{inspect(reason)}")
|
||||
{:error, %{kind: :http_error, reason: reason}}
|
||||
end
|
||||
end
|
||||
|
||||
# Streaming is opt-in per request (the caller passes :on_stream), limited
|
||||
# to interactive chat, and can be disabled globally for providers that do
|
||||
# not support SSE (config :bds, :chat, streaming: false).
|
||||
defp stream?(request, opts) do
|
||||
Map.get(request, :operation) == :chat and
|
||||
is_function(Keyword.get(opts, :on_stream), 1) and
|
||||
chat_config(:streaming, true)
|
||||
end
|
||||
|
||||
defp stream_emit_interval_ms, do: chat_config(:stream_emit_interval_ms, 100)
|
||||
|
||||
defp event_stream?(headers) do
|
||||
case headers["content-type"] do
|
||||
content_type when is_binary(content_type) ->
|
||||
String.contains?(content_type, "text/event-stream")
|
||||
|
||||
_missing ->
|
||||
# No content type: trust the request we made and parse as SSE.
|
||||
true
|
||||
end
|
||||
end
|
||||
|
||||
defp chat_config(key, default) do
|
||||
:bds |> Application.get_env(:chat, []) |> Keyword.get(key, default)
|
||||
end
|
||||
|
||||
defp normalize_response(body) do
|
||||
with {:ok, payload} <- decode_json_body(body) do
|
||||
message = get_in(payload, ["choices", Access.at(0), "message"]) || %{}
|
||||
@@ -88,19 +172,22 @@ defmodule BDS.AI.OpenAICompatibleRuntime do
|
||||
tool_calls = normalize_tool_calls(message["tool_calls"] || [])
|
||||
usage = normalize_usage(payload["usage"] || %{})
|
||||
|
||||
json =
|
||||
case content do
|
||||
nil ->
|
||||
nil
|
||||
{:ok,
|
||||
%{
|
||||
content: content,
|
||||
json: decode_json_content(content),
|
||||
tool_calls: tool_calls,
|
||||
usage: usage
|
||||
}}
|
||||
end
|
||||
end
|
||||
|
||||
value when is_binary(value) ->
|
||||
case Jason.decode(value) do
|
||||
{:ok, decoded} when is_map(decoded) -> decoded
|
||||
_other -> nil
|
||||
end
|
||||
end
|
||||
defp decode_json_content(nil), do: nil
|
||||
|
||||
{:ok, %{content: content, json: json, tool_calls: tool_calls, usage: usage}}
|
||||
defp decode_json_content(content) when is_binary(content) do
|
||||
case Jason.decode(content) do
|
||||
{:ok, decoded} when is_map(decoded) -> decoded
|
||||
_other -> nil
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
176
lib/bds/ai/sse.ex
Normal file
176
lib/bds/ai/sse.ex
Normal file
@@ -0,0 +1,176 @@
|
||||
defmodule BDS.AI.SSE do
|
||||
@moduledoc """
|
||||
Incremental assembler for OpenAI-compatible `text/event-stream` chat
|
||||
completions.
|
||||
|
||||
Fed raw transport chunks via `feed/2`, it buffers partial events, decodes
|
||||
`data:` payloads, and accumulates content deltas, tool-call fragments, and
|
||||
usage. Content is reported to the optional `on_event` callback as
|
||||
**cumulative snapshots** (`%{content: binary}`) — replace semantics, which
|
||||
matches how the chat editor renders streaming state and resets naturally
|
||||
between tool rounds. Emissions are throttled to `:emit_interval_ms`
|
||||
(the first delta always emits immediately for perceived latency).
|
||||
|
||||
`finish/1` returns the assembled message in OpenAI wire shape so the
|
||||
runtime can reuse its non-streaming normalization:
|
||||
`%{content: binary | nil, tool_calls: [%{"id" => _, "function" => %{"name" => _, "arguments" => json_string}}], usage: map | nil}`.
|
||||
"""
|
||||
|
||||
defstruct buffer: "",
|
||||
raw: [],
|
||||
content: [],
|
||||
content?: false,
|
||||
tool_calls: %{},
|
||||
usage: nil,
|
||||
done?: false,
|
||||
on_event: nil,
|
||||
emit_interval_ms: 100,
|
||||
last_emit_at: nil
|
||||
|
||||
@type t :: %__MODULE__{}
|
||||
|
||||
@spec new((map() -> any()) | nil, keyword()) :: t()
|
||||
def new(on_event \\ nil, opts \\ []) when is_list(opts) do
|
||||
%__MODULE__{
|
||||
on_event: on_event,
|
||||
emit_interval_ms: Keyword.get(opts, :emit_interval_ms, 100)
|
||||
}
|
||||
end
|
||||
|
||||
@spec feed(t(), binary()) :: t()
|
||||
def feed(%__MODULE__{done?: true} = sse, _chunk), do: sse
|
||||
|
||||
def feed(%__MODULE__{} = sse, chunk) when is_binary(chunk) do
|
||||
sse = %{sse | raw: [chunk | sse.raw]}
|
||||
parts = String.split(sse.buffer <> chunk, ~r/\r?\n\r?\n/)
|
||||
{complete_events, [rest]} = Enum.split(parts, -1)
|
||||
|
||||
Enum.reduce(complete_events, %{sse | buffer: rest}, &process_event(&2, &1))
|
||||
end
|
||||
|
||||
@doc """
|
||||
The unparsed transport bytes, for callers that discover after the fact
|
||||
that the response was not an event stream (e.g. a provider that ignored
|
||||
the `stream` flag and answered with plain JSON).
|
||||
"""
|
||||
@spec raw_body(t()) :: binary()
|
||||
def raw_body(%__MODULE__{} = sse) do
|
||||
sse.raw |> Enum.reverse() |> IO.iodata_to_binary()
|
||||
end
|
||||
|
||||
@spec finish(t()) :: %{content: binary() | nil, tool_calls: [map()], usage: map() | nil}
|
||||
def finish(%__MODULE__{} = sse) do
|
||||
# A final event may arrive without its trailing blank line.
|
||||
sse =
|
||||
case String.trim(sse.buffer) do
|
||||
"" -> sse
|
||||
remnant -> process_event(%{sse | buffer: ""}, remnant)
|
||||
end
|
||||
|
||||
%{
|
||||
content: assembled_content(sse),
|
||||
tool_calls: assembled_tool_calls(sse),
|
||||
usage: sse.usage
|
||||
}
|
||||
end
|
||||
|
||||
defp process_event(%{done?: true} = sse, _event), do: sse
|
||||
|
||||
defp process_event(sse, event) do
|
||||
data =
|
||||
event
|
||||
|> String.split(~r/\r?\n/)
|
||||
|> Enum.flat_map(&data_line/1)
|
||||
|> Enum.join("\n")
|
||||
|
||||
cond do
|
||||
data == "" ->
|
||||
sse
|
||||
|
||||
String.trim(data) == "[DONE]" ->
|
||||
%{sse | done?: true}
|
||||
|
||||
true ->
|
||||
case Jason.decode(data) do
|
||||
{:ok, payload} when is_map(payload) -> apply_payload(sse, payload)
|
||||
_other -> sse
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
defp data_line("data: " <> rest), do: [rest]
|
||||
defp data_line("data:" <> rest), do: [rest]
|
||||
defp data_line(_line), do: []
|
||||
|
||||
defp apply_payload(sse, payload) do
|
||||
delta = get_in(payload, ["choices", Access.at(0), "delta"]) || %{}
|
||||
|
||||
sse
|
||||
|> apply_content(delta["content"])
|
||||
|> apply_tool_calls(delta["tool_calls"])
|
||||
|> apply_usage(payload["usage"])
|
||||
end
|
||||
|
||||
defp apply_content(sse, content) when is_binary(content) and content != "" do
|
||||
%{sse | content: [content | sse.content], content?: true}
|
||||
|> maybe_emit()
|
||||
end
|
||||
|
||||
defp apply_content(sse, _content), do: sse
|
||||
|
||||
defp apply_tool_calls(sse, [_ | _] = fragments) do
|
||||
Enum.reduce(fragments, sse, fn fragment, acc ->
|
||||
index = fragment["index"] || 0
|
||||
existing = Map.get(acc.tool_calls, index, %{id: nil, name: nil, arguments: []})
|
||||
function_part = fragment["function"] || %{}
|
||||
|
||||
merged = %{
|
||||
id: existing.id || fragment["id"],
|
||||
name: existing.name || function_part["name"],
|
||||
arguments: [existing.arguments, function_part["arguments"] || ""]
|
||||
}
|
||||
|
||||
%{acc | tool_calls: Map.put(acc.tool_calls, index, merged)}
|
||||
end)
|
||||
end
|
||||
|
||||
defp apply_tool_calls(sse, _fragments), do: sse
|
||||
|
||||
defp apply_usage(sse, usage) when is_map(usage) and map_size(usage) > 0,
|
||||
do: %{sse | usage: usage}
|
||||
|
||||
defp apply_usage(sse, _usage), do: sse
|
||||
|
||||
defp maybe_emit(%{on_event: nil} = sse), do: sse
|
||||
|
||||
defp maybe_emit(sse) do
|
||||
now = System.monotonic_time(:millisecond)
|
||||
|
||||
if is_nil(sse.last_emit_at) or now - sse.last_emit_at >= sse.emit_interval_ms do
|
||||
sse.on_event.(%{content: assembled_content(sse) || ""})
|
||||
%{sse | last_emit_at: now}
|
||||
else
|
||||
sse
|
||||
end
|
||||
end
|
||||
|
||||
defp assembled_content(%{content?: false}), do: nil
|
||||
|
||||
defp assembled_content(sse) do
|
||||
sse.content |> Enum.reverse() |> IO.iodata_to_binary()
|
||||
end
|
||||
|
||||
defp assembled_tool_calls(sse) do
|
||||
sse.tool_calls
|
||||
|> Enum.sort_by(fn {index, _tool_call} -> index end)
|
||||
|> Enum.map(fn {_index, tool_call} ->
|
||||
%{
|
||||
"id" => tool_call.id,
|
||||
"function" => %{
|
||||
"name" => tool_call.name,
|
||||
"arguments" => IO.iodata_to_binary(tool_call.arguments)
|
||||
}
|
||||
}
|
||||
end)
|
||||
end
|
||||
end
|
||||
Reference in New Issue
Block a user