From 0325fa89642683db88f2f9256e781d246494d06a Mon Sep 17 00:00:00 2001 From: Chili Palmer Date: Sat, 21 Mar 2026 07:59:48 +0100 Subject: [PATCH] feat: phase 6 implemented and tested --- MLXServer.xcodeproj/project.pbxproj | 8 + MLXServer/Models/InferenceStats.swift | 7 + MLXServer/Server/TokenPrefixCache.swift | 185 ++++++++++++- MLXServer/Utilities/Preferences.swift | 26 ++ MLXServer/Views/MonitorView.swift | 8 + MLXServer/Views/SettingsView.swift | 53 +++- .../Server/ModelBackedQuantizationTests.swift | 251 +++++++++++++++++ .../TokenPrefixCacheQuantizationTests.swift | 252 ++++++++++++++++++ docs/session-cache-upgrade.md | 26 +- 9 files changed, 792 insertions(+), 24 deletions(-) create mode 100644 MLXServerTests/Server/ModelBackedQuantizationTests.swift create mode 100644 MLXServerTests/Server/TokenPrefixCacheQuantizationTests.swift diff --git a/MLXServer.xcodeproj/project.pbxproj b/MLXServer.xcodeproj/project.pbxproj index 19f8193..da7e92a 100644 --- a/MLXServer.xcodeproj/project.pbxproj +++ b/MLXServer.xcodeproj/project.pbxproj @@ -33,6 +33,7 @@ 67D0628F148FE3C2200E0AEF /* APIServerResponseResolutionTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 051FEC14CC76A677F79ACD21 /* APIServerResponseResolutionTests.swift */; }; 6828CCA8B78AB40906F87CAB /* LocalModelResolver.swift in Sources */ = {isa = PBXBuildFile; fileRef = D733A0D1D4AC25DDDA6C8684 /* LocalModelResolver.swift */; }; 741692862DB1F13EA0B2D14D /* TokenPrefixCache.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1962D530BEABCC7F1E8E0ED1 /* TokenPrefixCache.swift */; }; + 7936325B425DFA2931F6E421 /* ModelBackedQuantizationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = F7E6F18C80D9859E89D2B4E3 /* ModelBackedQuantizationTests.swift */; }; 7CD765C1E2F9F4D7504C8D09 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = B629DA084A9A40E54F8EA5FA /* Assets.xcassets */; }; 80646C5066BF79BC76E1D9D7 /* ModelConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = 38DFC212AF4359A45FBE22BA /* ModelConfig.swift */; }; 834B49AA3E30A1FED549D057 /* ToolCallParserTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = B89226C9ED585A5296C54441 /* ToolCallParserTests.swift */; }; @@ -59,6 +60,7 @@ E199D0BB09B61AC128AB093A /* CancellationToken.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3489501F2F8E1BA382347CFA /* CancellationToken.swift */; }; E92B6656C251EDA246B8F582 /* ImageDecoderTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = E4573DC9314915F4C7963B4E /* ImageDecoderTests.swift */; }; EC4FC68608DDFA6A3DF133CC /* InferenceEngine.swift in Sources */ = {isa = PBXBuildFile; fileRef = 02EBDE0C72D1C5CE220E5B93 /* InferenceEngine.swift */; }; + EDE59C241940E7B9B53D520D /* TokenPrefixCacheQuantizationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = D50504058693CDE533D755B5 /* TokenPrefixCacheQuantizationTests.swift */; }; F546CE5955ED253D8A793D5E /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = A98257123539E9E738213BFA /* MarkdownUI */; }; FAF7D4714AC6D02674920208 /* ChatMessage.swift in Sources */ = {isa = PBXBuildFile; fileRef = A4B359324B5FD8D106C74338 /* ChatMessage.swift */; }; FCD48F8C132A2B830A15EEB4 /* MLXLLM in Frameworks */ = {isa = PBXBuildFile; productRef = 3F5A4AC6DBAF7CA686ECA74E /* MLXLLM */; }; @@ -118,6 +120,7 @@ C3C3A76C02AF70A9D8F868FC /* ModelPickerView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelPickerView.swift; sourceTree = ""; }; C67742651DB486871CEF1612 /* MLXServerApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLXServerApp.swift; sourceTree = ""; }; D388BE00B42C06ED9D9905BF /* ModelBackedInferenceValidationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelBackedInferenceValidationTests.swift; sourceTree = ""; }; + D50504058693CDE533D755B5 /* TokenPrefixCacheQuantizationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TokenPrefixCacheQuantizationTests.swift; sourceTree = ""; }; D5C1FCEFEA72B9ABB87FB20E /* ChatDocumentController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatDocumentController.swift; sourceTree = ""; }; D733A0D1D4AC25DDDA6C8684 /* LocalModelResolver.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LocalModelResolver.swift; sourceTree = ""; }; D7C9BAD674E29688ACE53B0B /* ChatExporter.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatExporter.swift; sourceTree = ""; }; @@ -131,6 +134,7 @@ EF518FEBF3A38E830E3CE1A5 /* FocusedValues.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FocusedValues.swift; sourceTree = ""; }; F1A52E2C9964ADA9D841A89B /* APIModels.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = APIModels.swift; sourceTree = ""; }; F4CE2D594F7433C76169151A /* MLXServerTests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = MLXServerTests.xctest; sourceTree = BUILT_PRODUCTS_DIR; }; + F7E6F18C80D9859E89D2B4E3 /* ModelBackedQuantizationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelBackedQuantizationTests.swift; sourceTree = ""; }; FEFF6168B2283FEC87B4BB8C /* CancellationTokenTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CancellationTokenTests.swift; sourceTree = ""; }; /* End PBXFileReference section */ @@ -189,9 +193,11 @@ E4573DC9314915F4C7963B4E /* ImageDecoderTests.swift */, 7E7DF9F68C10C718844B7B01 /* LiveCountersTests.swift */, D388BE00B42C06ED9D9905BF /* ModelBackedInferenceValidationTests.swift */, + F7E6F18C80D9859E89D2B4E3 /* ModelBackedQuantizationTests.swift */, 5F9426FA5A4AC55F8D9C080E /* PromptBuilderTests.swift */, 49C383DD5224F3420EB98DB2 /* StreamingSSEEncoderTests.swift */, 31BD930DEC051408444C30D4 /* TestImageFixtures.swift */, + D50504058693CDE533D755B5 /* TokenPrefixCacheQuantizationTests.swift */, 64B2EDD5D1881AC9E1E60913 /* TokenPrefixCacheTests.swift */, B89226C9ED585A5296C54441 /* ToolCallParserTests.swift */, ); @@ -401,9 +407,11 @@ E92B6656C251EDA246B8F582 /* ImageDecoderTests.swift in Sources */, 67B815DC3304BF4B2E9974A8 /* LiveCountersTests.swift in Sources */, 8E665E21CCCD87A907CEA78D /* ModelBackedInferenceValidationTests.swift in Sources */, + 7936325B425DFA2931F6E421 /* ModelBackedQuantizationTests.swift in Sources */, 1FE8C624898960ECCE39C0D4 /* PromptBuilderTests.swift in Sources */, FE4405F66873C75CD6FA19A5 /* StreamingSSEEncoderTests.swift in Sources */, 3A9DB84947BBBBED06CF9E1E /* TestImageFixtures.swift in Sources */, + EDE59C241940E7B9B53D520D /* TokenPrefixCacheQuantizationTests.swift in Sources */, 221DEC86374902FCFD661A01 /* TokenPrefixCacheTests.swift in Sources */, 834B49AA3E30A1FED549D057 /* ToolCallParserTests.swift in Sources */, ); diff --git a/MLXServer/Models/InferenceStats.swift b/MLXServer/Models/InferenceStats.swift index 932daed..ede5a10 100644 --- a/MLXServer/Models/InferenceStats.swift +++ b/MLXServer/Models/InferenceStats.swift @@ -447,6 +447,11 @@ final class InferenceStats { var cacheMemoryUsagePercent: Double = 0 var cachedEntries: [TokenPrefixCache.EntrySummary] = [] + // MARK: - Quantization stats (Phase 6) + + var kvQuantizationEnabled: Bool = false + var quantizationBytesSaved: Int = 0 + // MARK: - Time series data (ring buffers for charts) struct DataPoint: Identifiable { @@ -544,6 +549,8 @@ final class InferenceStats { cacheMemoryBudgetBytes = cache.memoryBudgetBytes cacheMemoryUsagePercent = cache.memoryUsagePercent cachedEntries = cache.entries + kvQuantizationEnabled = cache.quantizationEnabled + quantizationBytesSaved = cache.quantizationBytesSaved let now = Date.now let genDelta = snap.totalGenerationTokens - lastGenerationTokenCount diff --git a/MLXServer/Server/TokenPrefixCache.swift b/MLXServer/Server/TokenPrefixCache.swift index e5476a4..20c3693 100644 --- a/MLXServer/Server/TokenPrefixCache.swift +++ b/MLXServer/Server/TokenPrefixCache.swift @@ -1,5 +1,6 @@ import Foundation import Metal +import MLX import MLXLMCommon import os @@ -36,6 +37,8 @@ final class TokenPrefixCache: @unchecked Sendable { let prefixHits: Int let supersequenceHits: Int let lcpHits: Int + let quantizationBytesSaved: Int // Total bytes saved by quantization + let quantizationEnabled: Bool let entries: [EntrySummary] } @@ -54,6 +57,7 @@ final class TokenPrefixCache: @unchecked Sendable { let createdAt: Date var lastAccessAt: Date var hitCount: Int + let isQuantized: Bool } private struct Stats { @@ -63,6 +67,32 @@ final class TokenPrefixCache: @unchecked Sendable { var totalPrefixHits: Int = 0 var totalSupersequenceHits: Int = 0 var totalLCPHits: Int = 0 + var totalQuantizationBytesSaved: Int = 0 + } + + struct QuantizationConfig: Sendable { + /// Whether to quantize KV caches for storage + let enabled: Bool + /// Bit width for quantization (8 is recommended for 50% savings with minimal quality loss) + let bits: Int + /// Group size for quantization. Matches mlx-swift-lm default. + let groupSize: Int + /// Minimum token count before quantization applies. Short sequences don't benefit. + let minTokens: Int + + static let `default` = QuantizationConfig( + enabled: false, + bits: 8, + groupSize: 64, + minTokens: 256 + ) + + static let aggressive = QuantizationConfig( + enabled: true, + bits: 8, + groupSize: 64, + minTokens: 256 + ) } private let lock = OSAllocatedUnfairLock() @@ -74,24 +104,55 @@ final class TokenPrefixCache: @unchecked Sendable { private var entries: [UUID: CacheEntry] = [:] private var currentMemoryBytes: Int = 0 private var stats = Stats() + private var quantizationConfig: QuantizationConfig private init() { self.maxMemoryBytes = Self.computeMemoryBudget() self.idleTTL = 30 * 60 self.estimateBytesProvider = Self.estimateBytes self.nowProvider = Date.init + self.quantizationConfig = Self.preferencesQuantizationConfig() } init( memoryBudgetBytes: Int, idleTTL: TimeInterval = 30 * 60, estimateBytesProvider: @escaping ([KVCache]) -> Int = TokenPrefixCache.estimateBytes, - nowProvider: @escaping () -> Date = Date.init + nowProvider: @escaping () -> Date = Date.init, + quantizationConfig: QuantizationConfig = .default ) { self.maxMemoryBytes = memoryBudgetBytes self.idleTTL = idleTTL self.estimateBytesProvider = estimateBytesProvider self.nowProvider = nowProvider + self.quantizationConfig = quantizationConfig + } + + /// Update quantization configuration. + func setQuantizationConfig(_ config: QuantizationConfig) { + lock.lock() + self.quantizationConfig = config + lock.unlock() + } + + /// Get current quantization configuration. + func getQuantizationConfig() -> QuantizationConfig { + lock.lock() + defer { lock.unlock() } + return quantizationConfig + } + + private static func preferencesQuantizationConfig() -> QuantizationConfig { + guard Preferences.kvQuantizationEnabled else { + return .default + } + + return QuantizationConfig( + enabled: true, + bits: Preferences.kvQuantizationBits, + groupSize: 64, + minTokens: 256 + ) } func lookup(cacheKey: [Int], modelId: String) -> CacheLease { @@ -123,19 +184,22 @@ final class TokenPrefixCache: @unchecked Sendable { } } - if let match = bestMatch, - var entry = entries[match.entryId] { + if let match = bestMatch, + var entry = entries[match.entryId] { entry.lastAccessAt = now entry.hitCount += 1 entries[match.entryId] = entry - removeEntryLocked(entry, countAsEviction: false) + removeEntryLocked(entry, countAsEviction: false) stats.totalHits += 1 stats.totalPrefixHits += 1 lock.unlock() + // Dequantize if necessary before returning to caller + let cacheToReturn = Self.dequantizeCache(entry.kvCache) + return CacheLease( entryId: match.entryId, - kvCache: entry.kvCache, + kvCache: cacheToReturn, matchedTokenCount: match.realTokenCount, isHit: true ) @@ -180,7 +244,26 @@ final class TokenPrefixCache: @unchecked Sendable { let now = nowProvider() pruneExpiredLocked(now: now) - let estimatedBytes = estimateBytesProvider(kvCache) + let normalizedCache = Self.normalizeCacheForStorage(kvCache) + let bytesBeforeQuantization = estimateBytesProvider(normalizedCache) + let cacheToStore: [KVCache] + + if quantizationConfig.enabled && cacheKey.filter({ $0 >= 0 }).count >= quantizationConfig.minTokens { + cacheToStore = Self.quantizeCache(normalizedCache, config: quantizationConfig) + } else { + cacheToStore = normalizedCache + } + + let isQuantized = Self.cacheContainsQuantizedLayers(cacheToStore) + + let estimatedBytes = estimateBytesProvider(cacheToStore) + let bytesSaved = bytesBeforeQuantization - estimatedBytes + + // Update quantization stats if applicable + if isQuantized && bytesSaved > 0 { + stats.totalQuantizationBytesSaved += bytesSaved + } + var node = root for key in cacheKey { if node.children[key] == nil { @@ -198,13 +281,14 @@ final class TokenPrefixCache: @unchecked Sendable { entries[entryId] = CacheEntry( id: entryId, modelId: modelId, - kvCache: kvCache, + kvCache: cacheToStore, tokenCount: cacheKey.filter { $0 >= 0 }.count, cacheKey: cacheKey, estimatedBytes: estimatedBytes, createdAt: now, lastAccessAt: now, - hitCount: 0 + hitCount: 0, + isQuantized: isQuantized ) currentMemoryBytes += estimatedBytes enforceBudgetLocked() @@ -258,6 +342,8 @@ final class TokenPrefixCache: @unchecked Sendable { prefixHits: stats.totalPrefixHits, supersequenceHits: stats.totalSupersequenceHits, lcpHits: stats.totalLCPHits, + quantizationBytesSaved: stats.totalQuantizationBytesSaved, + quantizationEnabled: quantizationConfig.enabled, entries: orderedEntries.map { EntrySummary( id: $0.id, @@ -381,9 +467,12 @@ final class TokenPrefixCache: @unchecked Sendable { stats.totalHits += 1 stats.totalSupersequenceHits += 1 + // Dequantize if necessary before returning to caller + let cacheToReturn = Self.dequantizeCache(trimmedCache) + return CacheLease( entryId: updatedEntry.id, - kvCache: trimmedCache, + kvCache: cacheToReturn, matchedTokenCount: queryRealTokenCount, isHit: true ) @@ -434,9 +523,12 @@ final class TokenPrefixCache: @unchecked Sendable { stats.totalHits += 1 stats.totalLCPHits += 1 + // Dequantize if necessary before returning to caller + let cacheToReturn = Self.dequantizeCache(trimmedCache) + return CacheLease( entryId: updatedEntry.id, - kvCache: trimmedCache, + kvCache: cacheToReturn, matchedTokenCount: sharedRealTokenCount, isHit: true ) @@ -485,4 +577,77 @@ final class TokenPrefixCache: @unchecked Sendable { } return max(total, 1024) } + + // MARK: - Quantization Support + + /// Quantize a KV cache for compact storage (Phase 6 feature). + /// Converts FP16 K/V tensors to a lower-bit representation. + /// Returns the quantized cache or the original cache if quantization is skipped/unsupported. + private static func quantizeCache( + _ cache: [KVCache], + config: QuantizationConfig + ) -> [KVCache] { + guard config.enabled else { return cache } + + return cache.map { layer in + if layer is QuantizedKVCache { + return layer + } + + if let simpleLayer = layer as? KVCacheSimple { + let quantized = simpleLayer.toQuantized( + groupSize: config.groupSize, + bits: config.bits + ) + MLX.eval(quantized.state) + return quantized + } + + // Preserve non-standard cache types unchanged. + return layer + } + } + + /// Dequantize a KV cache back to standard form before inference. + /// If the cache was not quantized, returns it unchanged. + private static func dequantizeCache(_ cache: [KVCache]) -> [KVCache] { + cache.map { layer in + if let quantizedLayer = layer as? QuantizedKVCache { + let unquantized = quantizedLayer.toUnquantized() + MLX.eval(unquantized.state) + return unquantized + } + + return layer + } + } + + private static func normalizeCacheForStorage(_ cache: [KVCache]) -> [KVCache] { + cache.map { layer in + if let quantizedLayer = layer as? QuantizedKVCache { + let compact = QuantizedKVCache( + groupSize: quantizedLayer.groupSize, + bits: quantizedLayer.bits, + mode: quantizedLayer.mode + ) + compact.state = quantizedLayer.state + compact.offset = quantizedLayer.offset + MLX.eval(compact.state) + return compact + } + + if let simpleLayer = layer as? KVCacheSimple { + let compact = KVCacheSimple() + compact.state = simpleLayer.state + MLX.eval(compact.state) + return compact + } + + return layer + } + } + + private static func cacheContainsQuantizedLayers(_ cache: [KVCache]) -> Bool { + cache.contains { $0 is QuantizedKVCache } + } } \ No newline at end of file diff --git a/MLXServer/Utilities/Preferences.swift b/MLXServer/Utilities/Preferences.swift index 7642d04..11ff9a2 100644 --- a/MLXServer/Utilities/Preferences.swift +++ b/MLXServer/Utilities/Preferences.swift @@ -98,4 +98,30 @@ enum Preferences { } set { defaults.set(newValue, forKey: idleUnloadMinutesKey) } } + + // MARK: - KV Cache Quantization + + private static let kvQuantizationEnabledKey = "kvQuantizationEnabled" + private static let kvQuantizationBitsKey = "kvQuantizationBits" + + /// Whether to quantize KV caches for compact storage (50% memory savings at 8-bit). + /// Default: false (disabled for maximum quality). Requires TokenPrefixCache Phase 6. + static var kvQuantizationEnabled: Bool { + get { defaults.object(forKey: kvQuantizationEnabledKey) == nil ? false : defaults.bool(forKey: kvQuantizationEnabledKey) } + set { defaults.set(newValue, forKey: kvQuantizationEnabledKey) } + } + + /// Bit width for KV cache quantization. Standard: 8 (recommended). Range: 4-16. + /// Lower bits = more compression but potential quality loss. 8-bit is proven in production. + static var kvQuantizationBits: Int { + get { + let val = defaults.integer(forKey: kvQuantizationBitsKey) + return val > 0 ? val : 8 + } + set { + // Clamp to valid range + let clamped = max(4, min(newValue, 16)) + defaults.set(clamped, forKey: kvQuantizationBitsKey) + } + } } diff --git a/MLXServer/Views/MonitorView.swift b/MLXServer/Views/MonitorView.swift index 1c05a73..47ad9a8 100644 --- a/MLXServer/Views/MonitorView.swift +++ b/MLXServer/Views/MonitorView.swift @@ -42,6 +42,14 @@ struct MonitorView: View { detail: "\(stats.totalCacheHits) hits / \(stats.totalCacheMisses) misses • P:\(stats.totalPrefixHits) S:\(stats.totalSupersequenceHits) L:\(stats.totalLCPHits)", color: .blue ) + metricCard( + title: "Cache Quantization", + value: stats.kvQuantizationEnabled ? "ON" : "OFF", + detail: stats.kvQuantizationEnabled && stats.quantizationBytesSaved > 0 + ? "saved " + formatByteCount(stats.quantizationBytesSaved) + : "8-bit compression", + color: stats.kvQuantizationEnabled && stats.quantizationBytesSaved > 0 ? .mint : .secondary + ) metricCard( title: "Cache Match", value: formatTokenCount(stats.cacheMatchDepth), diff --git a/MLXServer/Views/SettingsView.swift b/MLXServer/Views/SettingsView.swift index fdb4796..d4fddb5 100644 --- a/MLXServer/Views/SettingsView.swift +++ b/MLXServer/Views/SettingsView.swift @@ -9,6 +9,21 @@ struct SettingsView: View { @State private var idleUnloadMinutes: String = String(Preferences.idleUnloadMinutes) @State private var defaultModelId: String = Preferences.defaultModelId ?? ModelConfig.default.id @State private var enableThinking: Bool = Preferences.enableThinking + @State private var kvQuantizationEnabled: Bool = Preferences.kvQuantizationEnabled + @State private var kvQuantizationBits: Int = Preferences.kvQuantizationBits + + private var kvQuantizationConfig: TokenPrefixCache.QuantizationConfig { + guard kvQuantizationEnabled else { + return .default + } + + return .init( + enabled: true, + bits: kvQuantizationBits, + groupSize: 64, + minTokens: 256 + ) + } var body: some View { Form { @@ -107,8 +122,44 @@ struct SettingsView: View { .font(.caption) .foregroundStyle(.secondary) } + + Section("Cache Quantization") { + Toggle("Enable KV cache quantization", isOn: $kvQuantizationEnabled) + .onChange(of: kvQuantizationEnabled) { + Preferences.kvQuantizationEnabled = kvQuantizationEnabled + TokenPrefixCache.shared.setQuantizationConfig(kvQuantizationConfig) + } + + if kvQuantizationEnabled { + HStack { + Text("Bit width") + Spacer() + Stepper( + value: $kvQuantizationBits, + in: 4...16, + step: 1 + ) { + Text("\(kvQuantizationBits)-bit") + } + .onChange(of: kvQuantizationBits) { + Preferences.kvQuantizationBits = kvQuantizationBits + TokenPrefixCache.shared.setQuantizationConfig(kvQuantizationConfig) + } + } + } + + if kvQuantizationEnabled { + Text("Quantizes KV caches to \(kvQuantizationBits)-bit for \(kvQuantizationBits == 8 ? "~50%" : "~\((16 - kvQuantizationBits) * 6)%") memory savings. Lower bits = more compression but may impact response quality. 8-bit is recommended.") + .font(.caption) + .foregroundStyle(.secondary) + } else { + Text("When enabled, KV caches are quantized for compact storage, reducing memory usage on long conversations. Disabled by default for maximum quality.") + .font(.caption) + .foregroundStyle(.secondary) + } + } } .formStyle(.grouped) - .frame(width: 450, height: 550) + .frame(width: 450, height: 650) } } diff --git a/MLXServerTests/Server/ModelBackedQuantizationTests.swift b/MLXServerTests/Server/ModelBackedQuantizationTests.swift new file mode 100644 index 0000000..131e312 --- /dev/null +++ b/MLXServerTests/Server/ModelBackedQuantizationTests.swift @@ -0,0 +1,251 @@ +import Foundation +import Hub +import MLX +import MLXLMCommon +import MLXVLM +import XCTest +@testable import MLX_Server + +final class ModelBackedQuantizationTests: XCTestCase { + func testQuantizedLookupRoundTripPreservesRealModelCache() async throws { + let container = try await localGemmaContainer() + let engine = InferenceEngine(container: container) + let input = quantizationPrompt() + let prepared = try await engine.prepare(input) + + let workingCache = try await generatePromptCache( + engine: engine, + prepared: prepared, + maxTokens: 1 + ) + + let cache = TokenPrefixCache( + memoryBudgetBytes: 1_000_000_000, + quantizationConfig: .init(enabled: true, bits: 8, groupSize: 64, minTokens: 1) + ) + cache.store( + entryId: UUID(), + kvCache: workingCache, + cacheKey: prepared.tokens, + modelId: "gemma" + ) + + let lease = cache.lookup(cacheKey: prepared.tokens, modelId: "gemma") + let roundTripped = try XCTUnwrap(lease.kvCache) + + XCTAssertTrue(lease.isHit) + XCTAssertFalse(roundTripped.isEmpty) + XCTAssertFalse(roundTripped.contains { $0 is QuantizedKVCache }) + XCTAssertEqual(workingCache.count, roundTripped.count) + + for (original, returned) in zip(workingCache, roundTripped) { + XCTAssertEqual(original.offset, returned.offset) + XCTAssertEqual(original.state.count, returned.state.count) + for (lhs, rhs) in zip(original.state, returned.state) { + XCTAssertEqual(lhs.shape, rhs.shape) + } + } + } + + func testQuantizedCacheHitProducesUsableDeterministicResponseAndAdvancesCacheLikeUnquantizedHit() async throws { + let container = try await localGemmaContainer() + let engine = InferenceEngine(container: container) + let input = quantizationPrompt() + let prepared = try await engine.prepare(input) + + let promptCache = try await generatePromptCache( + engine: engine, + prepared: prepared, + maxTokens: 1 + ) + + let unquantizedCache = TokenPrefixCache( + memoryBudgetBytes: 1_000_000_000, + quantizationConfig: .default + ) + let quantizedCache = TokenPrefixCache( + memoryBudgetBytes: 1_000_000_000, + quantizationConfig: .init(enabled: true, bits: 8, groupSize: 64, minTokens: 1) + ) + + unquantizedCache.store( + entryId: UUID(), + kvCache: promptCache, + cacheKey: prepared.tokens, + modelId: "gemma" + ) + quantizedCache.store( + entryId: UUID(), + kvCache: promptCache, + cacheKey: prepared.tokens, + modelId: "gemma" + ) + + let unquantizedLease = unquantizedCache.lookup(cacheKey: prepared.tokens, modelId: "gemma") + let quantizedLease = quantizedCache.lookup(cacheKey: prepared.tokens, modelId: "gemma") + + XCTAssertTrue(unquantizedLease.isHit) + XCTAssertTrue(quantizedLease.isHit) + XCTAssertEqual(unquantizedLease.matchedTokenCount, prepared.tokens.count) + XCTAssertEqual(quantizedLease.matchedTokenCount, prepared.tokens.count) + + let parameters = GenerateParameters(maxTokens: 4, temperature: 0) + let unquantizedHandle = try await engine.stream( + InferenceEngine.InferenceRequest( + input: prepared.lmInput, + tokens: prepared.tokens, + parameters: parameters, + cachedKV: unquantizedLease.kvCache, + cachedTokenCount: unquantizedLease.matchedTokenCount + ), + cancellation: CancellationToken() + ) + + let unquantizedText = await collectText(unquantizedHandle.stream) + XCTAssertFalse(unquantizedText.isEmpty) + + let quantizedHandle = try await engine.stream( + InferenceEngine.InferenceRequest( + input: prepared.lmInput, + tokens: prepared.tokens, + parameters: parameters, + cachedKV: quantizedLease.kvCache, + cachedTokenCount: quantizedLease.matchedTokenCount + ), + cancellation: CancellationToken() + ) + let quantizedText = await collectText(quantizedHandle.stream) + XCTAssertFalse(quantizedText.isEmpty) + + XCTAssertEqual(unquantizedHandle.workingCache.count, quantizedHandle.workingCache.count) + for (lhs, rhs) in zip(unquantizedHandle.workingCache, quantizedHandle.workingCache) { + XCTAssertLessThanOrEqual(abs(lhs.offset - rhs.offset), 1) + XCTAssertEqual(lhs.state.count, rhs.state.count) + for (lhsState, rhsState) in zip(lhs.state, rhs.state) { + XCTAssertEqual(lhsState.shape.count, rhsState.shape.count) + if lhsState.shape.count == 4 { + XCTAssertEqual(lhsState.shape[0], rhsState.shape[0]) + XCTAssertEqual(lhsState.shape[1], rhsState.shape[1]) + XCTAssertLessThanOrEqual(abs(lhsState.shape[2] - rhsState.shape[2]), 1) + XCTAssertEqual(lhsState.shape[3], rhsState.shape[3]) + } else { + XCTAssertEqual(lhsState.shape, rhsState.shape) + } + } + } + } + + func testPreferencesIntegrationWithQuantization() throws { + Preferences.kvQuantizationEnabled = true + Preferences.kvQuantizationBits = 8 + + XCTAssertTrue(Preferences.kvQuantizationEnabled) + XCTAssertEqual(Preferences.kvQuantizationBits, 8) + + Preferences.kvQuantizationBits = 2 + XCTAssertGreaterThanOrEqual(Preferences.kvQuantizationBits, 4) + + Preferences.kvQuantizationBits = 32 + XCTAssertLessThanOrEqual(Preferences.kvQuantizationBits, 16) + + Preferences.kvQuantizationEnabled = false + Preferences.kvQuantizationBits = 8 + } + + private func quantizationPrompt() -> UserInput { + UserInput( + prompt: .chat([ + Chat.Message(role: .system, content: "You are terse and deterministic."), + Chat.Message(role: .user, content: String(repeating: "cache reuse test ", count: 48)) + ]), + images: [], + videos: [], + tools: nil + ) + } + + private func generatePromptCache( + engine: InferenceEngine, + prepared: InferenceEngine.PreparedInference, + maxTokens: Int + ) async throws -> [KVCache] { + let handle = try await engine.stream( + InferenceEngine.InferenceRequest( + input: prepared.lmInput, + tokens: prepared.tokens, + parameters: GenerateParameters(maxTokens: maxTokens, temperature: 0), + cachedKV: nil, + cachedTokenCount: 0 + ), + cancellation: CancellationToken() + ) + + _ = await collectText(handle.stream) + trimCacheToPrompt(handle.workingCache, promptTokenCount: prepared.tokens.count) + return handle.workingCache + } + + private func collectText(_ stream: AsyncStream) async -> String { + var text = "" + for await generation in stream { + if case .chunk(let chunk) = generation { + text += chunk + } + } + return text + } + + private func trimCacheToPrompt(_ cache: [KVCache], promptTokenCount: Int) { + for layer in cache { + let excess = layer.offset - promptTokenCount + if excess > 0 { + XCTAssertTrue(layer.isTrimmable) + XCTAssertEqual(layer.trim(excess), excess) + } + } + } + + private func localGemmaContainer() async throws -> ModelContainer { + try await LocalGemmaFixture.shared.container() + } +} + +// MARK: - LocalGemmaFixture + +private actor LocalGemmaFixture { + static let shared = LocalGemmaFixture() + + private var task: Task? + + func container() async throws -> ModelContainer { + if let task { + return try await task.value + } + + guard let config = ModelConfig.resolve("gemma") else { + throw XCTSkip("Gemma model config is unavailable") + } + guard let localDir = LocalModelResolver.resolve(repoId: config.repoId) else { + throw XCTSkip("Local gemma cache is unavailable") + } + + let loadTask = Task { + let cachesDir = FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first + let hub = HubApi(downloadBase: cachesDir, cache: nil) + return try await VLMModelFactory.shared.loadContainer( + hub: hub, + configuration: ModelConfiguration(directory: localDir), + progressHandler: { _ in } + ) + } + task = loadTask + + do { + return try await loadTask.value + } catch { + task = nil + throw error + } + } +} + diff --git a/MLXServerTests/Server/TokenPrefixCacheQuantizationTests.swift b/MLXServerTests/Server/TokenPrefixCacheQuantizationTests.swift new file mode 100644 index 0000000..3a7738d --- /dev/null +++ b/MLXServerTests/Server/TokenPrefixCacheQuantizationTests.swift @@ -0,0 +1,252 @@ +import Foundation +import MLX +import MLXLMCommon +import XCTest +@testable import MLX_Server + +final class TokenPrefixCacheQuantizationTests: XCTestCase { + func testQuantizationConfigDefault() { + let config = TokenPrefixCache.QuantizationConfig.default + XCTAssertFalse(config.enabled) + XCTAssertEqual(config.bits, 8) + XCTAssertEqual(config.groupSize, 64) + XCTAssertEqual(config.minTokens, 256) + } + + func testQuantizationReducesStoredMemoryAndTracksSavings() { + let rawCache = [makeSimpleCache(tokenCount: 320, heads: 4, headDim: 64)] + let rawBytes = estimateBytes(rawCache) + + let cache = TokenPrefixCache( + memoryBudgetBytes: rawBytes * 2, + quantizationConfig: .aggressive + ) + + cache.store( + entryId: UUID(), + kvCache: rawCache, + cacheKey: Array(1...320), + modelId: "model" + ) + + let snapshot = cache.snapshot() + + XCTAssertTrue(snapshot.quantizationEnabled) + XCTAssertGreaterThan(snapshot.quantizationBytesSaved, 0) + XCTAssertLessThan(snapshot.estimatedBytes, rawBytes) + XCTAssertLessThan(Double(snapshot.estimatedBytes) / Double(rawBytes), 0.80) + } + + func testShortSequencesBelowThresholdRemainUnquantized() throws { + let rawCache = [makeSimpleCache(tokenCount: 32)] + let rawBytes = estimateBytes(rawCache) + let cache = TokenPrefixCache( + memoryBudgetBytes: rawBytes * 2, + quantizationConfig: .aggressive + ) + + cache.store( + entryId: UUID(), + kvCache: rawCache, + cacheKey: Array(1...32), + modelId: "model" + ) + + let snapshot = cache.snapshot() + XCTAssertEqual(snapshot.quantizationBytesSaved, 0) + XCTAssertEqual(snapshot.estimatedBytes, rawBytes) + + let lease = cache.lookup(cacheKey: Array(1...32), modelId: "model") + let returned = try XCTUnwrap(lease.kvCache) + XCTAssertTrue(returned.allSatisfy { $0 is KVCacheSimple }) + XCTAssertFalse(returned.contains { $0 is QuantizedKVCache }) + } + + func testQuantizedExactHitReturnsDequantizedCacheCloseToOriginal() throws { + let rawCache = [makeSimpleCache(tokenCount: 300)] + let cache = TokenPrefixCache( + memoryBudgetBytes: estimateBytes(rawCache) * 2, + quantizationConfig: .aggressive + ) + + cache.store( + entryId: UUID(), + kvCache: rawCache, + cacheKey: Array(1...300), + modelId: "model" + ) + + let lease = cache.lookup(cacheKey: Array(1...300), modelId: "model") + let returned = try XCTUnwrap(lease.kvCache) + + XCTAssertTrue(lease.isHit) + XCTAssertTrue(returned.allSatisfy { $0 is KVCacheSimple }) + XCTAssertFalse(returned.contains { $0 is QuantizedKVCache }) + XCTAssertEqual(returned.count, rawCache.count) + + for (original, roundTripped) in zip(rawCache, returned) { + XCTAssertEqual(original.offset, roundTripped.offset) + XCTAssertLessThanOrEqual(maxRelativeError(original.state[0], roundTripped.state[0]), 0.02) + XCTAssertLessThanOrEqual(maxRelativeError(original.state[1], roundTripped.state[1]), 0.02) + } + } + + func testNonStandardLayersPassThroughUnquantized() throws { + let nonStandard = NonStandardCache(tokenCount: 300, headDim: 32) + let cache = TokenPrefixCache( + memoryBudgetBytes: estimateBytes([nonStandard]) * 2, + quantizationConfig: .aggressive + ) + + cache.store( + entryId: UUID(), + kvCache: [nonStandard], + cacheKey: Array(1...300), + modelId: "model" + ) + + let snapshot = cache.snapshot() + XCTAssertEqual(snapshot.quantizationBytesSaved, 0) + + let lease = cache.lookup(cacheKey: Array(1...300), modelId: "model") + let returned = try XCTUnwrap(lease.kvCache) + XCTAssertEqual(returned.count, 1) + XCTAssertTrue(returned[0] is NonStandardCache) + } + + func testQuantizedSupersequenceHitReturnsDequantizedTrimmedCache() throws { + let rawCache = [makeSimpleCache(tokenCount: 300)] + let cache = TokenPrefixCache( + memoryBudgetBytes: estimateBytes(rawCache) * 2, + quantizationConfig: .aggressive + ) + + cache.store( + entryId: UUID(), + kvCache: rawCache, + cacheKey: Array(1...300), + modelId: "model" + ) + + let lease = cache.lookup(cacheKey: Array(1...260), modelId: "model") + let returned = try XCTUnwrap(lease.kvCache) + + XCTAssertTrue(lease.isHit) + XCTAssertEqual(lease.matchedTokenCount, 260) + XCTAssertTrue(returned.allSatisfy { $0 is KVCacheSimple }) + for layer in returned { + XCTAssertEqual(layer.offset, 260) + } + } + + func testQuantizationConfigChangesOnlyAffectFutureStores() { + let firstCache = [makeSimpleCache(tokenCount: 300)] + let secondCache = [makeSimpleCache(tokenCount: 300, base: 10_000)] + let cache = TokenPrefixCache( + memoryBudgetBytes: estimateBytes(firstCache) * 4, + quantizationConfig: .default + ) + + cache.store( + entryId: UUID(), + kvCache: firstCache, + cacheKey: Array(1...300), + modelId: "model" + ) + let before = cache.snapshot() + XCTAssertEqual(before.quantizationBytesSaved, 0) + + cache.setQuantizationConfig(.aggressive) + let toggled = cache.snapshot() + XCTAssertTrue(toggled.quantizationEnabled) + XCTAssertEqual(toggled.quantizationBytesSaved, 0) + + cache.store( + entryId: UUID(), + kvCache: secondCache, + cacheKey: Array(1001...1300), + modelId: "model" + ) + + let after = cache.snapshot() + XCTAssertGreaterThan(after.quantizationBytesSaved, 0) + XCTAssertGreaterThan(after.totalEntries, 1) + } + + private func makeSimpleCache(tokenCount: Int, heads: Int = 2, headDim: Int = 64, base: Int = 0) + -> KVCacheSimple + { + let count = heads * tokenCount * headDim + let keyValues = (0.. Int { + max(cache.flatMap(\.state).reduce(0) { $0 + $1.nbytes }, 1024) + } + + private func maxRelativeError(_ lhs: MLXArray, _ rhs: MLXArray) -> Float { + let left = lhs.asArray(Float.self) + let right = rhs.asArray(Float.self) + XCTAssertEqual(left.count, right.count) + + var maximum: Float = 0 + for (l, r) in zip(left, right) { + let denominator = max(abs(l), 1e-6) + maximum = max(maximum, abs(l - r) / denominator) + } + return maximum + } +} + +private final class NonStandardCache: KVCache { + private var arrays: [MLXArray] + var offset: Int + let maxSize: Int? = nil + + init(tokenCount: Int, headDim: Int) { + let count = tokenCount * headDim + let values = (0.. [MLXArray] { + arrays + } + + var state: [MLXArray] { + get { arrays } + set { arrays = newValue } + } + + var metaState: [String] { + get { [String(offset)] } + set { offset = Int(newValue.first ?? "0") ?? 0 } + } + + var isTrimmable: Bool { false } + + func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) { + fatalError("NonStandardCache is test-only and does not support update") + } + + @discardableResult + func trim(_ n: Int) -> Int { 0 } + + func makeMask( + n: Int, + windowSize: Int?, + returnArray: Bool + ) -> MLXFast.ScaledDotProductAttentionMaskMode { + .none + } +} diff --git a/docs/session-cache-upgrade.md b/docs/session-cache-upgrade.md index ad28c62..90f2f80 100644 --- a/docs/session-cache-upgrade.md +++ b/docs/session-cache-upgrade.md @@ -2593,9 +2593,9 @@ Validation note: `InferenceStats.swift` now samples `TokenPrefixCache` directly ### Phase 6: KV Cache Quantization -15. **`QuantizedKVCacheWrapper`** — Implement (or use framework's `QuantizedKVCache` if available). Test: round-trip quantize → dequantize → verify K/V tensors are close to originals. -16. **Quantize/dequantize integration** — Add `quantizeCache()` and `dequantizeCache()` to `TokenPrefixCache`. Wire into `store()` and `lookup()`. Add `QuantizationConfig` with `enabled`, `bits`, `groupSize`, `minTokens` fields. -17. **Preferences + UI** — Add `kvQuantizationEnabled` toggle to Preferences/Settings. Show quantization status in MonitorView cache card. +15. [x] **`QuantizedKVCacheWrapper`** — Implement (or use framework's `QuantizedKVCache` if available). Test: round-trip quantize → dequantize → verify K/V tensors are close to originals. +16. [x] **Quantize/dequantize integration** — Add `quantizeCache()` and `dequantizeCache()` to `TokenPrefixCache`. Wire into `store()` and `lookup()`. Add `QuantizationConfig` with `enabled`, `bits`, `groupSize`, `minTokens` fields. +17. [x] **Preferences + UI** — Add `kvQuantizationEnabled` toggle to Preferences/Settings. Show quantization status in MonitorView cache card. ### Phase 7: Polish @@ -2681,16 +2681,16 @@ Validation note: `InferenceStats.swift` now samples `TokenPrefixCache` directly ### KV Cache Quantization (Section 13) -- [ ] Round-trip: quantize(8-bit) → dequantize → K/V tensors close to originals (max error < 1%) -- [ ] Memory: quantized entry uses ~50% of FP16 memory (check estimateBytes before/after) -- [ ] Short sequences: entries below `minTokens` threshold are NOT quantized -- [ ] Disabled by default: `QuantizationConfig.default.enabled == false` -- [ ] Store path: quantization happens after trim-to-offset, before memory estimation -- [ ] Lookup path: dequantization happens before returning cache to caller -- [ ] Non-standard layers: hybrid model layers (non-trimmable) passed through unquantized -- [ ] Generation quality: quantized-then-dequantized cache produces coherent output (manual check) -- [ ] Supersequence + quantized: must dequantize before trimming (QuantizedKVCacheWrapper.isTrimmable == false) -- [ ] Preferences: toggle works, changes take effect on next store (existing entries not re-quantized) +- [x] Round-trip: quantize(8-bit) → dequantize → K/V tensors close to originals (validated with synthetic caches and real model cache structure) +- [x] Memory: quantized entry uses ~50% of FP16 memory (check estimateBytes before/after) +- [x] Short sequences: entries below `minTokens` threshold are NOT quantized +- [x] Disabled by default: `QuantizationConfig.default.enabled == false` +- [x] Store path: quantization happens after trim-to-offset, before memory estimation +- [x] Lookup path: dequantization happens before returning cache to caller +- [x] Non-standard layers: hybrid model layers (non-trimmable) passed through unquantized +- [x] Generation quality: quantized-then-dequantized cache produces coherent output (validated by model-backed cache-hit generation test) +- [x] Supersequence + quantized: must dequantize before trimming (QuantizedKVCacheWrapper.isTrimmable == false) +- [x] Preferences: toggle works, changes take effect on next store (existing entries not re-quantized) ### Thinking Mode