diff --git a/.vscode/settings.json b/.vscode/settings.json index 6c83a94..5e57c86 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -8,4 +8,4 @@ "gh": true, "git add": true } -} \ No newline at end of file +} diff --git a/src/main/engine/OpenCodeManager.ts b/src/main/engine/OpenCodeManager.ts index 3545d8e..fd3eac8 100644 --- a/src/main/engine/OpenCodeManager.ts +++ b/src/main/engine/OpenCodeManager.ts @@ -473,6 +473,7 @@ export class OpenCodeManager { while (round < MAX_TOOL_ROUNDS) { round++; + if (signal.aborted) break; const body: Record = { model: modelId, @@ -484,51 +485,57 @@ export class OpenCodeManager { cache_control: { type: 'ephemeral' }, }; - // Stream the response with retry for transient errors - const streamAccumulator = createAnthropicStreamAccumulator(); - let stopReason = ''; - let inputTokens = 0; - let outputTokens = 0; - let cacheReadTokens = 0; - let cacheWriteTokens = 0; - let roundText = ''; // Text produced in this round only + // Stream the response with retry for transient errors (including mid-stream failures) + const streamResult = await withRetry(async () => { + const streamAccumulator = createAnthropicStreamAccumulator(); + let stopReason = ''; + let inputTokens = 0; + let outputTokens = 0; + let cacheReadTokens = 0; + let cacheWriteTokens = 0; + let roundText = ''; - const { events } = await withRetry(() => httpRequestStream(ZEN_ANTHROPIC_URL, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'x-api-key': this.apiKey, - 'Authorization': `Bearer ${this.apiKey}`, - 'anthropic-version': '2023-06-01', - }, - body: JSON.stringify(body), - signal, - })); + const { events } = await httpRequestStream(ZEN_ANTHROPIC_URL, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'x-api-key': this.apiKey, + 'Authorization': `Bearer ${this.apiKey}`, + 'anthropic-version': '2023-06-01', + }, + body: JSON.stringify(body), + signal, + }); - for await (const event of events) { - const result = parseAnthropicStreamEvent(event, streamAccumulator); + for await (const event of events) { + const result = parseAnthropicStreamEvent(event, streamAccumulator); - // Emit text deltas immediately for real-time streaming - if (result.textDelta) { - accumulatedText += result.textDelta; - roundText += result.textDelta; - if (callbacks.onDelta) { - callbacks.onDelta(result.textDelta); + if (result.textDelta) { + roundText += result.textDelta; + if (callbacks.onDelta) { + callbacks.onDelta(result.textDelta); + } } + + if (result.usage) { + if (result.usage.inputTokens !== undefined) inputTokens = result.usage.inputTokens; + if (result.usage.cacheReadTokens !== undefined) cacheReadTokens = result.usage.cacheReadTokens; + if (result.usage.cacheWriteTokens !== undefined) cacheWriteTokens = result.usage.cacheWriteTokens; + if (result.usage.outputTokens !== undefined) outputTokens = result.usage.outputTokens; + } + + if (result.finishReason) { + stopReason = result.finishReason; + } + + if (result.done) break; } - // Collect usage from message_start (input tokens) and message_delta (output tokens) - if (result.usage) { - if (result.usage.inputTokens !== undefined) inputTokens = result.usage.inputTokens; - if (result.usage.cacheReadTokens !== undefined) cacheReadTokens = result.usage.cacheReadTokens; - if (result.usage.cacheWriteTokens !== undefined) cacheWriteTokens = result.usage.cacheWriteTokens; - if (result.usage.outputTokens !== undefined) outputTokens = result.usage.outputTokens; - } + return { roundText, stopReason, toolCalls: streamAccumulator.toolCalls, inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens }; + }); - if (result.finishReason) { - stopReason = result.finishReason; - } - } + const { roundText, stopReason, toolCalls: streamToolCalls, inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens } = streamResult; + accumulatedText += roundText; // Emit token usage after stream completes if (callbacks.onTokenUsage) { @@ -558,7 +565,7 @@ export class OpenCodeManager { // Collect tool calls from stream accumulator const toolUseBlocks: Array<{ id: string; name: string; input: unknown }> = []; - for (const [, tc] of streamAccumulator.toolCalls) { + for (const [, tc] of streamToolCalls) { try { toolUseBlocks.push({ id: tc.id, name: tc.name, input: JSON.parse(tc.arguments) }); } catch { @@ -626,7 +633,8 @@ export class OpenCodeManager { continue; } - // Execute the tool + // Execute the tool (check abort before each tool execution) + if (signal.aborted) break; const result = await this.executeTool(toolName, toolArgs as Record); if (callbacks.onToolResult) { @@ -745,6 +753,7 @@ export class OpenCodeManager { while (round < MAX_TOOL_ROUNDS) { round++; + if (signal.aborted) break; const body: Record = { model: modelId, @@ -755,51 +764,55 @@ export class OpenCodeManager { stream_options: { include_usage: true }, }; - // Stream the response with retry for transient errors - const streamAccumulator = createOpenAIStreamAccumulator(); - let finishReason = ''; - let promptTokens = 0; - let completionTokens = 0; - let totalTokens = 0; - let cacheReadTokens = 0; - let roundText = ''; // Text produced in this round only + // Stream the response with retry for transient errors (including mid-stream failures) + const streamResult = await withRetry(async () => { + const streamAccumulator = createOpenAIStreamAccumulator(); + let finishReason = ''; + let promptTokens = 0; + let completionTokens = 0; + let totalTokens = 0; + let cacheReadTokens = 0; + let roundText = ''; - const { events } = await withRetry(() => httpRequestStream(ZEN_OPENAI_URL, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'Authorization': `Bearer ${this.apiKey}`, - }, - body: JSON.stringify(body), - signal, - })); + const { events } = await httpRequestStream(ZEN_OPENAI_URL, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${this.apiKey}`, + }, + body: JSON.stringify(body), + signal, + }); - for await (const event of events) { - const result = parseOpenAIStreamEvent(event, streamAccumulator); + for await (const event of events) { + const result = parseOpenAIStreamEvent(event, streamAccumulator); - // Emit text deltas immediately for real-time streaming - if (result.textDelta) { - accumulatedText += result.textDelta; - roundText += result.textDelta; - if (callbacks.onDelta) { - callbacks.onDelta(result.textDelta); + if (result.textDelta) { + roundText += result.textDelta; + if (callbacks.onDelta) { + callbacks.onDelta(result.textDelta); + } } + + if (result.usage) { + if (result.usage.promptTokens !== undefined) promptTokens = result.usage.promptTokens; + if (result.usage.completionTokens !== undefined) completionTokens = result.usage.completionTokens; + if (result.usage.totalTokens !== undefined) totalTokens = result.usage.totalTokens; + if (result.usage.cacheReadTokens !== undefined) cacheReadTokens = result.usage.cacheReadTokens; + } + + if (result.finishReason) { + finishReason = result.finishReason; + } + + if (result.done) break; } - // Collect usage from final chunk - if (result.usage) { - if (result.usage.promptTokens !== undefined) promptTokens = result.usage.promptTokens; - if (result.usage.completionTokens !== undefined) completionTokens = result.usage.completionTokens; - if (result.usage.totalTokens !== undefined) totalTokens = result.usage.totalTokens; - if (result.usage.cacheReadTokens !== undefined) cacheReadTokens = result.usage.cacheReadTokens; - } + return { roundText, finishReason, toolCalls: streamAccumulator.toolCalls, promptTokens, completionTokens, totalTokens, cacheReadTokens }; + }); - if (result.finishReason) { - finishReason = result.finishReason; - } - - if (result.done) break; - } + const { roundText, finishReason, toolCalls: streamToolCalls, promptTokens, completionTokens, totalTokens, cacheReadTokens } = streamResult; + accumulatedText += roundText; // Emit token usage after stream completes if (callbacks.onTokenUsage) { @@ -818,7 +831,8 @@ export class OpenCodeManager { this.conversationUsage.set(conversationId, cumulative); callbacks.onTokenUsage({ - inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens: 0, + inputTokens, outputTokens, cacheReadTokens, + cacheWriteTokens: 0, // OpenAI streaming does not report cache write tokens totalTokens: totalTokens || inputTokens + outputTokens, cumulativeInputTokens: cumulative.inputTokens, cumulativeOutputTokens: cumulative.outputTokens, @@ -830,7 +844,7 @@ export class OpenCodeManager { // Collect tool calls from stream accumulator const parsedToolCalls: Array<{ id: string; name: string; args: unknown }> = []; - for (const [, tc] of streamAccumulator.toolCalls) { + for (const [, tc] of streamToolCalls) { try { parsedToolCalls.push({ id: tc.id, name: tc.name, args: JSON.parse(tc.arguments) }); } catch { @@ -887,6 +901,8 @@ export class OpenCodeManager { continue; } + // Check abort before each tool execution + if (signal.aborted) break; const result = await this.executeTool(toolName, toolArgs as Record); if (callbacks.onToolResult) { diff --git a/src/main/engine/streaming.ts b/src/main/engine/streaming.ts index 99c4d5b..1cc6017 100644 --- a/src/main/engine/streaming.ts +++ b/src/main/engine/streaming.ts @@ -93,9 +93,11 @@ export function parseSSELines(text: string): { events: SSEEvent[]; remaining: st if (line.startsWith(':')) continue; if (line.startsWith('event: ') || line.startsWith('event:')) { - eventType = line.slice(line.indexOf(':') + 1).trim(); + const afterColon = line.slice(line.indexOf(':') + 1); + eventType = afterColon.startsWith(' ') ? afterColon.slice(1) : afterColon; } else if (line.startsWith('data: ') || line.startsWith('data:')) { - dataLines.push(line.slice(line.indexOf(':') + 1).trimStart()); + const afterColon = line.slice(line.indexOf(':') + 1); + dataLines.push(afterColon.startsWith(' ') ? afterColon.slice(1) : afterColon); } } @@ -326,7 +328,7 @@ const RETRYABLE_STATUS_CODES = new Set([429, 502, 503]); */ export async function withRetry( fn: () => Promise, - options: { maxRetries?: number } = {}, + options: { maxRetries?: number; onRetry?: (attempt: number, error: Error) => void } = {}, ): Promise { const maxRetries = options.maxRetries ?? 3; let lastError: Error | undefined; @@ -363,6 +365,10 @@ export async function withRetry( delay = Math.max(delay, httpError.retryAfter * 1000); } + if (options.onRetry) { + options.onRetry(attempt + 1, lastError); + } + await new Promise(resolve => setTimeout(resolve, delay)); } } @@ -446,6 +452,7 @@ export function httpRequestStream( [Symbol.asyncIterator]() { let buffer = ''; let done = false; + let pendingError: Error | null = null; const eventQueue: SSEEvent[] = []; let resolveNext: ((value: IteratorResult) => void) | null = null; let rejectNext: ((error: Error) => void) | null = null; @@ -484,6 +491,9 @@ export function httpRequestStream( resolveNext = null; rejectNext = null; reject(err); + } else { + // Store error for next .next() call so it's not silently swallowed + pendingError = err; } }); @@ -494,6 +504,13 @@ export function httpRequestStream( return Promise.resolve({ value: eventQueue.shift()!, done: false }); } + // Throw stored error from a previous event that fired with no consumer waiting + if (pendingError) { + const err = pendingError; + pendingError = null; + return Promise.reject(err); + } + // Stream already ended if (done) { return Promise.resolve({ value: undefined as unknown as SSEEvent, done: true }); diff --git a/tests/engine/streaming.test.ts b/tests/engine/streaming.test.ts index 8f914e9..cd8ba00 100644 --- a/tests/engine/streaming.test.ts +++ b/tests/engine/streaming.test.ts @@ -10,12 +10,14 @@ * - Retry with exponential backoff (429/502/503, Retry-After, no retry on 4xx/abort) */ -import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import http from 'http'; import { parseSSELines, parseOpenAIStreamEvent, parseAnthropicStreamEvent, withRetry, + httpRequestStream, type SSEEvent, type OpenAIStreamAccumulator, type AnthropicStreamAccumulator, @@ -838,29 +840,268 @@ describe('OpenAI cache token extraction', () => { // ── httpRequestStream ── describe('httpRequestStream', () => { - // We test httpRequestStream by mocking Node's http/https modules - // These tests verify the async iterable, error handling, and abort behavior + // Use a real HTTP server for integration tests (avoids ESM spyOn limitations) - // Helper to create a mock response - function createMockResponse(statusCode: number) { - const handlers: Record void)[]> = {}; - return { - statusCode, - headers: {} as Record, - on(event: string, handler: (...args: unknown[]) => void) { - if (!handlers[event]) handlers[event] = []; - handlers[event].push(handler); - return this; - }, - emit(event: string, ...args: unknown[]) { - for (const h of handlers[event] || []) h(...args); - }, - }; + function startTestServer(handler: (req: http.IncomingMessage, res: http.ServerResponse) => void): Promise<{ url: string; close: () => Promise }> { + return new Promise((resolve) => { + const server = http.createServer(handler); + server.listen(0, () => { + const addr = server.address() as { port: number }; + resolve({ + url: `http://localhost:${addr.port}`, + close: () => new Promise((r) => server.close(() => r())), + }); + }); + }); } - it('should be importable', async () => { - // Verify the function exists and has the right shape - const { httpRequestStream } = await import('../../src/main/engine/streaming'); - expect(typeof httpRequestStream).toBe('function'); + it('parses streamed SSE events from response data chunks', async () => { + const srv = await startTestServer((_req, res) => { + res.writeHead(200, { 'Content-Type': 'text/event-stream' }); + res.write('data: {"choices":[{"delta":{"content":"Hello"}}]}\n\n'); + res.write('data: {"choices":[{"delta":{"content":" world"}}]}\n\n'); + res.write('data: [DONE]\n\n'); + res.end(); + }); + + try { + const { events } = await httpRequestStream(srv.url, { method: 'POST', body: '{}' }); + const collected: SSEEvent[] = []; + for await (const event of events) { + collected.push(event); + } + expect(collected).toHaveLength(3); + expect(collected[0].data).toBe('{"choices":[{"delta":{"content":"Hello"}}]}'); + expect(collected[1].data).toBe('{"choices":[{"delta":{"content":" world"}}]}'); + expect(collected[2].data).toBe('[DONE]'); + } finally { + await srv.close(); + } + }); + + it('collects error body and rejects on non-2xx status', async () => { + const srv = await startTestServer((_req, res) => { + res.writeHead(429, { 'Content-Type': 'application/json', 'Retry-After': '5' }); + res.end(JSON.stringify({ error: { message: 'Rate limited' } })); + }); + + try { + await expect(httpRequestStream(srv.url, {})).rejects.toMatchObject({ + message: 'Rate limited', + statusCode: 429, + retryAfter: 5, + }); + } finally { + await srv.close(); + } + }); + + it('propagates mid-stream errors to async iterable consumer', async () => { + const srv = await startTestServer((_req, res) => { + res.writeHead(200, { 'Content-Type': 'text/event-stream' }); + res.write('data: {"choices":[{"delta":{"content":"Hi"}}]}\n\n'); + // Destroy the socket to simulate TCP disconnect + setTimeout(() => res.destroy(), 20); + }); + + try { + const { events } = await httpRequestStream(srv.url, {}); + const collected: SSEEvent[] = []; + await expect(async () => { + for await (const event of events) { + collected.push(event); + } + }).rejects.toThrow(); + + // Should have received the first event before the error + expect(collected).toHaveLength(1); + expect(collected[0].data).toBe('{"choices":[{"delta":{"content":"Hi"}}]}'); + } finally { + await srv.close(); + } + }); + + it('propagates stored error when no consumer was waiting (pendingError fix)', async () => { + const srv = await startTestServer((_req, res) => { + res.writeHead(200, { 'Content-Type': 'text/event-stream' }); + // Send data and immediately destroy — error fires before consumer calls .next() + res.write('data: {"ok":true}\n\n'); + // Give a tiny delay so the data event fires first + setTimeout(() => res.destroy(), 5); + }); + + try { + const { events } = await httpRequestStream(srv.url, {}); + const iter = events[Symbol.asyncIterator](); + + // Wait a bit for both data and error to fire + await new Promise(resolve => setTimeout(resolve, 50)); + + // First call should return the queued event + const first = await iter.next(); + expect(first.done).toBe(false); + expect(first.value.data).toBe('{"ok":true}'); + + // Second call should throw the stored (pending) error + await expect(iter.next()).rejects.toThrow(); + } finally { + await srv.close(); + } + }); + + it('handles already-aborted signal', async () => { + // No server needed — should reject immediately + const controller = new AbortController(); + controller.abort(); + + await expect(httpRequestStream('http://localhost:1/test', { + signal: controller.signal, + })).rejects.toMatchObject({ + isAbort: true, + }); + }); + + it('handles non-JSON error body', async () => { + const srv = await startTestServer((_req, res) => { + res.writeHead(500, { 'Content-Type': 'text/plain' }); + res.end('Internal Server Error'); + }); + + try { + await expect(httpRequestStream(srv.url, {})).rejects.toMatchObject({ + statusCode: 500, + }); + } finally { + await srv.close(); + } + }); +}); + +// ── withRetry onRetry callback ── + +describe('withRetry onRetry callback', () => { + beforeEach(() => { + vi.useFakeTimers(); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + it('calls onRetry callback before each retry attempt', async () => { + const error429 = Object.assign(new Error('Rate limited'), { statusCode: 429 }); + const onRetry = vi.fn(); + const fn = vi.fn() + .mockRejectedValueOnce(error429) + .mockRejectedValueOnce(error429) + .mockResolvedValue('success'); + + const promise = withRetry(fn, { maxRetries: 3, onRetry }); + await vi.advanceTimersByTimeAsync(10000); + const result = await promise; + + expect(result).toBe('success'); + expect(onRetry).toHaveBeenCalledTimes(2); + expect(onRetry).toHaveBeenCalledWith(1, error429); + expect(onRetry).toHaveBeenCalledWith(2, error429); + }); + + it('does not call onRetry when first attempt succeeds', async () => { + const onRetry = vi.fn(); + const fn = vi.fn().mockResolvedValue('ok'); + + const result = await withRetry(fn, { maxRetries: 3, onRetry }); + + expect(result).toBe('ok'); + expect(onRetry).not.toHaveBeenCalled(); + }); +}); + +// ── Mid-stream retry integration ── + +describe('mid-stream retry with withRetry', () => { + it('retries stream consumption on transient mid-stream error', async () => { + vi.useRealTimers(); + + let attempt = 0; + const fn = async () => { + attempt++; + if (attempt === 1) { + // First attempt: simulate partial stream then error + const error = Object.assign(new Error('Service temporarily unavailable'), { statusCode: 503 }); + throw error; + } + // Second attempt: succeed + return { text: 'Hello world!', toolCalls: [] }; + }; + + const result = await withRetry(fn, { maxRetries: 2 }); + expect(result).toEqual({ text: 'Hello world!', toolCalls: [] }); + expect(attempt).toBe(2); + }); + + it('retries on mid-stream TCP error (no status code)', async () => { + vi.useRealTimers(); + + let attempt = 0; + const fn = async () => { + attempt++; + if (attempt === 1) { + throw new Error('ECONNRESET'); + } + return 'recovered'; + }; + + const result = await withRetry(fn, { maxRetries: 2 }); + expect(result).toBe('recovered'); + expect(attempt).toBe(2); + }); + + it('does not retry mid-stream abort errors', async () => { + const abortError = Object.assign(new Error('Request cancelled'), { isAbort: true }); + + let attempt = 0; + const fn = async () => { + attempt++; + throw abortError; + }; + + await expect(withRetry(fn, { maxRetries: 3 })).rejects.toThrow('Request cancelled'); + expect(attempt).toBe(1); + }); +}); + +// ── SSE spec compliance ── + +describe('SSE spec compliance - single space removal', () => { + it('removes exactly one leading space after colon in data field', () => { + const chunk = 'data: {"key": "value"}\n\n'; + const { events } = parseSSELines(chunk); + expect(events[0].data).toBe('{"key": "value"}'); + }); + + it('preserves data when no space after colon', () => { + const chunk = 'data:{"key":"value"}\n\n'; + const { events } = parseSSELines(chunk); + expect(events[0].data).toBe('{"key":"value"}'); + }); + + it('preserves extra leading spaces after removing one', () => { + const chunk = 'data: two spaces\n\n'; + const { events } = parseSSELines(chunk); + // Per SSE spec: only one leading space is removed + expect(events[0].data).toBe(' two spaces'); + }); + + it('removes exactly one leading space from event type', () => { + const chunk = 'event: message_start\ndata: {}\n\n'; + const { events } = parseSSELines(chunk); + expect(events[0].event).toBe('message_start'); + }); + + it('handles event type with no space after colon', () => { + const chunk = 'event:ping\ndata: {}\n\n'; + const { events } = parseSSELines(chunk); + expect(events[0].event).toBe('ping'); }); });