diff --git a/src/main/engine/OpenCodeManager.ts b/src/main/engine/OpenCodeManager.ts index 80c83ba..c433ec1 100644 --- a/src/main/engine/OpenCodeManager.ts +++ b/src/main/engine/OpenCodeManager.ts @@ -12,6 +12,16 @@ import https from 'https'; import http from 'http'; import { URL } from 'url'; import { BrowserWindow } from 'electron'; +import { + parseSSELines, + parseAnthropicStreamEvent, + parseOpenAIStreamEvent, + createAnthropicStreamAccumulator, + createOpenAIStreamAccumulator, + httpRequestStream, + withRetry, + type HttpStreamError, +} from './streaming'; import { ChatEngine } from './ChatEngine'; import { PostEngine, type PostData } from './PostEngine'; import { MediaEngine, type MediaData } from './MediaEngine'; @@ -470,10 +480,20 @@ export class OpenCodeManager { system: systemPrompt, messages, tools, + stream: true, cache_control: { type: 'ephemeral' }, }; - const response = await this.httpRequest(ZEN_ANTHROPIC_URL, { + // Stream the response with retry for transient errors + const streamAccumulator = createAnthropicStreamAccumulator(); + let stopReason = ''; + let inputTokens = 0; + let outputTokens = 0; + let cacheReadTokens = 0; + let cacheWriteTokens = 0; + let roundText = ''; // Text produced in this round only + + const { events } = await withRetry(() => httpRequestStream(ZEN_ANTHROPIC_URL, { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -483,29 +503,43 @@ export class OpenCodeManager { }, body: JSON.stringify(body), signal, - }); + })); - if (response.statusCode >= 400) { - const errorMsg = this.parseErrorResponse(response); - throw new Error(errorMsg); + for await (const event of events) { + const result = parseAnthropicStreamEvent(event, streamAccumulator); + + // Emit text deltas immediately for real-time streaming + if (result.textDelta) { + accumulatedText += result.textDelta; + roundText += result.textDelta; + if (callbacks.onDelta) { + callbacks.onDelta(result.textDelta); + } + } + + // Collect usage from message_start (input tokens) and message_delta (output tokens) + if (result.usage) { + if (result.usage.inputTokens !== undefined) inputTokens = result.usage.inputTokens; + if (result.usage.cacheReadTokens !== undefined) cacheReadTokens = result.usage.cacheReadTokens; + if (result.usage.cacheWriteTokens !== undefined) cacheWriteTokens = result.usage.cacheWriteTokens; + if (result.usage.outputTokens !== undefined) outputTokens = result.usage.outputTokens; + } + + if (result.finishReason) { + stopReason = result.finishReason; + } } - const data = JSON.parse(response.body); - - // Extract and emit token usage - if (data.usage && callbacks.onTokenUsage) { - const usage = data.usage; - const cacheReadTokens = usage.cache_read_input_tokens || 0; - const cacheWriteTokens = usage.cache_creation_input_tokens || 0; - const inputTokens = (usage.input_tokens || 0) - cacheReadTokens - cacheWriteTokens; - const outputTokens = usage.output_tokens || 0; - const totalTokens = (usage.input_tokens || 0) + outputTokens; + // Emit token usage after stream completes + if (callbacks.onTokenUsage) { + const adjustedInputTokens = inputTokens - cacheReadTokens - cacheWriteTokens; + const totalTokens = inputTokens + outputTokens; const prev = this.conversationUsage.get(conversationId) || { inputTokens: 0, outputTokens: 0, cacheReadTokens: 0, cacheWriteTokens: 0, }; const cumulative = { - inputTokens: prev.inputTokens + inputTokens, + inputTokens: prev.inputTokens + adjustedInputTokens, outputTokens: prev.outputTokens + outputTokens, cacheReadTokens: prev.cacheReadTokens + cacheReadTokens, cacheWriteTokens: prev.cacheWriteTokens + cacheWriteTokens, @@ -513,7 +547,7 @@ export class OpenCodeManager { this.conversationUsage.set(conversationId, cumulative); callbacks.onTokenUsage({ - inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens, totalTokens, + inputTokens: adjustedInputTokens, outputTokens, cacheReadTokens, cacheWriteTokens, totalTokens, cumulativeInputTokens: cumulative.inputTokens, cumulativeOutputTokens: cumulative.outputTokens, cumulativeCacheReadTokens: cumulative.cacheReadTokens, @@ -522,35 +556,19 @@ export class OpenCodeManager { }); } - console.log('[OpenCodeManager] Round', round, 'stop_reason:', data.stop_reason, 'content blocks:', JSON.stringify(data.content?.map((b: AnthropicContentBlock) => ({ type: b.type, textLen: b.text?.length, name: b.name })))); - - if (!data.content) { - throw new Error('API response missing content field'); - } - - // Check if there are tool_use blocks - const toolUseBlocks = (data.content as AnthropicContentBlock[]).filter( - (b: AnthropicContentBlock) => b.type === 'tool_use' - ); - - // Capture text from any block type that has a text field (text, thinking, etc.) - const textBlocks = (data.content as AnthropicContentBlock[]).filter( - (b: AnthropicContentBlock) => b.text - ); - - // Accumulate and stream text content to frontend - for (const block of textBlocks) { - if (block.text) { - accumulatedText += block.text; - if (callbacks.onDelta) { - callbacks.onDelta(block.text); - } + // Collect tool calls from stream accumulator + const toolUseBlocks: Array<{ id: string; name: string; input: unknown }> = []; + for (const [, tc] of streamAccumulator.toolCalls) { + try { + toolUseBlocks.push({ id: tc.id, name: tc.name, input: JSON.parse(tc.arguments) }); + } catch { + toolUseBlocks.push({ id: tc.id, name: tc.name, input: {} }); } } - console.log('[OpenCodeManager] Round', round, 'accumulatedText length:', accumulatedText.length, 'toolUseBlocks:', toolUseBlocks.length); + console.log('[OpenCodeManager] Round', round, 'stopReason:', stopReason, 'accumulatedText length:', accumulatedText.length, 'toolCalls:', toolUseBlocks.length); - if (toolUseBlocks.length === 0 || data.stop_reason !== 'tool_use') { + if (toolUseBlocks.length === 0 || stopReason !== 'tool_use') { // No more tool calls - return all accumulated text console.log('[OpenCodeManager] Returning accumulated text length:', accumulatedText.length); return { content: accumulatedText, toolCalls: allToolCalls }; @@ -558,11 +576,26 @@ export class OpenCodeManager { // Execute tool calls const toolResults: AnthropicContentBlock[] = []; + // Build assistant content blocks for the next message round + const assistantContentBlocks: AnthropicContentBlock[] = []; + + // Add text block with text from this round + if (roundText) { + assistantContentBlocks.push({ type: 'text', text: roundText }); + } for (const toolBlock of toolUseBlocks) { - const toolName = toolBlock.name!; + const toolName = toolBlock.name; const toolArgs = toolBlock.input; - const toolUseId = toolBlock.id!; + const toolUseId = toolBlock.id; + + // Add tool_use block to assistant content + assistantContentBlocks.push({ + type: 'tool_use', + id: toolUseId, + name: toolName, + input: toolArgs, + }); allToolCalls.push({ name: toolName, args: toolArgs }); @@ -643,7 +676,7 @@ export class OpenCodeManager { // Add assistant response and tool results to messages for next round messages = [ ...messages, - { role: 'assistant' as const, content: data.content }, + { role: 'assistant' as const, content: assistantContentBlocks }, { role: 'user' as const, content: toolResults }, ]; } @@ -718,9 +751,18 @@ export class OpenCodeManager { max_tokens: 4096, messages, tools: openaiTools, + stream: true, + stream_options: { include_usage: true }, }; - const response = await this.httpRequest(ZEN_OPENAI_URL, { + // Stream the response with retry for transient errors + const streamAccumulator = createOpenAIStreamAccumulator(); + let finishReason = ''; + let promptTokens = 0; + let completionTokens = 0; + let totalTokens = 0; + + const { events } = await withRetry(() => httpRequestStream(ZEN_OPENAI_URL, { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -728,23 +770,38 @@ export class OpenCodeManager { }, body: JSON.stringify(body), signal, - }); + })); - if (response.statusCode >= 400) { - const errorMsg = this.parseErrorResponse(response); - throw new Error(errorMsg); + for await (const event of events) { + const result = parseOpenAIStreamEvent(event, streamAccumulator); + + // Emit text deltas immediately for real-time streaming + if (result.textDelta) { + accumulatedText += result.textDelta; + if (callbacks.onDelta) { + callbacks.onDelta(result.textDelta); + } + } + + // Collect usage from final chunk + if (result.usage) { + if (result.usage.promptTokens !== undefined) promptTokens = result.usage.promptTokens; + if (result.usage.completionTokens !== undefined) completionTokens = result.usage.completionTokens; + if (result.usage.totalTokens !== undefined) totalTokens = result.usage.totalTokens; + } + + if (result.finishReason) { + finishReason = result.finishReason; + } + + if (result.done) break; } - const data = JSON.parse(response.body); - const choice = data.choices?.[0]; - - // Extract and emit token usage (OpenAI format) - if (data.usage && callbacks.onTokenUsage) { - const usage = data.usage; - const cacheReadTokens = usage.prompt_tokens_details?.cached_tokens || 0; - const inputTokens = (usage.prompt_tokens || 0) - cacheReadTokens; - const outputTokens = usage.completion_tokens || 0; - const totalTokens = usage.total_tokens || (usage.prompt_tokens || 0) + outputTokens; + // Emit token usage after stream completes + if (callbacks.onTokenUsage) { + const cacheReadTokens = 0; // OpenAI doesn't provide cache info in streaming + const inputTokens = promptTokens; + const outputTokens = completionTokens; const prev = this.conversationUsage.get(conversationId) || { inputTokens: 0, outputTokens: 0, cacheReadTokens: 0, cacheWriteTokens: 0, @@ -758,7 +815,8 @@ export class OpenCodeManager { this.conversationUsage.set(conversationId, cumulative); callbacks.onTokenUsage({ - inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens: 0, totalTokens, + inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens: 0, + totalTokens: totalTokens || inputTokens + outputTokens, cumulativeInputTokens: cumulative.inputTokens, cumulativeOutputTokens: cumulative.outputTokens, cumulativeCacheReadTokens: cumulative.cacheReadTokens, @@ -767,57 +825,40 @@ export class OpenCodeManager { }); } - console.log('[OpenCodeManager:OpenAI] Round', round, 'status:', response.statusCode, 'content type:', typeof choice?.message?.content, 'content length:', choice?.message?.content?.length, 'tool_calls:', choice?.message?.tool_calls?.length); - - if (!choice?.message) { - throw new Error('API response missing expected message content'); + // Collect tool calls from stream accumulator + const parsedToolCalls: Array<{ id: string; name: string; args: unknown }> = []; + for (const [, tc] of streamAccumulator.toolCalls) { + try { + parsedToolCalls.push({ id: tc.id, name: tc.name, args: JSON.parse(tc.arguments) }); + } catch { + parsedToolCalls.push({ id: tc.id, name: tc.name, args: {} }); + } } - // Handle content that might be a string or an array of content parts - let textContent = ''; - const content = choice.message.content; - if (typeof content === 'string') { - textContent = content; - } else if (Array.isArray(content)) { - // Handle array of content parts (some models return this format) - // Accept any part that has a text field, regardless of type - textContent = content - .filter((part: { type?: string; text?: string }) => part.text) - .map((part: { text: string }) => part.text) - .join(''); - - // Log what types we're seeing for debugging - const types = content.map((p: { type?: string }) => p.type).filter(Boolean); - if (types.length > 0) { - console.log('[OpenCodeManager:OpenAI] Content block types:', types); - } - } else if (content && typeof content === 'object') { - // Handle single object with text field - if ('text' in content && typeof content.text === 'string') { - textContent = content.text; - } - } - - if (textContent) { - accumulatedText += textContent; - if (callbacks.onDelta) { - callbacks.onDelta(textContent); - } - } + console.log('[OpenCodeManager:OpenAI] Round', round, 'finishReason:', finishReason, 'text length:', accumulatedText.length, 'toolCalls:', parsedToolCalls.length); // If no tool calls, we're done - if (!choice.message.tool_calls || choice.message.tool_calls.length === 0) { + if (parsedToolCalls.length === 0 || finishReason !== 'tool_calls') { console.log('[OpenCodeManager:OpenAI] Done. Accumulated text length:', accumulatedText.length); return { content: accumulatedText, toolCalls: allToolCalls }; } - // Add assistant message (with tool_calls) to conversation - messages.push(choice.message); + // Build the assistant message with tool_calls for conversation history + const assistantMessage: Record = { + role: 'assistant', + content: accumulatedText || null, + tool_calls: parsedToolCalls.map((tc) => ({ + id: tc.id, + type: 'function', + function: { name: tc.name, arguments: JSON.stringify(tc.args) }, + })), + }; + messages.push(assistantMessage); // Execute tool calls and add results - for (const toolCall of choice.message.tool_calls) { - const toolName = toolCall.function.name; - const toolArgs = JSON.parse(toolCall.function.arguments || '{}'); + for (const toolCall of parsedToolCalls) { + const toolName = toolCall.name; + const toolArgs = toolCall.args; allToolCalls.push({ name: toolName, args: toolArgs }); if (callbacks.onToolCall) { @@ -826,7 +867,7 @@ export class OpenCodeManager { // Check if this is a render tool if (isRenderTool(toolName)) { - const a2uiMessages = generateFromToolCall(conversationId, toolName, toolArgs); + const a2uiMessages = generateFromToolCall(conversationId, toolName, toolArgs as Record); if (a2uiMessages) { emitA2UIMessages(a2uiMessages); } @@ -843,7 +884,7 @@ export class OpenCodeManager { continue; } - const result = await this.executeTool(toolName, toolArgs); + const result = await this.executeTool(toolName, toolArgs as Record); if (callbacks.onToolResult) { callbacks.onToolResult({ name: toolName, result }); diff --git a/src/main/engine/streaming.ts b/src/main/engine/streaming.ts new file mode 100644 index 0000000..6d755fb --- /dev/null +++ b/src/main/engine/streaming.ts @@ -0,0 +1,529 @@ +/** + * SSE Streaming Infrastructure + * + * Provides SSE line parsing, event parsers for OpenAI/Mistral and Anthropic + * stream formats, tool-call accumulation, and retry-with-exponential-backoff. + * + * Used by OpenCodeManager to convert buffered HTTP calls to real-time + * token-by-token streaming for all chat providers. + */ + +import https from 'https'; +import http from 'http'; +import { URL } from 'url'; + +// ── Types ── + +export interface SSEEvent { + event?: string; + data: string; +} + +export interface StreamEventResult { + /** Text content delta to emit to UI */ + textDelta?: string; + /** Whether the stream is complete */ + done: boolean; + /** Finish reason from the model */ + finishReason?: string; + /** Token usage information */ + usage?: { + promptTokens?: number; + completionTokens?: number; + totalTokens?: number; + inputTokens?: number; + outputTokens?: number; + cacheReadTokens?: number; + cacheWriteTokens?: number; + }; +} + +interface ToolCallAccumulator { + id: string; + name: string; + arguments: string; +} + +export interface OpenAIStreamAccumulator { + toolCalls: Map; +} + +export interface AnthropicStreamAccumulator { + toolCalls: Map; +} + +export interface HttpStreamError extends Error { + statusCode?: number; + retryAfter?: number; + isAbort?: boolean; +} + +// ── SSE Line Parsing ── + +/** + * Parse raw SSE text into structured events. + * + * SSE protocol: events are separated by double-newlines (\n\n). + * Each event can have `event:` and `data:` lines. + * Multiple `data:` lines within one event are concatenated with newlines. + * Lines starting with `:` are comments (ignored). + * + * Returns parsed events and any remaining incomplete text (buffer). + */ +export function parseSSELines(text: string): { events: SSEEvent[]; remaining: string } { + const events: SSEEvent[] = []; + + // Normalize \r\n to \n + const normalized = text.replace(/\r\n/g, '\n'); + + // Split on double-newline (event boundary) + const parts = normalized.split('\n\n'); + + // Last part may be incomplete (no trailing \n\n) + const remaining = normalized.endsWith('\n\n') ? '' : parts.pop() || ''; + + for (const part of parts) { + if (!part.trim()) continue; + + let eventType: string | undefined; + const dataLines: string[] = []; + + for (const line of part.split('\n')) { + // Comment lines start with ':' + if (line.startsWith(':')) continue; + + if (line.startsWith('event: ') || line.startsWith('event:')) { + eventType = line.slice(line.indexOf(':') + 1).trim(); + } else if (line.startsWith('data: ') || line.startsWith('data:')) { + dataLines.push(line.slice(line.indexOf(':') + 1).trimStart()); + } + } + + if (dataLines.length > 0) { + events.push({ + event: eventType, + data: dataLines.join('\n'), + }); + } + } + + return { events, remaining }; +} + +// ── Accumulator Factories ── + +export function createOpenAIStreamAccumulator(): OpenAIStreamAccumulator { + return { toolCalls: new Map() }; +} + +export function createAnthropicStreamAccumulator(): AnthropicStreamAccumulator { + return { toolCalls: new Map() }; +} + +// ── OpenAI/Mistral SSE Parser ── + +/** + * Parse a single OpenAI/Mistral SSE event and update the accumulator. + * + * OpenAI streaming format: + * - Text deltas: choices[0].delta.content + * - Tool call start: delta.tool_calls[i] with id + function.name + * - Tool call fragments: delta.tool_calls[i].function.arguments (append) + * - Finish reason: choices[0].finish_reason + * - Usage: usage object in final chunk (requires stream_options.include_usage) + * - [DONE] sentinel: stop iteration + */ +export function parseOpenAIStreamEvent( + event: SSEEvent, + accumulator: OpenAIStreamAccumulator, +): StreamEventResult { + // Handle [DONE] sentinel + if (event.data === '[DONE]') { + return { done: true }; + } + + const data = JSON.parse(event.data); + const choice = data.choices?.[0]; + const result: StreamEventResult = { done: false }; + + if (choice) { + const delta = choice.delta; + + // Text content delta + if (delta?.content && delta.content.length > 0) { + result.textDelta = delta.content; + } + + // Tool calls + if (delta?.tool_calls) { + for (const tc of delta.tool_calls) { + const idx = tc.index; + const existing = accumulator.toolCalls.get(idx); + + if (tc.id || tc.function?.name) { + // New tool call or update + if (!existing) { + accumulator.toolCalls.set(idx, { + id: tc.id || '', + name: tc.function?.name || '', + arguments: tc.function?.arguments || '', + }); + } else { + if (tc.id) existing.id = tc.id; + if (tc.function?.name) existing.name = tc.function.name; + if (tc.function?.arguments) existing.arguments += tc.function.arguments; + } + } else if (existing && tc.function?.arguments) { + // Append argument fragment + existing.arguments += tc.function.arguments; + } + } + } + + // Finish reason + if (choice.finish_reason) { + result.finishReason = choice.finish_reason; + } + } + + // Token usage (arrives in final chunk with stream_options.include_usage) + if (data.usage) { + result.usage = { + promptTokens: data.usage.prompt_tokens, + completionTokens: data.usage.completion_tokens, + totalTokens: data.usage.total_tokens, + }; + } + + return result; +} + +// ── Anthropic SSE Parser ── + +/** + * Parse a single Anthropic SSE event and update the accumulator. + * + * Anthropic streaming format uses named event types: + * - message_start: input token usage + * - content_block_start: text or tool_use block begins + * - content_block_delta: text_delta or input_json_delta + * - content_block_stop: block ends + * - message_delta: output tokens + stop_reason + * - message_stop: stream complete + * - ping: keep-alive (ignored) + * - error: server error mid-stream + */ +export function parseAnthropicStreamEvent( + event: SSEEvent, + accumulator: AnthropicStreamAccumulator, +): StreamEventResult { + const data = JSON.parse(event.data); + const result: StreamEventResult = { done: false }; + + switch (event.event) { + case 'message_start': { + const usage = data.message?.usage; + if (usage) { + result.usage = { + inputTokens: usage.input_tokens || 0, + cacheReadTokens: usage.cache_read_input_tokens || 0, + cacheWriteTokens: usage.cache_creation_input_tokens || 0, + }; + } + break; + } + + case 'content_block_start': { + const block = data.content_block; + if (block?.type === 'tool_use') { + accumulator.toolCalls.set(data.index, { + id: block.id, + name: block.name, + arguments: '', + }); + } + // text block start is a no-op (empty initial text) + break; + } + + case 'content_block_delta': { + const delta = data.delta; + if (delta?.type === 'text_delta' && delta.text) { + result.textDelta = delta.text; + } else if (delta?.type === 'input_json_delta' && delta.partial_json) { + const tc = accumulator.toolCalls.get(data.index); + if (tc) { + tc.arguments += delta.partial_json; + } + } + break; + } + + case 'content_block_stop': + // Block is complete. Tool arguments can now be parsed by the caller. + break; + + case 'message_delta': { + if (data.usage) { + result.usage = { + outputTokens: data.usage.output_tokens || 0, + }; + } + if (data.delta?.stop_reason) { + result.finishReason = data.delta.stop_reason; + } + break; + } + + case 'message_stop': + result.done = true; + break; + + case 'ping': + // Keep-alive, ignore + break; + + case 'error': { + const errorMsg = data.error?.message || 'Unknown streaming error'; + throw new Error(errorMsg); + } + + default: + // Unknown event type, ignore + break; + } + + return result; +} + +// ── Retry with Exponential Backoff ── + +const RETRYABLE_STATUS_CODES = new Set([429, 502, 503]); + +/** + * Retry a function with exponential backoff for transient HTTP errors. + * + * Retries on 429 (rate limit), 502 (bad gateway), 503 (service unavailable). + * Does NOT retry on other 4xx errors or abort. + * Respects Retry-After header for 429 responses. + */ +export async function withRetry( + fn: () => Promise, + options: { maxRetries?: number } = {}, +): Promise { + const maxRetries = options.maxRetries ?? 3; + let lastError: Error | undefined; + + for (let attempt = 0; attempt <= maxRetries; attempt++) { + try { + return await fn(); + } catch (error) { + lastError = error as Error; + const httpError = error as HttpStreamError; + + // Don't retry on abort + if (httpError.isAbort || httpError.message === 'Request cancelled') { + throw error; + } + + // Don't retry on non-retryable status codes + if (httpError.statusCode && !RETRYABLE_STATUS_CODES.has(httpError.statusCode)) { + throw error; + } + + // Don't retry if we've exhausted retries + if (attempt >= maxRetries) { + throw error; + } + + // Calculate delay with exponential backoff and jitter + const baseDelay = Math.pow(2, attempt) * 1000; // 1s, 2s, 4s + const jitter = Math.random() * 500; + let delay = baseDelay + jitter; + + // Respect Retry-After header for 429 + if (httpError.retryAfter && httpError.retryAfter > 0) { + delay = Math.max(delay, httpError.retryAfter * 1000); + } + + await new Promise(resolve => setTimeout(resolve, delay)); + } + } + + throw lastError; +} + +// ── HTTP Streaming Request ── + +interface HttpStreamOptions { + method?: string; + headers?: Record; + body?: string; + signal?: AbortSignal; + timeout?: number; +} + +/** + * Make an HTTP request that returns an async iterable of SSE events. + * + * Uses Node.js http/https modules directly, reading the response + * as a readable stream and parsing SSE events incrementally. + * + * On non-2xx status: collects the error body and throws. + * Supports AbortSignal for cancellation. + */ +export function httpRequestStream( + urlStr: string, + options: HttpStreamOptions, +): Promise<{ + statusCode: number; + events: AsyncIterable; +}> { + return new Promise((resolve, reject) => { + const url = new URL(urlStr); + const protocol = url.protocol === 'https:' ? https : http; + const timeout = options.timeout ?? 120000; + + const req = protocol.request(url, { + method: options.method || 'POST', + headers: options.headers || {}, + timeout, + }, (res) => { + const statusCode = res.statusCode || 0; + + // Non-2xx: collect error body and throw + if (statusCode < 200 || statusCode >= 300) { + let errorBody = ''; + res.on('data', (chunk: Buffer) => { errorBody += chunk; }); + res.on('end', () => { + const error: HttpStreamError = new Error(`API error: ${statusCode}`) as HttpStreamError; + error.statusCode = statusCode; + + // Parse Retry-After for 429 + if (statusCode === 429) { + const retryAfter = res.headers['retry-after']; + if (retryAfter) { + const seconds = parseInt(retryAfter, 10); + if (!isNaN(seconds)) { + error.retryAfter = seconds; + } + } + } + + // Try to extract a better error message + try { + const parsed = JSON.parse(errorBody); + error.message = parsed.error?.message || parsed.message || error.message; + } catch { + if (errorBody.length > 0) { + error.message = `${error.message}: ${errorBody.slice(0, 200)}`; + } + } + reject(error); + }); + return; + } + + // 2xx: create async iterable of SSE events + const events: AsyncIterable = { + [Symbol.asyncIterator]() { + let buffer = ''; + let done = false; + const eventQueue: SSEEvent[] = []; + let resolveNext: ((value: IteratorResult) => void) | null = null; + let rejectNext: ((error: Error) => void) | null = null; + + res.on('data', (chunk: Buffer) => { + buffer += chunk.toString('utf-8'); + const { events: parsed, remaining } = parseSSELines(buffer); + buffer = remaining; + + for (const event of parsed) { + if (resolveNext) { + const resolve = resolveNext; + resolveNext = null; + rejectNext = null; + resolve({ value: event, done: false }); + } else { + eventQueue.push(event); + } + } + }); + + res.on('end', () => { + done = true; + if (resolveNext) { + const resolve = resolveNext; + resolveNext = null; + rejectNext = null; + resolve({ value: undefined as unknown as SSEEvent, done: true }); + } + }); + + res.on('error', (err: Error) => { + done = true; + if (rejectNext) { + const reject = rejectNext; + resolveNext = null; + rejectNext = null; + reject(err); + } + }); + + return { + next(): Promise> { + // Return queued event immediately + if (eventQueue.length > 0) { + return Promise.resolve({ value: eventQueue.shift()!, done: false }); + } + + // Stream already ended + if (done) { + return Promise.resolve({ value: undefined as unknown as SSEEvent, done: true }); + } + + // Wait for next event + return new Promise>((resolve, reject) => { + resolveNext = resolve; + rejectNext = reject; + }); + }, + }; + }, + }; + + resolve({ statusCode, events }); + }); + + req.on('error', (err: Error) => { + const error: HttpStreamError = err as HttpStreamError; + if (options.signal?.aborted) { + error.isAbort = true; + } + reject(error); + }); + + req.on('timeout', () => { + req.destroy(); + reject(new Error('Request timed out')); + }); + + if (options.signal) { + if (options.signal.aborted) { + req.destroy(); + const error: HttpStreamError = new Error('Request cancelled') as HttpStreamError; + error.isAbort = true; + reject(error); + return; + } + options.signal.addEventListener('abort', () => { + req.destroy(); + }); + } + + if (options.body) { + req.write(options.body); + } + req.end(); + }); +} diff --git a/tests/engine/streaming.test.ts b/tests/engine/streaming.test.ts new file mode 100644 index 0000000..df0d534 --- /dev/null +++ b/tests/engine/streaming.test.ts @@ -0,0 +1,743 @@ +/** + * Tests for SSE streaming infrastructure (PR 1) + * + * Covers: + * - SSE line parsing (buffering partial lines across TCP chunks) + * - OpenAI/Mistral SSE event parsing (text deltas, tool calls, usage, [DONE]) + * - Anthropic SSE event parsing (message_start, content_block_delta, etc.) + * - Tool-call argument accumulation during streaming + * - Error handling (mid-stream errors, non-2xx status, abort) + * - Retry with exponential backoff (429/502/503, Retry-After, no retry on 4xx/abort) + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { + parseSSELines, + parseOpenAIStreamEvent, + parseAnthropicStreamEvent, + withRetry, + type SSEEvent, + type OpenAIStreamAccumulator, + type AnthropicStreamAccumulator, + createOpenAIStreamAccumulator, + createAnthropicStreamAccumulator, +} from '../../src/main/engine/streaming'; + +// ── SSE Line Parsing ── + +describe('parseSSELines', () => { + it('parses a complete SSE event from a single chunk', () => { + const buffer = ''; + const chunk = 'data: {"id":"1","choices":[{"delta":{"content":"Hello"}}]}\n\n'; + const { events, remaining } = parseSSELines(buffer + chunk); + expect(events).toHaveLength(1); + expect(events[0]).toEqual({ event: undefined, data: '{"id":"1","choices":[{"delta":{"content":"Hello"}}]}' }); + expect(remaining).toBe(''); + }); + + it('handles partial lines across TCP chunks', () => { + // First chunk ends mid-line + const chunk1 = 'data: {"id":"1","cho'; + const { events: events1, remaining: rem1 } = parseSSELines(chunk1); + expect(events1).toHaveLength(0); + expect(rem1).toBe('data: {"id":"1","cho'); + + // Second chunk completes the line + const chunk2 = 'ices":[{"delta":{"content":"Hello"}}]}\n\n'; + const { events: events2, remaining: rem2 } = parseSSELines(rem1 + chunk2); + expect(events2).toHaveLength(1); + expect(events2[0].data).toBe('{"id":"1","choices":[{"delta":{"content":"Hello"}}]}'); + expect(rem2).toBe(''); + }); + + it('handles multiple events in a single chunk', () => { + const chunk = 'data: {"a":1}\n\ndata: {"b":2}\n\n'; + const { events, remaining } = parseSSELines(chunk); + expect(events).toHaveLength(2); + expect(events[0].data).toBe('{"a":1}'); + expect(events[1].data).toBe('{"b":2}'); + expect(remaining).toBe(''); + }); + + it('handles named event types (Anthropic format)', () => { + const chunk = 'event: message_start\ndata: {"type":"message_start"}\n\n'; + const { events, remaining } = parseSSELines(chunk); + expect(events).toHaveLength(1); + expect(events[0].event).toBe('message_start'); + expect(events[0].data).toBe('{"type":"message_start"}'); + expect(remaining).toBe(''); + }); + + it('handles [DONE] sentinel', () => { + const chunk = 'data: [DONE]\n\n'; + const { events, remaining } = parseSSELines(chunk); + expect(events).toHaveLength(1); + expect(events[0].data).toBe('[DONE]'); + expect(remaining).toBe(''); + }); + + it('ignores empty data lines (keep-alive pings)', () => { + const chunk = ':\n\ndata: {"a":1}\n\n'; + const { events, remaining } = parseSSELines(chunk); + // The comment line ':' should be ignored + expect(events).toHaveLength(1); + expect(events[0].data).toBe('{"a":1}'); + expect(remaining).toBe(''); + }); + + it('handles multiple data lines for a single event (concatenation per SSE spec)', () => { + const chunk = 'data: line1\ndata: line2\n\n'; + const { events, remaining } = parseSSELines(chunk); + expect(events).toHaveLength(1); + expect(events[0].data).toBe('line1\nline2'); + expect(remaining).toBe(''); + }); + + it('returns incomplete data as remaining buffer', () => { + const chunk = 'data: {"partial'; + const { events, remaining } = parseSSELines(chunk); + expect(events).toHaveLength(0); + expect(remaining).toBe('data: {"partial'); + }); + + it('handles \\r\\n line endings', () => { + const chunk = 'data: {"a":1}\r\n\r\n'; + const { events, remaining } = parseSSELines(chunk); + expect(events).toHaveLength(1); + expect(events[0].data).toBe('{"a":1}'); + expect(remaining).toBe(''); + }); +}); + +// ── OpenAI/Mistral Stream Event Parsing ── + +describe('parseOpenAIStreamEvent', () => { + let accumulator: OpenAIStreamAccumulator; + + beforeEach(() => { + accumulator = createOpenAIStreamAccumulator(); + }); + + it('extracts text delta from content field', () => { + const event: SSEEvent = { + data: JSON.stringify({ + id: 'chatcmpl-1', + choices: [{ delta: { content: 'Hello' }, index: 0 }], + }), + }; + const result = parseOpenAIStreamEvent(event, accumulator); + expect(result.textDelta).toBe('Hello'); + expect(result.done).toBe(false); + }); + + it('accumulates tool call start (id + name)', () => { + const event: SSEEvent = { + data: JSON.stringify({ + id: 'chatcmpl-1', + choices: [{ + delta: { + tool_calls: [{ + index: 0, + id: 'call_abc', + function: { name: 'search_posts', arguments: '' }, + }], + }, + index: 0, + }], + }), + }; + const result = parseOpenAIStreamEvent(event, accumulator); + expect(result.textDelta).toBeUndefined(); + expect(accumulator.toolCalls.get(0)).toEqual({ + id: 'call_abc', + name: 'search_posts', + arguments: '', + }); + }); + + it('accumulates tool call argument fragments', () => { + // First event: tool call start + parseOpenAIStreamEvent({ + data: JSON.stringify({ + choices: [{ + delta: { + tool_calls: [{ + index: 0, id: 'call_abc', + function: { name: 'search_posts', arguments: '' }, + }], + }, + index: 0, + }], + }), + }, accumulator); + + // Second event: argument fragment + parseOpenAIStreamEvent({ + data: JSON.stringify({ + choices: [{ + delta: { + tool_calls: [{ + index: 0, + function: { arguments: '{"query"' }, + }], + }, + index: 0, + }], + }), + }, accumulator); + + // Third event: more arguments + parseOpenAIStreamEvent({ + data: JSON.stringify({ + choices: [{ + delta: { + tool_calls: [{ + index: 0, + function: { arguments: ': "test"}' }, + }], + }, + index: 0, + }], + }), + }, accumulator); + + expect(accumulator.toolCalls.get(0)?.arguments).toBe('{"query": "test"}'); + }); + + it('handles multiple concurrent tool calls', () => { + // Tool call 0 + parseOpenAIStreamEvent({ + data: JSON.stringify({ + choices: [{ + delta: { + tool_calls: [ + { index: 0, id: 'call_1', function: { name: 'search_posts', arguments: '{"q":"a"}' } }, + { index: 1, id: 'call_2', function: { name: 'list_posts', arguments: '{"limit":5}' } }, + ], + }, + index: 0, + }], + }), + }, accumulator); + + expect(accumulator.toolCalls.get(0)?.name).toBe('search_posts'); + expect(accumulator.toolCalls.get(1)?.name).toBe('list_posts'); + }); + + it('detects finish_reason stop', () => { + const event: SSEEvent = { + data: JSON.stringify({ + choices: [{ delta: {}, finish_reason: 'stop', index: 0 }], + }), + }; + const result = parseOpenAIStreamEvent(event, accumulator); + expect(result.finishReason).toBe('stop'); + }); + + it('detects finish_reason tool_calls', () => { + const event: SSEEvent = { + data: JSON.stringify({ + choices: [{ delta: {}, finish_reason: 'tool_calls', index: 0 }], + }), + }; + const result = parseOpenAIStreamEvent(event, accumulator); + expect(result.finishReason).toBe('tool_calls'); + }); + + it('extracts token usage from final chunk', () => { + const event: SSEEvent = { + data: JSON.stringify({ + choices: [{ delta: {}, index: 0 }], + usage: { + prompt_tokens: 150, + completion_tokens: 42, + total_tokens: 192, + }, + }), + }; + const result = parseOpenAIStreamEvent(event, accumulator); + expect(result.usage).toEqual({ + promptTokens: 150, + completionTokens: 42, + totalTokens: 192, + }); + }); + + it('handles [DONE] sentinel', () => { + const event: SSEEvent = { data: '[DONE]' }; + const result = parseOpenAIStreamEvent(event, accumulator); + expect(result.done).toBe(true); + }); + + it('returns empty result for empty content delta', () => { + const event: SSEEvent = { + data: JSON.stringify({ + choices: [{ delta: { content: '' }, index: 0 }], + }), + }; + const result = parseOpenAIStreamEvent(event, accumulator); + expect(result.textDelta).toBeUndefined(); + }); +}); + +// ── Anthropic Stream Event Parsing ── + +describe('parseAnthropicStreamEvent', () => { + let accumulator: AnthropicStreamAccumulator; + + beforeEach(() => { + accumulator = createAnthropicStreamAccumulator(); + }); + + it('extracts input_tokens from message_start', () => { + const event: SSEEvent = { + event: 'message_start', + data: JSON.stringify({ + type: 'message_start', + message: { + id: 'msg_1', + model: 'claude-sonnet-4-5', + usage: { + input_tokens: 150, + cache_read_input_tokens: 50, + cache_creation_input_tokens: 10, + }, + }, + }), + }; + const result = parseAnthropicStreamEvent(event, accumulator); + expect(result.usage).toEqual({ + inputTokens: 150, + cacheReadTokens: 50, + cacheWriteTokens: 10, + }); + }); + + it('handles text content_block_start (no-op)', () => { + const event: SSEEvent = { + event: 'content_block_start', + data: JSON.stringify({ + type: 'content_block_start', + index: 0, + content_block: { type: 'text', text: '' }, + }), + }; + const result = parseAnthropicStreamEvent(event, accumulator); + expect(result.textDelta).toBeUndefined(); + }); + + it('handles tool_use content_block_start', () => { + const event: SSEEvent = { + event: 'content_block_start', + data: JSON.stringify({ + type: 'content_block_start', + index: 1, + content_block: { type: 'tool_use', id: 'toolu_abc', name: 'search_posts' }, + }), + }; + const result = parseAnthropicStreamEvent(event, accumulator); + expect(result.textDelta).toBeUndefined(); + expect(accumulator.toolCalls.get(1)).toEqual({ + id: 'toolu_abc', + name: 'search_posts', + arguments: '', + }); + }); + + it('extracts text_delta from content_block_delta', () => { + const event: SSEEvent = { + event: 'content_block_delta', + data: JSON.stringify({ + type: 'content_block_delta', + index: 0, + delta: { type: 'text_delta', text: 'Hello world' }, + }), + }; + const result = parseAnthropicStreamEvent(event, accumulator); + expect(result.textDelta).toBe('Hello world'); + }); + + it('accumulates tool input_json_delta fragments', () => { + // Start tool block + parseAnthropicStreamEvent({ + event: 'content_block_start', + data: JSON.stringify({ + type: 'content_block_start', + index: 1, + content_block: { type: 'tool_use', id: 'toolu_abc', name: 'search_posts' }, + }), + }, accumulator); + + // First argument fragment + parseAnthropicStreamEvent({ + event: 'content_block_delta', + data: JSON.stringify({ + type: 'content_block_delta', + index: 1, + delta: { type: 'input_json_delta', partial_json: '{"query"' }, + }), + }, accumulator); + + // Second argument fragment + parseAnthropicStreamEvent({ + event: 'content_block_delta', + data: JSON.stringify({ + type: 'content_block_delta', + index: 1, + delta: { type: 'input_json_delta', partial_json: ': "test"}' }, + }), + }, accumulator); + + expect(accumulator.toolCalls.get(1)?.arguments).toBe('{"query": "test"}'); + }); + + it('extracts output_tokens from message_delta', () => { + const event: SSEEvent = { + event: 'message_delta', + data: JSON.stringify({ + type: 'message_delta', + delta: { stop_reason: 'end_turn' }, + usage: { output_tokens: 42 }, + }), + }; + const result = parseAnthropicStreamEvent(event, accumulator); + expect(result.usage).toEqual({ outputTokens: 42 }); + expect(result.finishReason).toBe('end_turn'); + }); + + it('signals done on message_stop', () => { + const event: SSEEvent = { + event: 'message_stop', + data: JSON.stringify({ type: 'message_stop' }), + }; + const result = parseAnthropicStreamEvent(event, accumulator); + expect(result.done).toBe(true); + }); + + it('ignores ping events', () => { + const event: SSEEvent = { + event: 'ping', + data: JSON.stringify({ type: 'ping' }), + }; + const result = parseAnthropicStreamEvent(event, accumulator); + expect(result.textDelta).toBeUndefined(); + expect(result.done).toBe(false); + }); + + it('throws on error events', () => { + const event: SSEEvent = { + event: 'error', + data: JSON.stringify({ + type: 'error', + error: { type: 'overloaded_error', message: 'Server is overloaded' }, + }), + }; + expect(() => parseAnthropicStreamEvent(event, accumulator)).toThrow('Server is overloaded'); + }); + + it('signals tool_use finish reason from message_delta', () => { + const event: SSEEvent = { + event: 'message_delta', + data: JSON.stringify({ + type: 'message_delta', + delta: { stop_reason: 'tool_use' }, + usage: { output_tokens: 10 }, + }), + }; + const result = parseAnthropicStreamEvent(event, accumulator); + expect(result.finishReason).toBe('tool_use'); + }); +}); + +// ── Tool Call Accumulation ── + +describe('tool call accumulation', () => { + it('OpenAI: builds complete tool calls from fragments', () => { + const acc = createOpenAIStreamAccumulator(); + + // Start + parseOpenAIStreamEvent({ + data: JSON.stringify({ + choices: [{ + delta: { + tool_calls: [{ + index: 0, id: 'call_1', + function: { name: 'search_posts', arguments: '' }, + }], + }, + index: 0, + }], + }), + }, acc); + + // Fragments + for (const frag of ['{"', 'query', '": "', 'hello', '"}']) { + parseOpenAIStreamEvent({ + data: JSON.stringify({ + choices: [{ + delta: { tool_calls: [{ index: 0, function: { arguments: frag } }] }, + index: 0, + }], + }), + }, acc); + } + + const tc = acc.toolCalls.get(0)!; + expect(tc.id).toBe('call_1'); + expect(tc.name).toBe('search_posts'); + expect(JSON.parse(tc.arguments)).toEqual({ query: 'hello' }); + }); + + it('Anthropic: builds complete tool calls from fragments', () => { + const acc = createAnthropicStreamAccumulator(); + + // Start block + parseAnthropicStreamEvent({ + event: 'content_block_start', + data: JSON.stringify({ + type: 'content_block_start', + index: 1, + content_block: { type: 'tool_use', id: 'toolu_1', name: 'list_posts' }, + }), + }, acc); + + // Fragments + for (const frag of ['{"', 'limit', '": ', '5}']) { + parseAnthropicStreamEvent({ + event: 'content_block_delta', + data: JSON.stringify({ + type: 'content_block_delta', + index: 1, + delta: { type: 'input_json_delta', partial_json: frag }, + }), + }, acc); + } + + const tc = acc.toolCalls.get(1)!; + expect(tc.id).toBe('toolu_1'); + expect(tc.name).toBe('list_posts'); + expect(JSON.parse(tc.arguments)).toEqual({ limit: 5 }); + }); +}); + +// ── Retry with Exponential Backoff ── + +describe('withRetry', () => { + beforeEach(() => { + vi.useFakeTimers(); + }); + + it('returns result on first successful call', async () => { + const fn = vi.fn().mockResolvedValue('success'); + const promise = withRetry(fn, { maxRetries: 3 }); + const result = await promise; + expect(result).toBe('success'); + expect(fn).toHaveBeenCalledTimes(1); + }); + + it('retries on 429 status and succeeds', async () => { + const error429 = Object.assign(new Error('Rate limited'), { statusCode: 429 }); + const fn = vi.fn() + .mockRejectedValueOnce(error429) + .mockResolvedValue('success'); + + const promise = withRetry(fn, { maxRetries: 3 }); + // Advance past the retry delay + await vi.advanceTimersByTimeAsync(2000); + const result = await promise; + expect(result).toBe('success'); + expect(fn).toHaveBeenCalledTimes(2); + }); + + it('retries on 502 status', async () => { + const error502 = Object.assign(new Error('Bad Gateway'), { statusCode: 502 }); + const fn = vi.fn() + .mockRejectedValueOnce(error502) + .mockResolvedValue('ok'); + + const promise = withRetry(fn, { maxRetries: 3 }); + await vi.advanceTimersByTimeAsync(2000); + const result = await promise; + expect(result).toBe('ok'); + expect(fn).toHaveBeenCalledTimes(2); + }); + + it('retries on 503 status', async () => { + const error503 = Object.assign(new Error('Service Unavailable'), { statusCode: 503 }); + const fn = vi.fn() + .mockRejectedValueOnce(error503) + .mockResolvedValue('ok'); + + const promise = withRetry(fn, { maxRetries: 3 }); + await vi.advanceTimersByTimeAsync(2000); + const result = await promise; + expect(result).toBe('ok'); + expect(fn).toHaveBeenCalledTimes(2); + }); + + it('does NOT retry on 400 status', async () => { + const error400 = Object.assign(new Error('Bad Request'), { statusCode: 400 }); + const fn = vi.fn().mockRejectedValue(error400); + + await expect(withRetry(fn, { maxRetries: 3 })).rejects.toThrow('Bad Request'); + expect(fn).toHaveBeenCalledTimes(1); + }); + + it('does NOT retry on 401 status', async () => { + const error401 = Object.assign(new Error('Unauthorized'), { statusCode: 401 }); + const fn = vi.fn().mockRejectedValue(error401); + + await expect(withRetry(fn, { maxRetries: 3 })).rejects.toThrow('Unauthorized'); + expect(fn).toHaveBeenCalledTimes(1); + }); + + it('does NOT retry on 403 status', async () => { + const error403 = Object.assign(new Error('Forbidden'), { statusCode: 403 }); + const fn = vi.fn().mockRejectedValue(error403); + + await expect(withRetry(fn, { maxRetries: 3 })).rejects.toThrow('Forbidden'); + expect(fn).toHaveBeenCalledTimes(1); + }); + + it('does NOT retry on abort', async () => { + const abortError = Object.assign(new Error('Request cancelled'), { isAbort: true }); + const fn = vi.fn().mockRejectedValue(abortError); + + await expect(withRetry(fn, { maxRetries: 3 })).rejects.toThrow('Request cancelled'); + expect(fn).toHaveBeenCalledTimes(1); + }); + + it('exhausts max retries and throws last error', async () => { + vi.useRealTimers(); // Real timers work better for this test + const error429 = Object.assign(new Error('Rate limited'), { statusCode: 429 }); + let callCount = 0; + const fn = vi.fn().mockImplementation(() => { + callCount++; + return Promise.reject(error429); + }); + + await expect(withRetry(fn, { maxRetries: 2 })).rejects.toThrow('Rate limited'); + expect(fn).toHaveBeenCalledTimes(3); // 1 initial + 2 retries + vi.useFakeTimers(); // Restore for afterEach + }); + + it('respects Retry-After header for 429', async () => { + const error429 = Object.assign(new Error('Rate limited'), { + statusCode: 429, + retryAfter: 5, + }); + const fn = vi.fn() + .mockRejectedValueOnce(error429) + .mockResolvedValue('ok'); + + const promise = withRetry(fn, { maxRetries: 3 }); + // Should NOT have retried yet at 3 seconds (Retry-After is 5) + await vi.advanceTimersByTimeAsync(3000); + expect(fn).toHaveBeenCalledTimes(1); + // Advance past the Retry-After + await vi.advanceTimersByTimeAsync(3000); + const result = await promise; + expect(result).toBe('ok'); + expect(fn).toHaveBeenCalledTimes(2); + }); + + afterEach(() => { + vi.useRealTimers(); + }); +}); + +// ── Full stream-to-result integration ── + +describe('stream event sequences', () => { + it('OpenAI: processes a complete text response stream', () => { + const acc = createOpenAIStreamAccumulator(); + const textChunks: string[] = []; + + const events: SSEEvent[] = [ + { data: JSON.stringify({ choices: [{ delta: { role: 'assistant' }, index: 0 }] }) }, + { data: JSON.stringify({ choices: [{ delta: { content: 'Hello' }, index: 0 }] }) }, + { data: JSON.stringify({ choices: [{ delta: { content: ' world' }, index: 0 }] }) }, + { data: JSON.stringify({ choices: [{ delta: { content: '!' }, index: 0 }] }) }, + { data: JSON.stringify({ choices: [{ delta: {}, finish_reason: 'stop', index: 0 }], usage: { prompt_tokens: 10, completion_tokens: 3, total_tokens: 13 } }) }, + { data: '[DONE]' }, + ]; + + for (const event of events) { + const result = parseOpenAIStreamEvent(event, acc); + if (result.textDelta) textChunks.push(result.textDelta); + } + + expect(textChunks.join('')).toBe('Hello world!'); + }); + + it('Anthropic: processes a complete text response stream', () => { + const acc = createAnthropicStreamAccumulator(); + const textChunks: string[] = []; + + const events: SSEEvent[] = [ + { event: 'message_start', data: JSON.stringify({ type: 'message_start', message: { id: 'msg_1', model: 'claude-sonnet-4', usage: { input_tokens: 100 } } }) }, + { event: 'content_block_start', data: JSON.stringify({ type: 'content_block_start', index: 0, content_block: { type: 'text', text: '' } }) }, + { event: 'content_block_delta', data: JSON.stringify({ type: 'content_block_delta', index: 0, delta: { type: 'text_delta', text: 'Hello' } }) }, + { event: 'content_block_delta', data: JSON.stringify({ type: 'content_block_delta', index: 0, delta: { type: 'text_delta', text: ' world!' } }) }, + { event: 'content_block_stop', data: JSON.stringify({ type: 'content_block_stop', index: 0 }) }, + { event: 'message_delta', data: JSON.stringify({ type: 'message_delta', delta: { stop_reason: 'end_turn' }, usage: { output_tokens: 5 } }) }, + { event: 'message_stop', data: JSON.stringify({ type: 'message_stop' }) }, + ]; + + for (const event of events) { + const result = parseAnthropicStreamEvent(event, acc); + if (result.textDelta) textChunks.push(result.textDelta); + } + + expect(textChunks.join('')).toBe('Hello world!'); + }); + + it('OpenAI: processes a tool call response stream', () => { + const acc = createOpenAIStreamAccumulator(); + + const events: SSEEvent[] = [ + { data: JSON.stringify({ choices: [{ delta: { role: 'assistant', tool_calls: [{ index: 0, id: 'call_1', function: { name: 'search_posts', arguments: '' } }] }, index: 0 }] }) }, + { data: JSON.stringify({ choices: [{ delta: { tool_calls: [{ index: 0, function: { arguments: '{"query"' } }] }, index: 0 }] }) }, + { data: JSON.stringify({ choices: [{ delta: { tool_calls: [{ index: 0, function: { arguments: ': "test"}' } }] }, index: 0 }] }) }, + { data: JSON.stringify({ choices: [{ delta: {}, finish_reason: 'tool_calls', index: 0 }] }) }, + { data: '[DONE]' }, + ]; + + for (const event of events) { + parseOpenAIStreamEvent(event, acc); + } + + expect(acc.toolCalls.size).toBe(1); + const tc = acc.toolCalls.get(0)!; + expect(tc.name).toBe('search_posts'); + expect(JSON.parse(tc.arguments)).toEqual({ query: 'test' }); + }); + + it('Anthropic: processes a tool call response stream', () => { + const acc = createAnthropicStreamAccumulator(); + + const events: SSEEvent[] = [ + { event: 'message_start', data: JSON.stringify({ type: 'message_start', message: { id: 'msg_1', usage: { input_tokens: 100 } } }) }, + { event: 'content_block_start', data: JSON.stringify({ type: 'content_block_start', index: 0, content_block: { type: 'text', text: '' } }) }, + { event: 'content_block_delta', data: JSON.stringify({ type: 'content_block_delta', index: 0, delta: { type: 'text_delta', text: 'Let me search.' } }) }, + { event: 'content_block_stop', data: JSON.stringify({ type: 'content_block_stop', index: 0 }) }, + { event: 'content_block_start', data: JSON.stringify({ type: 'content_block_start', index: 1, content_block: { type: 'tool_use', id: 'toolu_1', name: 'search_posts' } }) }, + { event: 'content_block_delta', data: JSON.stringify({ type: 'content_block_delta', index: 1, delta: { type: 'input_json_delta', partial_json: '{"query": "test"}' } }) }, + { event: 'content_block_stop', data: JSON.stringify({ type: 'content_block_stop', index: 1 }) }, + { event: 'message_delta', data: JSON.stringify({ type: 'message_delta', delta: { stop_reason: 'tool_use' }, usage: { output_tokens: 20 } }) }, + { event: 'message_stop', data: JSON.stringify({ type: 'message_stop' }) }, + ]; + + const textChunks: string[] = []; + for (const event of events) { + const result = parseAnthropicStreamEvent(event, acc); + if (result.textDelta) textChunks.push(result.textDelta); + } + + expect(textChunks.join('')).toBe('Let me search.'); + expect(acc.toolCalls.size).toBe(1); + const tc = acc.toolCalls.get(1)!; + expect(tc.name).toBe('search_posts'); + expect(JSON.parse(tc.arguments)).toEqual({ query: 'test' }); + }); +});