From 32bbf3f20471865c7f6ab497a693ac717339fc70 Mon Sep 17 00:00:00 2001 From: Chili Palmer Date: Sat, 21 Mar 2026 09:39:25 +0100 Subject: [PATCH] feat: context fill grade in chat UI --- MLXServer/Server/PromptBuilder.swift | 9 +++- MLXServer/ViewModels/ChatViewModel.swift | 28 ++++++++++++ MLXServer/Views/StatusBarView.swift | 43 +++++++++++++++++++ .../Server/PromptBuilderTests.swift | 14 ++++++ 4 files changed, 93 insertions(+), 1 deletion(-) diff --git a/MLXServer/Server/PromptBuilder.swift b/MLXServer/Server/PromptBuilder.swift index 0e90ade..5257ceb 100644 --- a/MLXServer/Server/PromptBuilder.swift +++ b/MLXServer/Server/PromptBuilder.swift @@ -96,7 +96,7 @@ enum PromptBuilder { additionalContext: additionalContext ) - let estimatedPromptTokens = (instructions.count + chatMessages.reduce(0) { $0 + $1.content.count }) * 10 / 35 + let estimatedPromptTokens = estimatePromptTokens(instructions: instructions, chatMessages: chatMessages) return PreparedPrompt( instructions: instructions, @@ -111,6 +111,13 @@ enum PromptBuilder { ) } + static func estimatePromptTokens(instructions: String, chatMessages: [Chat.Message]) -> Int { + let characterCount = instructions.count + chatMessages.reduce(0) { partial, message in + partial + message.content.count + } + return max(0, characterCount * 10 / 35) + } + private static func imageFingerprint(_ source: String) -> UInt64 { var hash: UInt64 = 14_695_981_039_346_656_037 for byte in source.utf8 { diff --git a/MLXServer/ViewModels/ChatViewModel.swift b/MLXServer/ViewModels/ChatViewModel.swift index bc84931..3e307f6 100644 --- a/MLXServer/ViewModels/ChatViewModel.swift +++ b/MLXServer/ViewModels/ChatViewModel.swift @@ -49,6 +49,34 @@ final class ChatViewModel { hasUnsavedChanges ? "\(documentDisplayName) *" : documentDisplayName } + var currentContextLength: Int { + modelManager.currentModel?.contextLength ?? 0 + } + + var estimatedPromptTokens: Int { + let draft = inputText.trimmingCharacters(in: .whitespacesAndNewlines) + var chatMessages = conversation.messages.compactMap(historyMessage(from:)) + if !draft.isEmpty { + chatMessages.append(Chat.Message(role: .user, content: draft)) + } + return PromptBuilder.estimatePromptTokens( + instructions: effectiveSystemPrompt, + chatMessages: chatMessages + ) + } + + var contextUsedTokens: Int { + if isGenerating && (promptTokens > 0 || generationTokens > 0) { + return promptTokens + generationTokens + } + return estimatedPromptTokens + } + + var contextFillRatio: Double { + guard currentContextLength > 0 else { return 0 } + return min(max(Double(contextUsedTokens) / Double(currentContextLength), 0), 1) + } + /// Ensure a ChatSession exists for the current model. private func ensureSession() { guard let container = modelManager.modelContainer else { return } diff --git a/MLXServer/Views/StatusBarView.swift b/MLXServer/Views/StatusBarView.swift index 917bc37..2688926 100644 --- a/MLXServer/Views/StatusBarView.swift +++ b/MLXServer/Views/StatusBarView.swift @@ -31,6 +31,10 @@ struct StatusBarView: View { .font(.caption) .foregroundStyle(.secondary) + if let model = modelManager.currentModel, model.contextLength > 0 { + contextFillView(totalContext: model.contextLength) + } + Spacer() // GPU memory @@ -78,4 +82,43 @@ struct StatusBarView: View { .padding(.vertical, 4) .background(.bar) } + + @ViewBuilder + private func contextFillView(totalContext: Int) -> some View { + let usedTokens = viewModel.contextUsedTokens + let ratio = viewModel.contextFillRatio + let percent = Int((ratio * 100).rounded()) + + HStack(spacing: 6) { + Capsule() + .fill(.quaternary) + .frame(width: 48, height: 6) + .overlay(alignment: .leading) { + Capsule() + .fill(contextFillColor(for: ratio)) + .frame(width: max(4, 48 * ratio), height: 6) + } + + Text("Ctx \(percent)%") + .font(.caption.monospacedDigit()) + .foregroundStyle(.secondary) + } + .help("Approximate context usage: \(formatTokenCount(usedTokens)) of \(formatTokenCount(totalContext)) tokens") + } + + private func contextFillColor(for ratio: Double) -> Color { + if ratio >= 0.9 { return .red } + if ratio >= 0.7 { return .orange } + return .blue + } + + private func formatTokenCount(_ count: Int) -> String { + if count >= 1_000_000 { + return String(format: "%.1fM", Double(count) / 1_000_000) + } + if count >= 1_000 { + return String(format: "%.1fk", Double(count) / 1_000) + } + return "\(count)" + } } diff --git a/MLXServerTests/Server/PromptBuilderTests.swift b/MLXServerTests/Server/PromptBuilderTests.swift index 7669cba..164ff03 100644 --- a/MLXServerTests/Server/PromptBuilderTests.swift +++ b/MLXServerTests/Server/PromptBuilderTests.swift @@ -61,6 +61,20 @@ final class PromptBuilderTests: XCTestCase { XCTAssertEqual(prepared.additionalContext?["enable_thinking"] as? Bool, legacy.additionalContext?["enable_thinking"] as? Bool) } + func testEstimatePromptTokensMatchesSharedCharacterHeuristic() { + let messages = [ + Chat.Message(role: .user, content: "1234567890"), + Chat.Message(role: .assistant, content: "abcdefghij") + ] + + let estimated = PromptBuilder.estimatePromptTokens( + instructions: "system12345", + chatMessages: messages + ) + + XCTAssertEqual(estimated, 8) + } + func testBuildAggregatesInstructionsAndMessages() { let request = APIChatCompletionRequest( model: "gemma",