Files
bDS/src/main/engine/ai/chat.ts

513 lines
18 KiB
TypeScript

/**
* ChatService — streaming chat using AI SDK's streamText().
*
* Streaming chat service using AI SDK v6 streamText().
*
* AI SDK handles:
* - SSE parsing, reconnection, abort
* - Provider-specific request/response format (Anthropic Messages, OpenAI Chat Completions)
* - Tool call/result loop (maxSteps)
* - Token usage extraction
*/
import { streamText, generateText, stepCountIs } from 'ai';
import type { ModelMessage, LanguageModelUsage } from 'ai';
import type { BrowserWindow } from 'electron';
import type { ChatEngine, ChatMessageData } from '../ChatEngine';
import { isRenderTool, generateFromToolCall } from '../../a2ui/generator';
import type { A2UIServerMessage } from '../../a2ui/types';
import { ProviderRegistry, detectProvider } from './providers';
import { createBlogTools, type BlogToolDeps } from './blog-tools';
import { createA2UITools } from './a2ui-tools';
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
export interface ChatCallbacks {
onDelta?: (delta: string) => void;
onToolCall?: (toolCall: { name: string; args: unknown }) => void;
onToolResult?: (result: { name: string; result: unknown }) => void;
onA2UIMessage?: (message: A2UIServerMessage) => void;
onTokenUsage?: (usage: TokenUsageReport) => void;
}
export interface TokenUsageReport {
inputTokens: number;
outputTokens: number;
cacheReadTokens: number;
cacheWriteTokens: number;
totalTokens: number;
cumulativeInputTokens: number;
cumulativeOutputTokens: number;
cumulativeCacheReadTokens: number;
cumulativeCacheWriteTokens: number;
cumulativeTotalTokens: number;
}
export interface SendResult {
success: boolean;
message?: string;
error?: string;
toolCalls?: Array<{ name: string; args: unknown }>;
}
// Maximum tool-call rounds per request
const MAX_TOOL_ROUNDS = 10;
// ---------------------------------------------------------------------------
// Message serialization — DB flat rows ↔ AI SDK messages
// ---------------------------------------------------------------------------
/**
* Convert DB message rows into AI SDK Message[] for `streamText({ messages })`.
* DB stores flat rows: role, content, toolCallId, toolCalls (JSON).
* AI SDK expects structured messages with content parts.
*
* Per Open Questions #3: only user/assistant messages are sent, tool call
* details from previous turns are appended as text annotations.
*/
function dbMessagesToAIMessages(
dbMessages: Pick<ChatMessageData, 'role' | 'content' | 'toolCalls'>[],
): ModelMessage[] {
const messages: ModelMessage[] = [];
for (const msg of dbMessages) {
if (msg.role === 'user') {
messages.push({ role: 'user', content: msg.content || '' });
} else if (msg.role === 'assistant') {
let content = msg.content || '';
// Append tool-call annotation from previous turns
if (msg.toolCalls) {
try {
const calls = JSON.parse(msg.toolCalls) as Array<{ name: string; args: unknown }>;
if (calls.length > 0) {
const summary = calls
.map(tc => `- ${tc.name}(${JSON.stringify(tc.args)})`)
.join('\n');
content += `\n\n[Tools used in this turn:\n${summary}\n]`;
}
} catch {
// Ignore malformed tool call JSON
}
}
messages.push({ role: 'assistant', content });
}
// System and tool messages from DB are not sent — system is passed separately,
// tool results are only used within the same request via maxSteps.
}
return messages;
}
// ---------------------------------------------------------------------------
// System prompt augmentation
// ---------------------------------------------------------------------------
/** Append live blog stats to the system prompt for data-volume awareness. */
async function appendBlogStats(
basePrompt: string,
blogToolDeps: BlogToolDeps,
): Promise<string> {
try {
const stats = await blogToolDeps.postEngine.getBlogStats();
const mediaList = await blogToolDeps.mediaEngine.getAllMedia();
if (stats.totalPosts === 0) return basePrompt;
const dateRange = stats.oldestPostDate && stats.newestPostDate
? `from ${stats.oldestPostDate.toISOString().split('T')[0]} to ${stats.newestPostDate.toISOString().split('T')[0]}`
: 'unknown';
const yearBreakdown = Object.entries(stats.postsPerYear)
.sort(([a], [b]) => Number(a) - Number(b))
.map(([year, count]) => `${year}: ${count}`)
.join(', ');
return basePrompt + `
--- CURRENT BLOG DATA SUMMARY ---
Total posts: ${stats.totalPosts} (${stats.publishedCount} published, ${stats.draftCount} drafts, ${stats.archivedCount} archived)
Date range: ${dateRange}
Posts per year: ${yearBreakdown}
Unique tags: ${stats.tagCount}, Unique categories: ${stats.categoryCount}
Total media files: ${mediaList.length}
NOTE: Use pagination (offset/limit) in list_posts and search_posts to access all data. Default page size is 20.`;
} catch {
return basePrompt;
}
}
// ---------------------------------------------------------------------------
// Token estimation (for context truncation)
// ---------------------------------------------------------------------------
function estimateTokens(text: string): number {
return Math.ceil(text.length / 3.5);
}
/**
* Drop oldest user+assistant pairs to fit within context budget.
* Preserves the most recent messages for continuity.
*/
function truncateMessages(
messages: ModelMessage[],
systemPrompt: string,
toolsJson: string,
maxContextTokens: number,
): ModelMessage[] {
const systemTokens = estimateTokens(systemPrompt);
const toolsTokens = estimateTokens(toolsJson);
const responseReserve = 4096;
const availableBudget = maxContextTokens - systemTokens - toolsTokens - responseReserve;
if (availableBudget <= 0) return messages.slice(-1);
const messageTokens = () =>
messages.reduce((sum, m) => sum + estimateTokens(typeof m.content === 'string' ? m.content : JSON.stringify(m.content)), 0);
if (messageTokens() <= availableBudget) return messages;
let truncated = [...messages];
while (truncated.length > 2 && messageTokens.call(null) > availableBudget) {
if (truncated[0].role === 'user') {
truncated = truncated.slice(2); // Drop user + assistant pair
} else {
truncated = truncated.slice(1);
}
}
return truncated;
}
// ---------------------------------------------------------------------------
// ChatService
// ---------------------------------------------------------------------------
export class ChatService {
private chatEngine: ChatEngine;
private providers: ProviderRegistry;
private blogToolDeps: BlogToolDeps;
private getMainWindow: () => BrowserWindow | null;
// Abort controllers per conversation
private abortControllers = new Map<string, AbortController>();
// Cumulative token usage per conversation
private conversationUsage = new Map<string, {
inputTokens: number;
outputTokens: number;
cacheReadTokens: number;
cacheWriteTokens: number;
}>();
constructor(
chatEngine: ChatEngine,
providers: ProviderRegistry,
blogToolDeps: BlogToolDeps,
getMainWindow: () => BrowserWindow | null,
) {
this.chatEngine = chatEngine;
this.providers = providers;
this.blogToolDeps = blogToolDeps;
this.getMainWindow = getMainWindow;
}
/**
* Send a user message, stream the AI response with tool use.
* Send a message in a conversation, streaming the response.
*/
async sendMessage(
conversationId: string,
userMessage: string,
callbacks: ChatCallbacks = {},
): Promise<SendResult> {
try {
// Readiness check
if (!this.providers.isReady()) {
return { success: false, error: 'API key not configured' };
}
// Load conversation
const conversation = await this.chatEngine.getConversation(conversationId);
if (!conversation) {
return { success: false, error: 'Conversation not found' };
}
// Add user message to DB
await this.chatEngine.addMessage({
conversationId,
role: 'user',
content: userMessage,
createdAt: new Date(),
});
// Abort controller
const abortController = new AbortController();
this.abortControllers.set(conversationId, abortController);
const modelId = conversation.model || 'claude-sonnet-4';
const provider = detectProvider(modelId);
// Verify provider key is available
if (!this.providers.isProviderKeySet(provider)) {
const providerLabel = provider === 'mistral' ? 'Mistral' : 'OpenCode';
return { success: false, error: `The model '${modelId}' requires a ${providerLabel} API key. Configure it in Settings.` };
}
// Build system prompt with live blog stats
const systemMessage = conversation.messages.find(m => m.role === 'system');
const basePrompt = systemMessage?.content || await this.chatEngine.getDefaultSystemPrompt();
const systemPrompt = await appendBlogStats(basePrompt, this.blogToolDeps);
// Convert DB messages to AI SDK format
const dbMessages = conversation.messages.filter(m => m.role !== 'system');
dbMessages.push({
conversationId,
role: 'user',
content: userMessage,
createdAt: new Date(),
});
const aiMessages = dbMessagesToAIMessages(dbMessages);
// Build tools
const blogTools = createBlogTools(this.blogToolDeps);
const a2uiToolsRaw = createA2UITools();
const allTools = { ...blogTools, ...a2uiToolsRaw };
// Get context window for truncation
const contextWindow = await this.providers.getModelCatalogEngine().getContextWindow(modelId) ?? 150_000;
const truncatedMessages = truncateMessages(
aiMessages,
systemPrompt,
JSON.stringify(Object.keys(allTools)),
contextWindow,
);
// Resolve model
const model = this.providers.resolveModel(modelId);
// Compute turn index for A2UI messages
const turnIndex = dbMessages.filter(m => m.role === 'user').length - 1;
// Track tool calls for response
const allToolCalls: Array<{ name: string; args: unknown }> = [];
// Build Anthropic-specific provider options for cache control
const providerOptions = modelId.startsWith('claude')
? { anthropic: { cacheControl: { type: 'ephemeral' as const } } }
: undefined;
try {
// --- streamText: the AI SDK replaces our entire SSE/accumulator/tool-loop ---
const result = streamText({
model,
system: systemPrompt,
messages: truncatedMessages,
tools: allTools,
stopWhen: stepCountIs(MAX_TOOL_ROUNDS),
abortSignal: abortController.signal,
maxRetries: 3,
providerOptions,
onChunk: ({ chunk }) => {
if (chunk.type === 'text-delta' && callbacks.onDelta) {
callbacks.onDelta(chunk.text);
}
},
onStepFinish: ({ staticToolCalls: stepToolCalls, staticToolResults: stepToolResults }) => {
// Emit tool call/result events for each step
if (stepToolCalls) {
for (const tc of stepToolCalls) {
allToolCalls.push({ name: tc.toolName, args: tc.input });
callbacks.onToolCall?.({ name: tc.toolName, args: tc.input });
}
}
if (stepToolResults) {
for (const tr of stepToolResults) {
const toolName = tr.toolName;
const toolResult = tr.output;
// Handle A2UI render tools
if (isRenderTool(toolName)) {
// Find the matching tool call args
const matchingCall = stepToolCalls?.find(tc => tc.toolName === toolName);
if (matchingCall) {
const a2uiMessages = generateFromToolCall(
conversationId,
toolName,
matchingCall.input as Record<string, unknown>,
);
if (a2uiMessages && callbacks.onA2UIMessage) {
for (const msg of a2uiMessages) {
if (msg.type === 'createSurface') {
msg.metadata = { ...msg.metadata, turnIndex };
}
callbacks.onA2UIMessage(msg);
}
}
}
}
callbacks.onToolResult?.({ name: toolName, result: toolResult });
}
}
},
});
// Consume the stream to completion
const finalResult = await result.response;
// Extract usage from the response
const usage = await result.usage;
this.emitUsage(conversationId, usage, callbacks);
// Get the final text
const fullResponse = await result.text;
// Save assistant response to DB
if (fullResponse) {
await this.chatEngine.addMessage({
conversationId,
role: 'assistant',
content: fullResponse,
toolCalls: allToolCalls.length > 0 ? JSON.stringify(allToolCalls) : undefined,
createdAt: new Date(),
});
}
// Generate title after first user message
const userMsgCount = conversation.messages.filter(m => m.role === 'user').length;
if (userMsgCount === 0 && fullResponse) {
this.generateConversationTitle(conversationId, userMessage).catch(err =>
console.error('[ChatService] Error generating title:', err),
);
}
return {
success: true,
message: fullResponse,
toolCalls: allToolCalls.length > 0 ? allToolCalls : undefined,
};
} catch (error) {
const isAborted = abortController.signal.aborted || (error as Error).message === 'Request cancelled';
if (!isAborted) throw error;
return { success: true, message: '' };
} finally {
this.abortControllers.delete(conversationId);
}
} catch (error) {
console.error('[ChatService] Error sending message:', error);
return { success: false, error: (error as Error).message };
}
}
/** Abort an in-flight request for a conversation. */
async abortMessage(conversationId: string): Promise<{ success: boolean; error?: string }> {
const controller = this.abortControllers.get(conversationId);
if (!controller) {
return { success: false, error: 'No active request for this conversation' };
}
controller.abort();
this.abortControllers.delete(conversationId);
return { success: true };
}
/** Abort all in-flight requests. */
async stop(): Promise<void> {
for (const [, controller] of this.abortControllers) {
controller.abort();
}
this.abortControllers.clear();
}
// ---- Private helpers ----
/**
* Generate a short conversation title from the first user message.
* Non-streaming one-shot call using the configured title model.
*/
private async generateConversationTitle(
conversationId: string,
userMessage: string,
): Promise<void> {
try {
let titleModel = await this.chatEngine.getSetting('chat_title_model');
// Fallback chain: setting → haiku → mistral-small
if (!titleModel || !this.providers.isProviderKeySet(detectProvider(titleModel))) {
titleModel = this.providers.getOpencodeKey()
? 'claude-haiku-4-5'
: this.providers.getMistralKey()
? 'mistral-small-latest'
: null;
}
if (!titleModel) return;
const model = this.providers.resolveModel(titleModel);
const { text } = await generateText({
model,
system: 'Generate an ultra-short title (2-3 words, max 25 characters) for this conversation. Focus ONLY on the topic. Ignore any capability disclaimers. Output ONLY the title text.',
prompt: `Topic: ${userMessage.substring(0, 100)}`,
maxOutputTokens: 20,
maxRetries: 2,
});
let title = text.trim().replace(/^["']|["']$/g, '').replace(/[.!?]+$/, '');
const MAX_TITLE_LENGTH = 30;
if (title.length > MAX_TITLE_LENGTH) {
title = title.substring(0, MAX_TITLE_LENGTH - 1) + '…';
}
if (title) {
await this.chatEngine.updateConversation(conversationId, { title });
const mainWindow = this.getMainWindow();
if (mainWindow) {
mainWindow.webContents.send('chat-title-updated', { conversationId, title });
}
}
} catch (error) {
console.error('[ChatService] Error generating title:', error);
}
}
/** Emit per-turn + cumulative token usage. */
private emitUsage(
conversationId: string,
usage: LanguageModelUsage | undefined,
callbacks: ChatCallbacks,
): void {
if (!usage || !callbacks.onTokenUsage) return;
// AI SDK v6 normalizes usage into inputTokens/outputTokens
// Cache tokens are in inputTokenDetails
const inputTokens = usage.inputTokens ?? 0;
const outputTokens = usage.outputTokens ?? 0;
const cacheReadTokens = usage.inputTokenDetails?.cacheReadTokens ?? 0;
const cacheWriteTokens = usage.inputTokenDetails?.cacheWriteTokens ?? 0;
const adjustedInputTokens = inputTokens - cacheReadTokens - cacheWriteTokens;
const totalTokens = inputTokens + outputTokens;
const prev = this.conversationUsage.get(conversationId) || {
inputTokens: 0, outputTokens: 0, cacheReadTokens: 0, cacheWriteTokens: 0,
};
const cumulative = {
inputTokens: prev.inputTokens + adjustedInputTokens,
outputTokens: prev.outputTokens + outputTokens,
cacheReadTokens: prev.cacheReadTokens + cacheReadTokens,
cacheWriteTokens: prev.cacheWriteTokens + cacheWriteTokens,
};
this.conversationUsage.set(conversationId, cumulative);
callbacks.onTokenUsage({
inputTokens: adjustedInputTokens, outputTokens, cacheReadTokens, cacheWriteTokens, totalTokens,
cumulativeInputTokens: cumulative.inputTokens,
cumulativeOutputTokens: cumulative.outputTokens,
cumulativeCacheReadTokens: cumulative.cacheReadTokens,
cumulativeCacheWriteTokens: cumulative.cacheWriteTokens,
cumulativeTotalTokens: cumulative.inputTokens + cumulative.outputTokens + cumulative.cacheReadTokens + cumulative.cacheWriteTokens,
});
}
}