feat: phase 6 implemented and tested
This commit is contained in:
@@ -33,6 +33,7 @@
|
|||||||
67D0628F148FE3C2200E0AEF /* APIServerResponseResolutionTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 051FEC14CC76A677F79ACD21 /* APIServerResponseResolutionTests.swift */; };
|
67D0628F148FE3C2200E0AEF /* APIServerResponseResolutionTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 051FEC14CC76A677F79ACD21 /* APIServerResponseResolutionTests.swift */; };
|
||||||
6828CCA8B78AB40906F87CAB /* LocalModelResolver.swift in Sources */ = {isa = PBXBuildFile; fileRef = D733A0D1D4AC25DDDA6C8684 /* LocalModelResolver.swift */; };
|
6828CCA8B78AB40906F87CAB /* LocalModelResolver.swift in Sources */ = {isa = PBXBuildFile; fileRef = D733A0D1D4AC25DDDA6C8684 /* LocalModelResolver.swift */; };
|
||||||
741692862DB1F13EA0B2D14D /* TokenPrefixCache.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1962D530BEABCC7F1E8E0ED1 /* TokenPrefixCache.swift */; };
|
741692862DB1F13EA0B2D14D /* TokenPrefixCache.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1962D530BEABCC7F1E8E0ED1 /* TokenPrefixCache.swift */; };
|
||||||
|
7936325B425DFA2931F6E421 /* ModelBackedQuantizationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = F7E6F18C80D9859E89D2B4E3 /* ModelBackedQuantizationTests.swift */; };
|
||||||
7CD765C1E2F9F4D7504C8D09 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = B629DA084A9A40E54F8EA5FA /* Assets.xcassets */; };
|
7CD765C1E2F9F4D7504C8D09 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = B629DA084A9A40E54F8EA5FA /* Assets.xcassets */; };
|
||||||
80646C5066BF79BC76E1D9D7 /* ModelConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = 38DFC212AF4359A45FBE22BA /* ModelConfig.swift */; };
|
80646C5066BF79BC76E1D9D7 /* ModelConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = 38DFC212AF4359A45FBE22BA /* ModelConfig.swift */; };
|
||||||
834B49AA3E30A1FED549D057 /* ToolCallParserTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = B89226C9ED585A5296C54441 /* ToolCallParserTests.swift */; };
|
834B49AA3E30A1FED549D057 /* ToolCallParserTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = B89226C9ED585A5296C54441 /* ToolCallParserTests.swift */; };
|
||||||
@@ -59,6 +60,7 @@
|
|||||||
E199D0BB09B61AC128AB093A /* CancellationToken.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3489501F2F8E1BA382347CFA /* CancellationToken.swift */; };
|
E199D0BB09B61AC128AB093A /* CancellationToken.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3489501F2F8E1BA382347CFA /* CancellationToken.swift */; };
|
||||||
E92B6656C251EDA246B8F582 /* ImageDecoderTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = E4573DC9314915F4C7963B4E /* ImageDecoderTests.swift */; };
|
E92B6656C251EDA246B8F582 /* ImageDecoderTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = E4573DC9314915F4C7963B4E /* ImageDecoderTests.swift */; };
|
||||||
EC4FC68608DDFA6A3DF133CC /* InferenceEngine.swift in Sources */ = {isa = PBXBuildFile; fileRef = 02EBDE0C72D1C5CE220E5B93 /* InferenceEngine.swift */; };
|
EC4FC68608DDFA6A3DF133CC /* InferenceEngine.swift in Sources */ = {isa = PBXBuildFile; fileRef = 02EBDE0C72D1C5CE220E5B93 /* InferenceEngine.swift */; };
|
||||||
|
EDE59C241940E7B9B53D520D /* TokenPrefixCacheQuantizationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = D50504058693CDE533D755B5 /* TokenPrefixCacheQuantizationTests.swift */; };
|
||||||
F546CE5955ED253D8A793D5E /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = A98257123539E9E738213BFA /* MarkdownUI */; };
|
F546CE5955ED253D8A793D5E /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = A98257123539E9E738213BFA /* MarkdownUI */; };
|
||||||
FAF7D4714AC6D02674920208 /* ChatMessage.swift in Sources */ = {isa = PBXBuildFile; fileRef = A4B359324B5FD8D106C74338 /* ChatMessage.swift */; };
|
FAF7D4714AC6D02674920208 /* ChatMessage.swift in Sources */ = {isa = PBXBuildFile; fileRef = A4B359324B5FD8D106C74338 /* ChatMessage.swift */; };
|
||||||
FCD48F8C132A2B830A15EEB4 /* MLXLLM in Frameworks */ = {isa = PBXBuildFile; productRef = 3F5A4AC6DBAF7CA686ECA74E /* MLXLLM */; };
|
FCD48F8C132A2B830A15EEB4 /* MLXLLM in Frameworks */ = {isa = PBXBuildFile; productRef = 3F5A4AC6DBAF7CA686ECA74E /* MLXLLM */; };
|
||||||
@@ -118,6 +120,7 @@
|
|||||||
C3C3A76C02AF70A9D8F868FC /* ModelPickerView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelPickerView.swift; sourceTree = "<group>"; };
|
C3C3A76C02AF70A9D8F868FC /* ModelPickerView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelPickerView.swift; sourceTree = "<group>"; };
|
||||||
C67742651DB486871CEF1612 /* MLXServerApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLXServerApp.swift; sourceTree = "<group>"; };
|
C67742651DB486871CEF1612 /* MLXServerApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLXServerApp.swift; sourceTree = "<group>"; };
|
||||||
D388BE00B42C06ED9D9905BF /* ModelBackedInferenceValidationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelBackedInferenceValidationTests.swift; sourceTree = "<group>"; };
|
D388BE00B42C06ED9D9905BF /* ModelBackedInferenceValidationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelBackedInferenceValidationTests.swift; sourceTree = "<group>"; };
|
||||||
|
D50504058693CDE533D755B5 /* TokenPrefixCacheQuantizationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TokenPrefixCacheQuantizationTests.swift; sourceTree = "<group>"; };
|
||||||
D5C1FCEFEA72B9ABB87FB20E /* ChatDocumentController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatDocumentController.swift; sourceTree = "<group>"; };
|
D5C1FCEFEA72B9ABB87FB20E /* ChatDocumentController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatDocumentController.swift; sourceTree = "<group>"; };
|
||||||
D733A0D1D4AC25DDDA6C8684 /* LocalModelResolver.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LocalModelResolver.swift; sourceTree = "<group>"; };
|
D733A0D1D4AC25DDDA6C8684 /* LocalModelResolver.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LocalModelResolver.swift; sourceTree = "<group>"; };
|
||||||
D7C9BAD674E29688ACE53B0B /* ChatExporter.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatExporter.swift; sourceTree = "<group>"; };
|
D7C9BAD674E29688ACE53B0B /* ChatExporter.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatExporter.swift; sourceTree = "<group>"; };
|
||||||
@@ -131,6 +134,7 @@
|
|||||||
EF518FEBF3A38E830E3CE1A5 /* FocusedValues.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FocusedValues.swift; sourceTree = "<group>"; };
|
EF518FEBF3A38E830E3CE1A5 /* FocusedValues.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FocusedValues.swift; sourceTree = "<group>"; };
|
||||||
F1A52E2C9964ADA9D841A89B /* APIModels.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = APIModels.swift; sourceTree = "<group>"; };
|
F1A52E2C9964ADA9D841A89B /* APIModels.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = APIModels.swift; sourceTree = "<group>"; };
|
||||||
F4CE2D594F7433C76169151A /* MLXServerTests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = MLXServerTests.xctest; sourceTree = BUILT_PRODUCTS_DIR; };
|
F4CE2D594F7433C76169151A /* MLXServerTests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = MLXServerTests.xctest; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||||
|
F7E6F18C80D9859E89D2B4E3 /* ModelBackedQuantizationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelBackedQuantizationTests.swift; sourceTree = "<group>"; };
|
||||||
FEFF6168B2283FEC87B4BB8C /* CancellationTokenTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CancellationTokenTests.swift; sourceTree = "<group>"; };
|
FEFF6168B2283FEC87B4BB8C /* CancellationTokenTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CancellationTokenTests.swift; sourceTree = "<group>"; };
|
||||||
/* End PBXFileReference section */
|
/* End PBXFileReference section */
|
||||||
|
|
||||||
@@ -189,9 +193,11 @@
|
|||||||
E4573DC9314915F4C7963B4E /* ImageDecoderTests.swift */,
|
E4573DC9314915F4C7963B4E /* ImageDecoderTests.swift */,
|
||||||
7E7DF9F68C10C718844B7B01 /* LiveCountersTests.swift */,
|
7E7DF9F68C10C718844B7B01 /* LiveCountersTests.swift */,
|
||||||
D388BE00B42C06ED9D9905BF /* ModelBackedInferenceValidationTests.swift */,
|
D388BE00B42C06ED9D9905BF /* ModelBackedInferenceValidationTests.swift */,
|
||||||
|
F7E6F18C80D9859E89D2B4E3 /* ModelBackedQuantizationTests.swift */,
|
||||||
5F9426FA5A4AC55F8D9C080E /* PromptBuilderTests.swift */,
|
5F9426FA5A4AC55F8D9C080E /* PromptBuilderTests.swift */,
|
||||||
49C383DD5224F3420EB98DB2 /* StreamingSSEEncoderTests.swift */,
|
49C383DD5224F3420EB98DB2 /* StreamingSSEEncoderTests.swift */,
|
||||||
31BD930DEC051408444C30D4 /* TestImageFixtures.swift */,
|
31BD930DEC051408444C30D4 /* TestImageFixtures.swift */,
|
||||||
|
D50504058693CDE533D755B5 /* TokenPrefixCacheQuantizationTests.swift */,
|
||||||
64B2EDD5D1881AC9E1E60913 /* TokenPrefixCacheTests.swift */,
|
64B2EDD5D1881AC9E1E60913 /* TokenPrefixCacheTests.swift */,
|
||||||
B89226C9ED585A5296C54441 /* ToolCallParserTests.swift */,
|
B89226C9ED585A5296C54441 /* ToolCallParserTests.swift */,
|
||||||
);
|
);
|
||||||
@@ -401,9 +407,11 @@
|
|||||||
E92B6656C251EDA246B8F582 /* ImageDecoderTests.swift in Sources */,
|
E92B6656C251EDA246B8F582 /* ImageDecoderTests.swift in Sources */,
|
||||||
67B815DC3304BF4B2E9974A8 /* LiveCountersTests.swift in Sources */,
|
67B815DC3304BF4B2E9974A8 /* LiveCountersTests.swift in Sources */,
|
||||||
8E665E21CCCD87A907CEA78D /* ModelBackedInferenceValidationTests.swift in Sources */,
|
8E665E21CCCD87A907CEA78D /* ModelBackedInferenceValidationTests.swift in Sources */,
|
||||||
|
7936325B425DFA2931F6E421 /* ModelBackedQuantizationTests.swift in Sources */,
|
||||||
1FE8C624898960ECCE39C0D4 /* PromptBuilderTests.swift in Sources */,
|
1FE8C624898960ECCE39C0D4 /* PromptBuilderTests.swift in Sources */,
|
||||||
FE4405F66873C75CD6FA19A5 /* StreamingSSEEncoderTests.swift in Sources */,
|
FE4405F66873C75CD6FA19A5 /* StreamingSSEEncoderTests.swift in Sources */,
|
||||||
3A9DB84947BBBBED06CF9E1E /* TestImageFixtures.swift in Sources */,
|
3A9DB84947BBBBED06CF9E1E /* TestImageFixtures.swift in Sources */,
|
||||||
|
EDE59C241940E7B9B53D520D /* TokenPrefixCacheQuantizationTests.swift in Sources */,
|
||||||
221DEC86374902FCFD661A01 /* TokenPrefixCacheTests.swift in Sources */,
|
221DEC86374902FCFD661A01 /* TokenPrefixCacheTests.swift in Sources */,
|
||||||
834B49AA3E30A1FED549D057 /* ToolCallParserTests.swift in Sources */,
|
834B49AA3E30A1FED549D057 /* ToolCallParserTests.swift in Sources */,
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -447,6 +447,11 @@ final class InferenceStats {
|
|||||||
var cacheMemoryUsagePercent: Double = 0
|
var cacheMemoryUsagePercent: Double = 0
|
||||||
var cachedEntries: [TokenPrefixCache.EntrySummary] = []
|
var cachedEntries: [TokenPrefixCache.EntrySummary] = []
|
||||||
|
|
||||||
|
// MARK: - Quantization stats (Phase 6)
|
||||||
|
|
||||||
|
var kvQuantizationEnabled: Bool = false
|
||||||
|
var quantizationBytesSaved: Int = 0
|
||||||
|
|
||||||
// MARK: - Time series data (ring buffers for charts)
|
// MARK: - Time series data (ring buffers for charts)
|
||||||
|
|
||||||
struct DataPoint: Identifiable {
|
struct DataPoint: Identifiable {
|
||||||
@@ -544,6 +549,8 @@ final class InferenceStats {
|
|||||||
cacheMemoryBudgetBytes = cache.memoryBudgetBytes
|
cacheMemoryBudgetBytes = cache.memoryBudgetBytes
|
||||||
cacheMemoryUsagePercent = cache.memoryUsagePercent
|
cacheMemoryUsagePercent = cache.memoryUsagePercent
|
||||||
cachedEntries = cache.entries
|
cachedEntries = cache.entries
|
||||||
|
kvQuantizationEnabled = cache.quantizationEnabled
|
||||||
|
quantizationBytesSaved = cache.quantizationBytesSaved
|
||||||
|
|
||||||
let now = Date.now
|
let now = Date.now
|
||||||
let genDelta = snap.totalGenerationTokens - lastGenerationTokenCount
|
let genDelta = snap.totalGenerationTokens - lastGenerationTokenCount
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import Foundation
|
import Foundation
|
||||||
import Metal
|
import Metal
|
||||||
|
import MLX
|
||||||
import MLXLMCommon
|
import MLXLMCommon
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -36,6 +37,8 @@ final class TokenPrefixCache: @unchecked Sendable {
|
|||||||
let prefixHits: Int
|
let prefixHits: Int
|
||||||
let supersequenceHits: Int
|
let supersequenceHits: Int
|
||||||
let lcpHits: Int
|
let lcpHits: Int
|
||||||
|
let quantizationBytesSaved: Int // Total bytes saved by quantization
|
||||||
|
let quantizationEnabled: Bool
|
||||||
let entries: [EntrySummary]
|
let entries: [EntrySummary]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,6 +57,7 @@ final class TokenPrefixCache: @unchecked Sendable {
|
|||||||
let createdAt: Date
|
let createdAt: Date
|
||||||
var lastAccessAt: Date
|
var lastAccessAt: Date
|
||||||
var hitCount: Int
|
var hitCount: Int
|
||||||
|
let isQuantized: Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
private struct Stats {
|
private struct Stats {
|
||||||
@@ -63,6 +67,32 @@ final class TokenPrefixCache: @unchecked Sendable {
|
|||||||
var totalPrefixHits: Int = 0
|
var totalPrefixHits: Int = 0
|
||||||
var totalSupersequenceHits: Int = 0
|
var totalSupersequenceHits: Int = 0
|
||||||
var totalLCPHits: Int = 0
|
var totalLCPHits: Int = 0
|
||||||
|
var totalQuantizationBytesSaved: Int = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
struct QuantizationConfig: Sendable {
|
||||||
|
/// Whether to quantize KV caches for storage
|
||||||
|
let enabled: Bool
|
||||||
|
/// Bit width for quantization (8 is recommended for 50% savings with minimal quality loss)
|
||||||
|
let bits: Int
|
||||||
|
/// Group size for quantization. Matches mlx-swift-lm default.
|
||||||
|
let groupSize: Int
|
||||||
|
/// Minimum token count before quantization applies. Short sequences don't benefit.
|
||||||
|
let minTokens: Int
|
||||||
|
|
||||||
|
static let `default` = QuantizationConfig(
|
||||||
|
enabled: false,
|
||||||
|
bits: 8,
|
||||||
|
groupSize: 64,
|
||||||
|
minTokens: 256
|
||||||
|
)
|
||||||
|
|
||||||
|
static let aggressive = QuantizationConfig(
|
||||||
|
enabled: true,
|
||||||
|
bits: 8,
|
||||||
|
groupSize: 64,
|
||||||
|
minTokens: 256
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
private let lock = OSAllocatedUnfairLock()
|
private let lock = OSAllocatedUnfairLock()
|
||||||
@@ -74,24 +104,55 @@ final class TokenPrefixCache: @unchecked Sendable {
|
|||||||
private var entries: [UUID: CacheEntry] = [:]
|
private var entries: [UUID: CacheEntry] = [:]
|
||||||
private var currentMemoryBytes: Int = 0
|
private var currentMemoryBytes: Int = 0
|
||||||
private var stats = Stats()
|
private var stats = Stats()
|
||||||
|
private var quantizationConfig: QuantizationConfig
|
||||||
|
|
||||||
private init() {
|
private init() {
|
||||||
self.maxMemoryBytes = Self.computeMemoryBudget()
|
self.maxMemoryBytes = Self.computeMemoryBudget()
|
||||||
self.idleTTL = 30 * 60
|
self.idleTTL = 30 * 60
|
||||||
self.estimateBytesProvider = Self.estimateBytes
|
self.estimateBytesProvider = Self.estimateBytes
|
||||||
self.nowProvider = Date.init
|
self.nowProvider = Date.init
|
||||||
|
self.quantizationConfig = Self.preferencesQuantizationConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
init(
|
init(
|
||||||
memoryBudgetBytes: Int,
|
memoryBudgetBytes: Int,
|
||||||
idleTTL: TimeInterval = 30 * 60,
|
idleTTL: TimeInterval = 30 * 60,
|
||||||
estimateBytesProvider: @escaping ([KVCache]) -> Int = TokenPrefixCache.estimateBytes,
|
estimateBytesProvider: @escaping ([KVCache]) -> Int = TokenPrefixCache.estimateBytes,
|
||||||
nowProvider: @escaping () -> Date = Date.init
|
nowProvider: @escaping () -> Date = Date.init,
|
||||||
|
quantizationConfig: QuantizationConfig = .default
|
||||||
) {
|
) {
|
||||||
self.maxMemoryBytes = memoryBudgetBytes
|
self.maxMemoryBytes = memoryBudgetBytes
|
||||||
self.idleTTL = idleTTL
|
self.idleTTL = idleTTL
|
||||||
self.estimateBytesProvider = estimateBytesProvider
|
self.estimateBytesProvider = estimateBytesProvider
|
||||||
self.nowProvider = nowProvider
|
self.nowProvider = nowProvider
|
||||||
|
self.quantizationConfig = quantizationConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update quantization configuration.
|
||||||
|
func setQuantizationConfig(_ config: QuantizationConfig) {
|
||||||
|
lock.lock()
|
||||||
|
self.quantizationConfig = config
|
||||||
|
lock.unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get current quantization configuration.
|
||||||
|
func getQuantizationConfig() -> QuantizationConfig {
|
||||||
|
lock.lock()
|
||||||
|
defer { lock.unlock() }
|
||||||
|
return quantizationConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
private static func preferencesQuantizationConfig() -> QuantizationConfig {
|
||||||
|
guard Preferences.kvQuantizationEnabled else {
|
||||||
|
return .default
|
||||||
|
}
|
||||||
|
|
||||||
|
return QuantizationConfig(
|
||||||
|
enabled: true,
|
||||||
|
bits: Preferences.kvQuantizationBits,
|
||||||
|
groupSize: 64,
|
||||||
|
minTokens: 256
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func lookup(cacheKey: [Int], modelId: String) -> CacheLease {
|
func lookup(cacheKey: [Int], modelId: String) -> CacheLease {
|
||||||
@@ -123,19 +184,22 @@ final class TokenPrefixCache: @unchecked Sendable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let match = bestMatch,
|
if let match = bestMatch,
|
||||||
var entry = entries[match.entryId] {
|
var entry = entries[match.entryId] {
|
||||||
entry.lastAccessAt = now
|
entry.lastAccessAt = now
|
||||||
entry.hitCount += 1
|
entry.hitCount += 1
|
||||||
entries[match.entryId] = entry
|
entries[match.entryId] = entry
|
||||||
removeEntryLocked(entry, countAsEviction: false)
|
removeEntryLocked(entry, countAsEviction: false)
|
||||||
stats.totalHits += 1
|
stats.totalHits += 1
|
||||||
stats.totalPrefixHits += 1
|
stats.totalPrefixHits += 1
|
||||||
lock.unlock()
|
lock.unlock()
|
||||||
|
|
||||||
|
// Dequantize if necessary before returning to caller
|
||||||
|
let cacheToReturn = Self.dequantizeCache(entry.kvCache)
|
||||||
|
|
||||||
return CacheLease(
|
return CacheLease(
|
||||||
entryId: match.entryId,
|
entryId: match.entryId,
|
||||||
kvCache: entry.kvCache,
|
kvCache: cacheToReturn,
|
||||||
matchedTokenCount: match.realTokenCount,
|
matchedTokenCount: match.realTokenCount,
|
||||||
isHit: true
|
isHit: true
|
||||||
)
|
)
|
||||||
@@ -180,7 +244,26 @@ final class TokenPrefixCache: @unchecked Sendable {
|
|||||||
let now = nowProvider()
|
let now = nowProvider()
|
||||||
pruneExpiredLocked(now: now)
|
pruneExpiredLocked(now: now)
|
||||||
|
|
||||||
let estimatedBytes = estimateBytesProvider(kvCache)
|
let normalizedCache = Self.normalizeCacheForStorage(kvCache)
|
||||||
|
let bytesBeforeQuantization = estimateBytesProvider(normalizedCache)
|
||||||
|
let cacheToStore: [KVCache]
|
||||||
|
|
||||||
|
if quantizationConfig.enabled && cacheKey.filter({ $0 >= 0 }).count >= quantizationConfig.minTokens {
|
||||||
|
cacheToStore = Self.quantizeCache(normalizedCache, config: quantizationConfig)
|
||||||
|
} else {
|
||||||
|
cacheToStore = normalizedCache
|
||||||
|
}
|
||||||
|
|
||||||
|
let isQuantized = Self.cacheContainsQuantizedLayers(cacheToStore)
|
||||||
|
|
||||||
|
let estimatedBytes = estimateBytesProvider(cacheToStore)
|
||||||
|
let bytesSaved = bytesBeforeQuantization - estimatedBytes
|
||||||
|
|
||||||
|
// Update quantization stats if applicable
|
||||||
|
if isQuantized && bytesSaved > 0 {
|
||||||
|
stats.totalQuantizationBytesSaved += bytesSaved
|
||||||
|
}
|
||||||
|
|
||||||
var node = root
|
var node = root
|
||||||
for key in cacheKey {
|
for key in cacheKey {
|
||||||
if node.children[key] == nil {
|
if node.children[key] == nil {
|
||||||
@@ -198,13 +281,14 @@ final class TokenPrefixCache: @unchecked Sendable {
|
|||||||
entries[entryId] = CacheEntry(
|
entries[entryId] = CacheEntry(
|
||||||
id: entryId,
|
id: entryId,
|
||||||
modelId: modelId,
|
modelId: modelId,
|
||||||
kvCache: kvCache,
|
kvCache: cacheToStore,
|
||||||
tokenCount: cacheKey.filter { $0 >= 0 }.count,
|
tokenCount: cacheKey.filter { $0 >= 0 }.count,
|
||||||
cacheKey: cacheKey,
|
cacheKey: cacheKey,
|
||||||
estimatedBytes: estimatedBytes,
|
estimatedBytes: estimatedBytes,
|
||||||
createdAt: now,
|
createdAt: now,
|
||||||
lastAccessAt: now,
|
lastAccessAt: now,
|
||||||
hitCount: 0
|
hitCount: 0,
|
||||||
|
isQuantized: isQuantized
|
||||||
)
|
)
|
||||||
currentMemoryBytes += estimatedBytes
|
currentMemoryBytes += estimatedBytes
|
||||||
enforceBudgetLocked()
|
enforceBudgetLocked()
|
||||||
@@ -258,6 +342,8 @@ final class TokenPrefixCache: @unchecked Sendable {
|
|||||||
prefixHits: stats.totalPrefixHits,
|
prefixHits: stats.totalPrefixHits,
|
||||||
supersequenceHits: stats.totalSupersequenceHits,
|
supersequenceHits: stats.totalSupersequenceHits,
|
||||||
lcpHits: stats.totalLCPHits,
|
lcpHits: stats.totalLCPHits,
|
||||||
|
quantizationBytesSaved: stats.totalQuantizationBytesSaved,
|
||||||
|
quantizationEnabled: quantizationConfig.enabled,
|
||||||
entries: orderedEntries.map {
|
entries: orderedEntries.map {
|
||||||
EntrySummary(
|
EntrySummary(
|
||||||
id: $0.id,
|
id: $0.id,
|
||||||
@@ -381,9 +467,12 @@ final class TokenPrefixCache: @unchecked Sendable {
|
|||||||
stats.totalHits += 1
|
stats.totalHits += 1
|
||||||
stats.totalSupersequenceHits += 1
|
stats.totalSupersequenceHits += 1
|
||||||
|
|
||||||
|
// Dequantize if necessary before returning to caller
|
||||||
|
let cacheToReturn = Self.dequantizeCache(trimmedCache)
|
||||||
|
|
||||||
return CacheLease(
|
return CacheLease(
|
||||||
entryId: updatedEntry.id,
|
entryId: updatedEntry.id,
|
||||||
kvCache: trimmedCache,
|
kvCache: cacheToReturn,
|
||||||
matchedTokenCount: queryRealTokenCount,
|
matchedTokenCount: queryRealTokenCount,
|
||||||
isHit: true
|
isHit: true
|
||||||
)
|
)
|
||||||
@@ -434,9 +523,12 @@ final class TokenPrefixCache: @unchecked Sendable {
|
|||||||
stats.totalHits += 1
|
stats.totalHits += 1
|
||||||
stats.totalLCPHits += 1
|
stats.totalLCPHits += 1
|
||||||
|
|
||||||
|
// Dequantize if necessary before returning to caller
|
||||||
|
let cacheToReturn = Self.dequantizeCache(trimmedCache)
|
||||||
|
|
||||||
return CacheLease(
|
return CacheLease(
|
||||||
entryId: updatedEntry.id,
|
entryId: updatedEntry.id,
|
||||||
kvCache: trimmedCache,
|
kvCache: cacheToReturn,
|
||||||
matchedTokenCount: sharedRealTokenCount,
|
matchedTokenCount: sharedRealTokenCount,
|
||||||
isHit: true
|
isHit: true
|
||||||
)
|
)
|
||||||
@@ -485,4 +577,77 @@ final class TokenPrefixCache: @unchecked Sendable {
|
|||||||
}
|
}
|
||||||
return max(total, 1024)
|
return max(total, 1024)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MARK: - Quantization Support
|
||||||
|
|
||||||
|
/// Quantize a KV cache for compact storage (Phase 6 feature).
|
||||||
|
/// Converts FP16 K/V tensors to a lower-bit representation.
|
||||||
|
/// Returns the quantized cache or the original cache if quantization is skipped/unsupported.
|
||||||
|
private static func quantizeCache(
|
||||||
|
_ cache: [KVCache],
|
||||||
|
config: QuantizationConfig
|
||||||
|
) -> [KVCache] {
|
||||||
|
guard config.enabled else { return cache }
|
||||||
|
|
||||||
|
return cache.map { layer in
|
||||||
|
if layer is QuantizedKVCache {
|
||||||
|
return layer
|
||||||
|
}
|
||||||
|
|
||||||
|
if let simpleLayer = layer as? KVCacheSimple {
|
||||||
|
let quantized = simpleLayer.toQuantized(
|
||||||
|
groupSize: config.groupSize,
|
||||||
|
bits: config.bits
|
||||||
|
)
|
||||||
|
MLX.eval(quantized.state)
|
||||||
|
return quantized
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preserve non-standard cache types unchanged.
|
||||||
|
return layer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Dequantize a KV cache back to standard form before inference.
|
||||||
|
/// If the cache was not quantized, returns it unchanged.
|
||||||
|
private static func dequantizeCache(_ cache: [KVCache]) -> [KVCache] {
|
||||||
|
cache.map { layer in
|
||||||
|
if let quantizedLayer = layer as? QuantizedKVCache {
|
||||||
|
let unquantized = quantizedLayer.toUnquantized()
|
||||||
|
MLX.eval(unquantized.state)
|
||||||
|
return unquantized
|
||||||
|
}
|
||||||
|
|
||||||
|
return layer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static func normalizeCacheForStorage(_ cache: [KVCache]) -> [KVCache] {
|
||||||
|
cache.map { layer in
|
||||||
|
if let quantizedLayer = layer as? QuantizedKVCache {
|
||||||
|
let compact = QuantizedKVCache(
|
||||||
|
groupSize: quantizedLayer.groupSize,
|
||||||
|
bits: quantizedLayer.bits,
|
||||||
|
mode: quantizedLayer.mode
|
||||||
|
)
|
||||||
|
compact.state = quantizedLayer.state
|
||||||
|
compact.offset = quantizedLayer.offset
|
||||||
|
MLX.eval(compact.state)
|
||||||
|
return compact
|
||||||
|
}
|
||||||
|
|
||||||
|
if let simpleLayer = layer as? KVCacheSimple {
|
||||||
|
let compact = KVCacheSimple()
|
||||||
|
compact.state = simpleLayer.state
|
||||||
|
MLX.eval(compact.state)
|
||||||
|
return compact
|
||||||
|
}
|
||||||
|
|
||||||
|
return layer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static func cacheContainsQuantizedLayers(_ cache: [KVCache]) -> Bool {
|
||||||
|
cache.contains { $0 is QuantizedKVCache }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -98,4 +98,30 @@ enum Preferences {
|
|||||||
}
|
}
|
||||||
set { defaults.set(newValue, forKey: idleUnloadMinutesKey) }
|
set { defaults.set(newValue, forKey: idleUnloadMinutesKey) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MARK: - KV Cache Quantization
|
||||||
|
|
||||||
|
private static let kvQuantizationEnabledKey = "kvQuantizationEnabled"
|
||||||
|
private static let kvQuantizationBitsKey = "kvQuantizationBits"
|
||||||
|
|
||||||
|
/// Whether to quantize KV caches for compact storage (50% memory savings at 8-bit).
|
||||||
|
/// Default: false (disabled for maximum quality). Requires TokenPrefixCache Phase 6.
|
||||||
|
static var kvQuantizationEnabled: Bool {
|
||||||
|
get { defaults.object(forKey: kvQuantizationEnabledKey) == nil ? false : defaults.bool(forKey: kvQuantizationEnabledKey) }
|
||||||
|
set { defaults.set(newValue, forKey: kvQuantizationEnabledKey) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Bit width for KV cache quantization. Standard: 8 (recommended). Range: 4-16.
|
||||||
|
/// Lower bits = more compression but potential quality loss. 8-bit is proven in production.
|
||||||
|
static var kvQuantizationBits: Int {
|
||||||
|
get {
|
||||||
|
let val = defaults.integer(forKey: kvQuantizationBitsKey)
|
||||||
|
return val > 0 ? val : 8
|
||||||
|
}
|
||||||
|
set {
|
||||||
|
// Clamp to valid range
|
||||||
|
let clamped = max(4, min(newValue, 16))
|
||||||
|
defaults.set(clamped, forKey: kvQuantizationBitsKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,6 +42,14 @@ struct MonitorView: View {
|
|||||||
detail: "\(stats.totalCacheHits) hits / \(stats.totalCacheMisses) misses • P:\(stats.totalPrefixHits) S:\(stats.totalSupersequenceHits) L:\(stats.totalLCPHits)",
|
detail: "\(stats.totalCacheHits) hits / \(stats.totalCacheMisses) misses • P:\(stats.totalPrefixHits) S:\(stats.totalSupersequenceHits) L:\(stats.totalLCPHits)",
|
||||||
color: .blue
|
color: .blue
|
||||||
)
|
)
|
||||||
|
metricCard(
|
||||||
|
title: "Cache Quantization",
|
||||||
|
value: stats.kvQuantizationEnabled ? "ON" : "OFF",
|
||||||
|
detail: stats.kvQuantizationEnabled && stats.quantizationBytesSaved > 0
|
||||||
|
? "saved " + formatByteCount(stats.quantizationBytesSaved)
|
||||||
|
: "8-bit compression",
|
||||||
|
color: stats.kvQuantizationEnabled && stats.quantizationBytesSaved > 0 ? .mint : .secondary
|
||||||
|
)
|
||||||
metricCard(
|
metricCard(
|
||||||
title: "Cache Match",
|
title: "Cache Match",
|
||||||
value: formatTokenCount(stats.cacheMatchDepth),
|
value: formatTokenCount(stats.cacheMatchDepth),
|
||||||
|
|||||||
@@ -9,6 +9,21 @@ struct SettingsView: View {
|
|||||||
@State private var idleUnloadMinutes: String = String(Preferences.idleUnloadMinutes)
|
@State private var idleUnloadMinutes: String = String(Preferences.idleUnloadMinutes)
|
||||||
@State private var defaultModelId: String = Preferences.defaultModelId ?? ModelConfig.default.id
|
@State private var defaultModelId: String = Preferences.defaultModelId ?? ModelConfig.default.id
|
||||||
@State private var enableThinking: Bool = Preferences.enableThinking
|
@State private var enableThinking: Bool = Preferences.enableThinking
|
||||||
|
@State private var kvQuantizationEnabled: Bool = Preferences.kvQuantizationEnabled
|
||||||
|
@State private var kvQuantizationBits: Int = Preferences.kvQuantizationBits
|
||||||
|
|
||||||
|
private var kvQuantizationConfig: TokenPrefixCache.QuantizationConfig {
|
||||||
|
guard kvQuantizationEnabled else {
|
||||||
|
return .default
|
||||||
|
}
|
||||||
|
|
||||||
|
return .init(
|
||||||
|
enabled: true,
|
||||||
|
bits: kvQuantizationBits,
|
||||||
|
groupSize: 64,
|
||||||
|
minTokens: 256
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
var body: some View {
|
var body: some View {
|
||||||
Form {
|
Form {
|
||||||
@@ -107,8 +122,44 @@ struct SettingsView: View {
|
|||||||
.font(.caption)
|
.font(.caption)
|
||||||
.foregroundStyle(.secondary)
|
.foregroundStyle(.secondary)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Section("Cache Quantization") {
|
||||||
|
Toggle("Enable KV cache quantization", isOn: $kvQuantizationEnabled)
|
||||||
|
.onChange(of: kvQuantizationEnabled) {
|
||||||
|
Preferences.kvQuantizationEnabled = kvQuantizationEnabled
|
||||||
|
TokenPrefixCache.shared.setQuantizationConfig(kvQuantizationConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
if kvQuantizationEnabled {
|
||||||
|
HStack {
|
||||||
|
Text("Bit width")
|
||||||
|
Spacer()
|
||||||
|
Stepper(
|
||||||
|
value: $kvQuantizationBits,
|
||||||
|
in: 4...16,
|
||||||
|
step: 1
|
||||||
|
) {
|
||||||
|
Text("\(kvQuantizationBits)-bit")
|
||||||
|
}
|
||||||
|
.onChange(of: kvQuantizationBits) {
|
||||||
|
Preferences.kvQuantizationBits = kvQuantizationBits
|
||||||
|
TokenPrefixCache.shared.setQuantizationConfig(kvQuantizationConfig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if kvQuantizationEnabled {
|
||||||
|
Text("Quantizes KV caches to \(kvQuantizationBits)-bit for \(kvQuantizationBits == 8 ? "~50%" : "~\((16 - kvQuantizationBits) * 6)%") memory savings. Lower bits = more compression but may impact response quality. 8-bit is recommended.")
|
||||||
|
.font(.caption)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
} else {
|
||||||
|
Text("When enabled, KV caches are quantized for compact storage, reducing memory usage on long conversations. Disabled by default for maximum quality.")
|
||||||
|
.font(.caption)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
.formStyle(.grouped)
|
.formStyle(.grouped)
|
||||||
.frame(width: 450, height: 550)
|
.frame(width: 450, height: 650)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
251
MLXServerTests/Server/ModelBackedQuantizationTests.swift
Normal file
251
MLXServerTests/Server/ModelBackedQuantizationTests.swift
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
import Foundation
|
||||||
|
import Hub
|
||||||
|
import MLX
|
||||||
|
import MLXLMCommon
|
||||||
|
import MLXVLM
|
||||||
|
import XCTest
|
||||||
|
@testable import MLX_Server
|
||||||
|
|
||||||
|
final class ModelBackedQuantizationTests: XCTestCase {
|
||||||
|
func testQuantizedLookupRoundTripPreservesRealModelCache() async throws {
|
||||||
|
let container = try await localGemmaContainer()
|
||||||
|
let engine = InferenceEngine(container: container)
|
||||||
|
let input = quantizationPrompt()
|
||||||
|
let prepared = try await engine.prepare(input)
|
||||||
|
|
||||||
|
let workingCache = try await generatePromptCache(
|
||||||
|
engine: engine,
|
||||||
|
prepared: prepared,
|
||||||
|
maxTokens: 1
|
||||||
|
)
|
||||||
|
|
||||||
|
let cache = TokenPrefixCache(
|
||||||
|
memoryBudgetBytes: 1_000_000_000,
|
||||||
|
quantizationConfig: .init(enabled: true, bits: 8, groupSize: 64, minTokens: 1)
|
||||||
|
)
|
||||||
|
cache.store(
|
||||||
|
entryId: UUID(),
|
||||||
|
kvCache: workingCache,
|
||||||
|
cacheKey: prepared.tokens,
|
||||||
|
modelId: "gemma"
|
||||||
|
)
|
||||||
|
|
||||||
|
let lease = cache.lookup(cacheKey: prepared.tokens, modelId: "gemma")
|
||||||
|
let roundTripped = try XCTUnwrap(lease.kvCache)
|
||||||
|
|
||||||
|
XCTAssertTrue(lease.isHit)
|
||||||
|
XCTAssertFalse(roundTripped.isEmpty)
|
||||||
|
XCTAssertFalse(roundTripped.contains { $0 is QuantizedKVCache })
|
||||||
|
XCTAssertEqual(workingCache.count, roundTripped.count)
|
||||||
|
|
||||||
|
for (original, returned) in zip(workingCache, roundTripped) {
|
||||||
|
XCTAssertEqual(original.offset, returned.offset)
|
||||||
|
XCTAssertEqual(original.state.count, returned.state.count)
|
||||||
|
for (lhs, rhs) in zip(original.state, returned.state) {
|
||||||
|
XCTAssertEqual(lhs.shape, rhs.shape)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testQuantizedCacheHitProducesUsableDeterministicResponseAndAdvancesCacheLikeUnquantizedHit() async throws {
|
||||||
|
let container = try await localGemmaContainer()
|
||||||
|
let engine = InferenceEngine(container: container)
|
||||||
|
let input = quantizationPrompt()
|
||||||
|
let prepared = try await engine.prepare(input)
|
||||||
|
|
||||||
|
let promptCache = try await generatePromptCache(
|
||||||
|
engine: engine,
|
||||||
|
prepared: prepared,
|
||||||
|
maxTokens: 1
|
||||||
|
)
|
||||||
|
|
||||||
|
let unquantizedCache = TokenPrefixCache(
|
||||||
|
memoryBudgetBytes: 1_000_000_000,
|
||||||
|
quantizationConfig: .default
|
||||||
|
)
|
||||||
|
let quantizedCache = TokenPrefixCache(
|
||||||
|
memoryBudgetBytes: 1_000_000_000,
|
||||||
|
quantizationConfig: .init(enabled: true, bits: 8, groupSize: 64, minTokens: 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
unquantizedCache.store(
|
||||||
|
entryId: UUID(),
|
||||||
|
kvCache: promptCache,
|
||||||
|
cacheKey: prepared.tokens,
|
||||||
|
modelId: "gemma"
|
||||||
|
)
|
||||||
|
quantizedCache.store(
|
||||||
|
entryId: UUID(),
|
||||||
|
kvCache: promptCache,
|
||||||
|
cacheKey: prepared.tokens,
|
||||||
|
modelId: "gemma"
|
||||||
|
)
|
||||||
|
|
||||||
|
let unquantizedLease = unquantizedCache.lookup(cacheKey: prepared.tokens, modelId: "gemma")
|
||||||
|
let quantizedLease = quantizedCache.lookup(cacheKey: prepared.tokens, modelId: "gemma")
|
||||||
|
|
||||||
|
XCTAssertTrue(unquantizedLease.isHit)
|
||||||
|
XCTAssertTrue(quantizedLease.isHit)
|
||||||
|
XCTAssertEqual(unquantizedLease.matchedTokenCount, prepared.tokens.count)
|
||||||
|
XCTAssertEqual(quantizedLease.matchedTokenCount, prepared.tokens.count)
|
||||||
|
|
||||||
|
let parameters = GenerateParameters(maxTokens: 4, temperature: 0)
|
||||||
|
let unquantizedHandle = try await engine.stream(
|
||||||
|
InferenceEngine.InferenceRequest(
|
||||||
|
input: prepared.lmInput,
|
||||||
|
tokens: prepared.tokens,
|
||||||
|
parameters: parameters,
|
||||||
|
cachedKV: unquantizedLease.kvCache,
|
||||||
|
cachedTokenCount: unquantizedLease.matchedTokenCount
|
||||||
|
),
|
||||||
|
cancellation: CancellationToken()
|
||||||
|
)
|
||||||
|
|
||||||
|
let unquantizedText = await collectText(unquantizedHandle.stream)
|
||||||
|
XCTAssertFalse(unquantizedText.isEmpty)
|
||||||
|
|
||||||
|
let quantizedHandle = try await engine.stream(
|
||||||
|
InferenceEngine.InferenceRequest(
|
||||||
|
input: prepared.lmInput,
|
||||||
|
tokens: prepared.tokens,
|
||||||
|
parameters: parameters,
|
||||||
|
cachedKV: quantizedLease.kvCache,
|
||||||
|
cachedTokenCount: quantizedLease.matchedTokenCount
|
||||||
|
),
|
||||||
|
cancellation: CancellationToken()
|
||||||
|
)
|
||||||
|
let quantizedText = await collectText(quantizedHandle.stream)
|
||||||
|
XCTAssertFalse(quantizedText.isEmpty)
|
||||||
|
|
||||||
|
XCTAssertEqual(unquantizedHandle.workingCache.count, quantizedHandle.workingCache.count)
|
||||||
|
for (lhs, rhs) in zip(unquantizedHandle.workingCache, quantizedHandle.workingCache) {
|
||||||
|
XCTAssertLessThanOrEqual(abs(lhs.offset - rhs.offset), 1)
|
||||||
|
XCTAssertEqual(lhs.state.count, rhs.state.count)
|
||||||
|
for (lhsState, rhsState) in zip(lhs.state, rhs.state) {
|
||||||
|
XCTAssertEqual(lhsState.shape.count, rhsState.shape.count)
|
||||||
|
if lhsState.shape.count == 4 {
|
||||||
|
XCTAssertEqual(lhsState.shape[0], rhsState.shape[0])
|
||||||
|
XCTAssertEqual(lhsState.shape[1], rhsState.shape[1])
|
||||||
|
XCTAssertLessThanOrEqual(abs(lhsState.shape[2] - rhsState.shape[2]), 1)
|
||||||
|
XCTAssertEqual(lhsState.shape[3], rhsState.shape[3])
|
||||||
|
} else {
|
||||||
|
XCTAssertEqual(lhsState.shape, rhsState.shape)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testPreferencesIntegrationWithQuantization() throws {
|
||||||
|
Preferences.kvQuantizationEnabled = true
|
||||||
|
Preferences.kvQuantizationBits = 8
|
||||||
|
|
||||||
|
XCTAssertTrue(Preferences.kvQuantizationEnabled)
|
||||||
|
XCTAssertEqual(Preferences.kvQuantizationBits, 8)
|
||||||
|
|
||||||
|
Preferences.kvQuantizationBits = 2
|
||||||
|
XCTAssertGreaterThanOrEqual(Preferences.kvQuantizationBits, 4)
|
||||||
|
|
||||||
|
Preferences.kvQuantizationBits = 32
|
||||||
|
XCTAssertLessThanOrEqual(Preferences.kvQuantizationBits, 16)
|
||||||
|
|
||||||
|
Preferences.kvQuantizationEnabled = false
|
||||||
|
Preferences.kvQuantizationBits = 8
|
||||||
|
}
|
||||||
|
|
||||||
|
private func quantizationPrompt() -> UserInput {
|
||||||
|
UserInput(
|
||||||
|
prompt: .chat([
|
||||||
|
Chat.Message(role: .system, content: "You are terse and deterministic."),
|
||||||
|
Chat.Message(role: .user, content: String(repeating: "cache reuse test ", count: 48))
|
||||||
|
]),
|
||||||
|
images: [],
|
||||||
|
videos: [],
|
||||||
|
tools: nil
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func generatePromptCache(
|
||||||
|
engine: InferenceEngine,
|
||||||
|
prepared: InferenceEngine.PreparedInference,
|
||||||
|
maxTokens: Int
|
||||||
|
) async throws -> [KVCache] {
|
||||||
|
let handle = try await engine.stream(
|
||||||
|
InferenceEngine.InferenceRequest(
|
||||||
|
input: prepared.lmInput,
|
||||||
|
tokens: prepared.tokens,
|
||||||
|
parameters: GenerateParameters(maxTokens: maxTokens, temperature: 0),
|
||||||
|
cachedKV: nil,
|
||||||
|
cachedTokenCount: 0
|
||||||
|
),
|
||||||
|
cancellation: CancellationToken()
|
||||||
|
)
|
||||||
|
|
||||||
|
_ = await collectText(handle.stream)
|
||||||
|
trimCacheToPrompt(handle.workingCache, promptTokenCount: prepared.tokens.count)
|
||||||
|
return handle.workingCache
|
||||||
|
}
|
||||||
|
|
||||||
|
private func collectText(_ stream: AsyncStream<Generation>) async -> String {
|
||||||
|
var text = ""
|
||||||
|
for await generation in stream {
|
||||||
|
if case .chunk(let chunk) = generation {
|
||||||
|
text += chunk
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
|
||||||
|
private func trimCacheToPrompt(_ cache: [KVCache], promptTokenCount: Int) {
|
||||||
|
for layer in cache {
|
||||||
|
let excess = layer.offset - promptTokenCount
|
||||||
|
if excess > 0 {
|
||||||
|
XCTAssertTrue(layer.isTrimmable)
|
||||||
|
XCTAssertEqual(layer.trim(excess), excess)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func localGemmaContainer() async throws -> ModelContainer {
|
||||||
|
try await LocalGemmaFixture.shared.container()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - LocalGemmaFixture
|
||||||
|
|
||||||
|
private actor LocalGemmaFixture {
|
||||||
|
static let shared = LocalGemmaFixture()
|
||||||
|
|
||||||
|
private var task: Task<ModelContainer, Error>?
|
||||||
|
|
||||||
|
func container() async throws -> ModelContainer {
|
||||||
|
if let task {
|
||||||
|
return try await task.value
|
||||||
|
}
|
||||||
|
|
||||||
|
guard let config = ModelConfig.resolve("gemma") else {
|
||||||
|
throw XCTSkip("Gemma model config is unavailable")
|
||||||
|
}
|
||||||
|
guard let localDir = LocalModelResolver.resolve(repoId: config.repoId) else {
|
||||||
|
throw XCTSkip("Local gemma cache is unavailable")
|
||||||
|
}
|
||||||
|
|
||||||
|
let loadTask = Task<ModelContainer, Error> {
|
||||||
|
let cachesDir = FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first
|
||||||
|
let hub = HubApi(downloadBase: cachesDir, cache: nil)
|
||||||
|
return try await VLMModelFactory.shared.loadContainer(
|
||||||
|
hub: hub,
|
||||||
|
configuration: ModelConfiguration(directory: localDir),
|
||||||
|
progressHandler: { _ in }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
task = loadTask
|
||||||
|
|
||||||
|
do {
|
||||||
|
return try await loadTask.value
|
||||||
|
} catch {
|
||||||
|
task = nil
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
252
MLXServerTests/Server/TokenPrefixCacheQuantizationTests.swift
Normal file
252
MLXServerTests/Server/TokenPrefixCacheQuantizationTests.swift
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
import Foundation
|
||||||
|
import MLX
|
||||||
|
import MLXLMCommon
|
||||||
|
import XCTest
|
||||||
|
@testable import MLX_Server
|
||||||
|
|
||||||
|
final class TokenPrefixCacheQuantizationTests: XCTestCase {
|
||||||
|
func testQuantizationConfigDefault() {
|
||||||
|
let config = TokenPrefixCache.QuantizationConfig.default
|
||||||
|
XCTAssertFalse(config.enabled)
|
||||||
|
XCTAssertEqual(config.bits, 8)
|
||||||
|
XCTAssertEqual(config.groupSize, 64)
|
||||||
|
XCTAssertEqual(config.minTokens, 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testQuantizationReducesStoredMemoryAndTracksSavings() {
|
||||||
|
let rawCache = [makeSimpleCache(tokenCount: 320, heads: 4, headDim: 64)]
|
||||||
|
let rawBytes = estimateBytes(rawCache)
|
||||||
|
|
||||||
|
let cache = TokenPrefixCache(
|
||||||
|
memoryBudgetBytes: rawBytes * 2,
|
||||||
|
quantizationConfig: .aggressive
|
||||||
|
)
|
||||||
|
|
||||||
|
cache.store(
|
||||||
|
entryId: UUID(),
|
||||||
|
kvCache: rawCache,
|
||||||
|
cacheKey: Array(1...320),
|
||||||
|
modelId: "model"
|
||||||
|
)
|
||||||
|
|
||||||
|
let snapshot = cache.snapshot()
|
||||||
|
|
||||||
|
XCTAssertTrue(snapshot.quantizationEnabled)
|
||||||
|
XCTAssertGreaterThan(snapshot.quantizationBytesSaved, 0)
|
||||||
|
XCTAssertLessThan(snapshot.estimatedBytes, rawBytes)
|
||||||
|
XCTAssertLessThan(Double(snapshot.estimatedBytes) / Double(rawBytes), 0.80)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testShortSequencesBelowThresholdRemainUnquantized() throws {
|
||||||
|
let rawCache = [makeSimpleCache(tokenCount: 32)]
|
||||||
|
let rawBytes = estimateBytes(rawCache)
|
||||||
|
let cache = TokenPrefixCache(
|
||||||
|
memoryBudgetBytes: rawBytes * 2,
|
||||||
|
quantizationConfig: .aggressive
|
||||||
|
)
|
||||||
|
|
||||||
|
cache.store(
|
||||||
|
entryId: UUID(),
|
||||||
|
kvCache: rawCache,
|
||||||
|
cacheKey: Array(1...32),
|
||||||
|
modelId: "model"
|
||||||
|
)
|
||||||
|
|
||||||
|
let snapshot = cache.snapshot()
|
||||||
|
XCTAssertEqual(snapshot.quantizationBytesSaved, 0)
|
||||||
|
XCTAssertEqual(snapshot.estimatedBytes, rawBytes)
|
||||||
|
|
||||||
|
let lease = cache.lookup(cacheKey: Array(1...32), modelId: "model")
|
||||||
|
let returned = try XCTUnwrap(lease.kvCache)
|
||||||
|
XCTAssertTrue(returned.allSatisfy { $0 is KVCacheSimple })
|
||||||
|
XCTAssertFalse(returned.contains { $0 is QuantizedKVCache })
|
||||||
|
}
|
||||||
|
|
||||||
|
func testQuantizedExactHitReturnsDequantizedCacheCloseToOriginal() throws {
|
||||||
|
let rawCache = [makeSimpleCache(tokenCount: 300)]
|
||||||
|
let cache = TokenPrefixCache(
|
||||||
|
memoryBudgetBytes: estimateBytes(rawCache) * 2,
|
||||||
|
quantizationConfig: .aggressive
|
||||||
|
)
|
||||||
|
|
||||||
|
cache.store(
|
||||||
|
entryId: UUID(),
|
||||||
|
kvCache: rawCache,
|
||||||
|
cacheKey: Array(1...300),
|
||||||
|
modelId: "model"
|
||||||
|
)
|
||||||
|
|
||||||
|
let lease = cache.lookup(cacheKey: Array(1...300), modelId: "model")
|
||||||
|
let returned = try XCTUnwrap(lease.kvCache)
|
||||||
|
|
||||||
|
XCTAssertTrue(lease.isHit)
|
||||||
|
XCTAssertTrue(returned.allSatisfy { $0 is KVCacheSimple })
|
||||||
|
XCTAssertFalse(returned.contains { $0 is QuantizedKVCache })
|
||||||
|
XCTAssertEqual(returned.count, rawCache.count)
|
||||||
|
|
||||||
|
for (original, roundTripped) in zip(rawCache, returned) {
|
||||||
|
XCTAssertEqual(original.offset, roundTripped.offset)
|
||||||
|
XCTAssertLessThanOrEqual(maxRelativeError(original.state[0], roundTripped.state[0]), 0.02)
|
||||||
|
XCTAssertLessThanOrEqual(maxRelativeError(original.state[1], roundTripped.state[1]), 0.02)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testNonStandardLayersPassThroughUnquantized() throws {
|
||||||
|
let nonStandard = NonStandardCache(tokenCount: 300, headDim: 32)
|
||||||
|
let cache = TokenPrefixCache(
|
||||||
|
memoryBudgetBytes: estimateBytes([nonStandard]) * 2,
|
||||||
|
quantizationConfig: .aggressive
|
||||||
|
)
|
||||||
|
|
||||||
|
cache.store(
|
||||||
|
entryId: UUID(),
|
||||||
|
kvCache: [nonStandard],
|
||||||
|
cacheKey: Array(1...300),
|
||||||
|
modelId: "model"
|
||||||
|
)
|
||||||
|
|
||||||
|
let snapshot = cache.snapshot()
|
||||||
|
XCTAssertEqual(snapshot.quantizationBytesSaved, 0)
|
||||||
|
|
||||||
|
let lease = cache.lookup(cacheKey: Array(1...300), modelId: "model")
|
||||||
|
let returned = try XCTUnwrap(lease.kvCache)
|
||||||
|
XCTAssertEqual(returned.count, 1)
|
||||||
|
XCTAssertTrue(returned[0] is NonStandardCache)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testQuantizedSupersequenceHitReturnsDequantizedTrimmedCache() throws {
|
||||||
|
let rawCache = [makeSimpleCache(tokenCount: 300)]
|
||||||
|
let cache = TokenPrefixCache(
|
||||||
|
memoryBudgetBytes: estimateBytes(rawCache) * 2,
|
||||||
|
quantizationConfig: .aggressive
|
||||||
|
)
|
||||||
|
|
||||||
|
cache.store(
|
||||||
|
entryId: UUID(),
|
||||||
|
kvCache: rawCache,
|
||||||
|
cacheKey: Array(1...300),
|
||||||
|
modelId: "model"
|
||||||
|
)
|
||||||
|
|
||||||
|
let lease = cache.lookup(cacheKey: Array(1...260), modelId: "model")
|
||||||
|
let returned = try XCTUnwrap(lease.kvCache)
|
||||||
|
|
||||||
|
XCTAssertTrue(lease.isHit)
|
||||||
|
XCTAssertEqual(lease.matchedTokenCount, 260)
|
||||||
|
XCTAssertTrue(returned.allSatisfy { $0 is KVCacheSimple })
|
||||||
|
for layer in returned {
|
||||||
|
XCTAssertEqual(layer.offset, 260)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testQuantizationConfigChangesOnlyAffectFutureStores() {
|
||||||
|
let firstCache = [makeSimpleCache(tokenCount: 300)]
|
||||||
|
let secondCache = [makeSimpleCache(tokenCount: 300, base: 10_000)]
|
||||||
|
let cache = TokenPrefixCache(
|
||||||
|
memoryBudgetBytes: estimateBytes(firstCache) * 4,
|
||||||
|
quantizationConfig: .default
|
||||||
|
)
|
||||||
|
|
||||||
|
cache.store(
|
||||||
|
entryId: UUID(),
|
||||||
|
kvCache: firstCache,
|
||||||
|
cacheKey: Array(1...300),
|
||||||
|
modelId: "model"
|
||||||
|
)
|
||||||
|
let before = cache.snapshot()
|
||||||
|
XCTAssertEqual(before.quantizationBytesSaved, 0)
|
||||||
|
|
||||||
|
cache.setQuantizationConfig(.aggressive)
|
||||||
|
let toggled = cache.snapshot()
|
||||||
|
XCTAssertTrue(toggled.quantizationEnabled)
|
||||||
|
XCTAssertEqual(toggled.quantizationBytesSaved, 0)
|
||||||
|
|
||||||
|
cache.store(
|
||||||
|
entryId: UUID(),
|
||||||
|
kvCache: secondCache,
|
||||||
|
cacheKey: Array(1001...1300),
|
||||||
|
modelId: "model"
|
||||||
|
)
|
||||||
|
|
||||||
|
let after = cache.snapshot()
|
||||||
|
XCTAssertGreaterThan(after.quantizationBytesSaved, 0)
|
||||||
|
XCTAssertGreaterThan(after.totalEntries, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func makeSimpleCache(tokenCount: Int, heads: Int = 2, headDim: Int = 64, base: Int = 0)
|
||||||
|
-> KVCacheSimple
|
||||||
|
{
|
||||||
|
let count = heads * tokenCount * headDim
|
||||||
|
let keyValues = (0..<count).map { index in
|
||||||
|
Float(base + index) / Float(max(count - 1, 1)) * 2 - 1
|
||||||
|
}
|
||||||
|
let valueValues = keyValues.reversed()
|
||||||
|
let keys = MLXArray(keyValues, [1, heads, tokenCount, headDim])
|
||||||
|
let values = MLXArray(Array(valueValues), [1, heads, tokenCount, headDim])
|
||||||
|
let cache = KVCacheSimple()
|
||||||
|
cache.state = [keys, values]
|
||||||
|
MLX.eval(cache.state)
|
||||||
|
return cache
|
||||||
|
}
|
||||||
|
|
||||||
|
private func estimateBytes(_ cache: [KVCache]) -> Int {
|
||||||
|
max(cache.flatMap(\.state).reduce(0) { $0 + $1.nbytes }, 1024)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func maxRelativeError(_ lhs: MLXArray, _ rhs: MLXArray) -> Float {
|
||||||
|
let left = lhs.asArray(Float.self)
|
||||||
|
let right = rhs.asArray(Float.self)
|
||||||
|
XCTAssertEqual(left.count, right.count)
|
||||||
|
|
||||||
|
var maximum: Float = 0
|
||||||
|
for (l, r) in zip(left, right) {
|
||||||
|
let denominator = max(abs(l), 1e-6)
|
||||||
|
maximum = max(maximum, abs(l - r) / denominator)
|
||||||
|
}
|
||||||
|
return maximum
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private final class NonStandardCache: KVCache {
|
||||||
|
private var arrays: [MLXArray]
|
||||||
|
var offset: Int
|
||||||
|
let maxSize: Int? = nil
|
||||||
|
|
||||||
|
init(tokenCount: Int, headDim: Int) {
|
||||||
|
let count = tokenCount * headDim
|
||||||
|
let values = (0..<count).map { Float($0) / Float(max(count - 1, 1)) }
|
||||||
|
self.arrays = [MLXArray(values, [1, 1, tokenCount, headDim])]
|
||||||
|
self.offset = tokenCount
|
||||||
|
}
|
||||||
|
|
||||||
|
func innerState() -> [MLXArray] {
|
||||||
|
arrays
|
||||||
|
}
|
||||||
|
|
||||||
|
var state: [MLXArray] {
|
||||||
|
get { arrays }
|
||||||
|
set { arrays = newValue }
|
||||||
|
}
|
||||||
|
|
||||||
|
var metaState: [String] {
|
||||||
|
get { [String(offset)] }
|
||||||
|
set { offset = Int(newValue.first ?? "0") ?? 0 }
|
||||||
|
}
|
||||||
|
|
||||||
|
var isTrimmable: Bool { false }
|
||||||
|
|
||||||
|
func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) {
|
||||||
|
fatalError("NonStandardCache is test-only and does not support update")
|
||||||
|
}
|
||||||
|
|
||||||
|
@discardableResult
|
||||||
|
func trim(_ n: Int) -> Int { 0 }
|
||||||
|
|
||||||
|
func makeMask(
|
||||||
|
n: Int,
|
||||||
|
windowSize: Int?,
|
||||||
|
returnArray: Bool
|
||||||
|
) -> MLXFast.ScaledDotProductAttentionMaskMode {
|
||||||
|
.none
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2593,9 +2593,9 @@ Validation note: `InferenceStats.swift` now samples `TokenPrefixCache` directly
|
|||||||
|
|
||||||
### Phase 6: KV Cache Quantization
|
### Phase 6: KV Cache Quantization
|
||||||
|
|
||||||
15. **`QuantizedKVCacheWrapper`** — Implement (or use framework's `QuantizedKVCache` if available). Test: round-trip quantize → dequantize → verify K/V tensors are close to originals.
|
15. [x] **`QuantizedKVCacheWrapper`** — Implement (or use framework's `QuantizedKVCache` if available). Test: round-trip quantize → dequantize → verify K/V tensors are close to originals.
|
||||||
16. **Quantize/dequantize integration** — Add `quantizeCache()` and `dequantizeCache()` to `TokenPrefixCache`. Wire into `store()` and `lookup()`. Add `QuantizationConfig` with `enabled`, `bits`, `groupSize`, `minTokens` fields.
|
16. [x] **Quantize/dequantize integration** — Add `quantizeCache()` and `dequantizeCache()` to `TokenPrefixCache`. Wire into `store()` and `lookup()`. Add `QuantizationConfig` with `enabled`, `bits`, `groupSize`, `minTokens` fields.
|
||||||
17. **Preferences + UI** — Add `kvQuantizationEnabled` toggle to Preferences/Settings. Show quantization status in MonitorView cache card.
|
17. [x] **Preferences + UI** — Add `kvQuantizationEnabled` toggle to Preferences/Settings. Show quantization status in MonitorView cache card.
|
||||||
|
|
||||||
### Phase 7: Polish
|
### Phase 7: Polish
|
||||||
|
|
||||||
@@ -2681,16 +2681,16 @@ Validation note: `InferenceStats.swift` now samples `TokenPrefixCache` directly
|
|||||||
|
|
||||||
### KV Cache Quantization (Section 13)
|
### KV Cache Quantization (Section 13)
|
||||||
|
|
||||||
- [ ] Round-trip: quantize(8-bit) → dequantize → K/V tensors close to originals (max error < 1%)
|
- [x] Round-trip: quantize(8-bit) → dequantize → K/V tensors close to originals (validated with synthetic caches and real model cache structure)
|
||||||
- [ ] Memory: quantized entry uses ~50% of FP16 memory (check estimateBytes before/after)
|
- [x] Memory: quantized entry uses ~50% of FP16 memory (check estimateBytes before/after)
|
||||||
- [ ] Short sequences: entries below `minTokens` threshold are NOT quantized
|
- [x] Short sequences: entries below `minTokens` threshold are NOT quantized
|
||||||
- [ ] Disabled by default: `QuantizationConfig.default.enabled == false`
|
- [x] Disabled by default: `QuantizationConfig.default.enabled == false`
|
||||||
- [ ] Store path: quantization happens after trim-to-offset, before memory estimation
|
- [x] Store path: quantization happens after trim-to-offset, before memory estimation
|
||||||
- [ ] Lookup path: dequantization happens before returning cache to caller
|
- [x] Lookup path: dequantization happens before returning cache to caller
|
||||||
- [ ] Non-standard layers: hybrid model layers (non-trimmable) passed through unquantized
|
- [x] Non-standard layers: hybrid model layers (non-trimmable) passed through unquantized
|
||||||
- [ ] Generation quality: quantized-then-dequantized cache produces coherent output (manual check)
|
- [x] Generation quality: quantized-then-dequantized cache produces coherent output (validated by model-backed cache-hit generation test)
|
||||||
- [ ] Supersequence + quantized: must dequantize before trimming (QuantizedKVCacheWrapper.isTrimmable == false)
|
- [x] Supersequence + quantized: must dequantize before trimming (QuantizedKVCacheWrapper.isTrimmable == false)
|
||||||
- [ ] Preferences: toggle works, changes take effect on next store (existing entries not re-quantized)
|
- [x] Preferences: toggle works, changes take effect on next store (existing entries not re-quantized)
|
||||||
|
|
||||||
### Thinking Mode
|
### Thinking Mode
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user