feat: phase 6 implemented and tested
This commit is contained in:
251
MLXServerTests/Server/ModelBackedQuantizationTests.swift
Normal file
251
MLXServerTests/Server/ModelBackedQuantizationTests.swift
Normal file
@@ -0,0 +1,251 @@
|
||||
import Foundation
|
||||
import Hub
|
||||
import MLX
|
||||
import MLXLMCommon
|
||||
import MLXVLM
|
||||
import XCTest
|
||||
@testable import MLX_Server
|
||||
|
||||
final class ModelBackedQuantizationTests: XCTestCase {
|
||||
func testQuantizedLookupRoundTripPreservesRealModelCache() async throws {
|
||||
let container = try await localGemmaContainer()
|
||||
let engine = InferenceEngine(container: container)
|
||||
let input = quantizationPrompt()
|
||||
let prepared = try await engine.prepare(input)
|
||||
|
||||
let workingCache = try await generatePromptCache(
|
||||
engine: engine,
|
||||
prepared: prepared,
|
||||
maxTokens: 1
|
||||
)
|
||||
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: 1_000_000_000,
|
||||
quantizationConfig: .init(enabled: true, bits: 8, groupSize: 64, minTokens: 1)
|
||||
)
|
||||
cache.store(
|
||||
entryId: UUID(),
|
||||
kvCache: workingCache,
|
||||
cacheKey: prepared.tokens,
|
||||
modelId: "gemma"
|
||||
)
|
||||
|
||||
let lease = cache.lookup(cacheKey: prepared.tokens, modelId: "gemma")
|
||||
let roundTripped = try XCTUnwrap(lease.kvCache)
|
||||
|
||||
XCTAssertTrue(lease.isHit)
|
||||
XCTAssertFalse(roundTripped.isEmpty)
|
||||
XCTAssertFalse(roundTripped.contains { $0 is QuantizedKVCache })
|
||||
XCTAssertEqual(workingCache.count, roundTripped.count)
|
||||
|
||||
for (original, returned) in zip(workingCache, roundTripped) {
|
||||
XCTAssertEqual(original.offset, returned.offset)
|
||||
XCTAssertEqual(original.state.count, returned.state.count)
|
||||
for (lhs, rhs) in zip(original.state, returned.state) {
|
||||
XCTAssertEqual(lhs.shape, rhs.shape)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testQuantizedCacheHitProducesUsableDeterministicResponseAndAdvancesCacheLikeUnquantizedHit() async throws {
|
||||
let container = try await localGemmaContainer()
|
||||
let engine = InferenceEngine(container: container)
|
||||
let input = quantizationPrompt()
|
||||
let prepared = try await engine.prepare(input)
|
||||
|
||||
let promptCache = try await generatePromptCache(
|
||||
engine: engine,
|
||||
prepared: prepared,
|
||||
maxTokens: 1
|
||||
)
|
||||
|
||||
let unquantizedCache = TokenPrefixCache(
|
||||
memoryBudgetBytes: 1_000_000_000,
|
||||
quantizationConfig: .default
|
||||
)
|
||||
let quantizedCache = TokenPrefixCache(
|
||||
memoryBudgetBytes: 1_000_000_000,
|
||||
quantizationConfig: .init(enabled: true, bits: 8, groupSize: 64, minTokens: 1)
|
||||
)
|
||||
|
||||
unquantizedCache.store(
|
||||
entryId: UUID(),
|
||||
kvCache: promptCache,
|
||||
cacheKey: prepared.tokens,
|
||||
modelId: "gemma"
|
||||
)
|
||||
quantizedCache.store(
|
||||
entryId: UUID(),
|
||||
kvCache: promptCache,
|
||||
cacheKey: prepared.tokens,
|
||||
modelId: "gemma"
|
||||
)
|
||||
|
||||
let unquantizedLease = unquantizedCache.lookup(cacheKey: prepared.tokens, modelId: "gemma")
|
||||
let quantizedLease = quantizedCache.lookup(cacheKey: prepared.tokens, modelId: "gemma")
|
||||
|
||||
XCTAssertTrue(unquantizedLease.isHit)
|
||||
XCTAssertTrue(quantizedLease.isHit)
|
||||
XCTAssertEqual(unquantizedLease.matchedTokenCount, prepared.tokens.count)
|
||||
XCTAssertEqual(quantizedLease.matchedTokenCount, prepared.tokens.count)
|
||||
|
||||
let parameters = GenerateParameters(maxTokens: 4, temperature: 0)
|
||||
let unquantizedHandle = try await engine.stream(
|
||||
InferenceEngine.InferenceRequest(
|
||||
input: prepared.lmInput,
|
||||
tokens: prepared.tokens,
|
||||
parameters: parameters,
|
||||
cachedKV: unquantizedLease.kvCache,
|
||||
cachedTokenCount: unquantizedLease.matchedTokenCount
|
||||
),
|
||||
cancellation: CancellationToken()
|
||||
)
|
||||
|
||||
let unquantizedText = await collectText(unquantizedHandle.stream)
|
||||
XCTAssertFalse(unquantizedText.isEmpty)
|
||||
|
||||
let quantizedHandle = try await engine.stream(
|
||||
InferenceEngine.InferenceRequest(
|
||||
input: prepared.lmInput,
|
||||
tokens: prepared.tokens,
|
||||
parameters: parameters,
|
||||
cachedKV: quantizedLease.kvCache,
|
||||
cachedTokenCount: quantizedLease.matchedTokenCount
|
||||
),
|
||||
cancellation: CancellationToken()
|
||||
)
|
||||
let quantizedText = await collectText(quantizedHandle.stream)
|
||||
XCTAssertFalse(quantizedText.isEmpty)
|
||||
|
||||
XCTAssertEqual(unquantizedHandle.workingCache.count, quantizedHandle.workingCache.count)
|
||||
for (lhs, rhs) in zip(unquantizedHandle.workingCache, quantizedHandle.workingCache) {
|
||||
XCTAssertLessThanOrEqual(abs(lhs.offset - rhs.offset), 1)
|
||||
XCTAssertEqual(lhs.state.count, rhs.state.count)
|
||||
for (lhsState, rhsState) in zip(lhs.state, rhs.state) {
|
||||
XCTAssertEqual(lhsState.shape.count, rhsState.shape.count)
|
||||
if lhsState.shape.count == 4 {
|
||||
XCTAssertEqual(lhsState.shape[0], rhsState.shape[0])
|
||||
XCTAssertEqual(lhsState.shape[1], rhsState.shape[1])
|
||||
XCTAssertLessThanOrEqual(abs(lhsState.shape[2] - rhsState.shape[2]), 1)
|
||||
XCTAssertEqual(lhsState.shape[3], rhsState.shape[3])
|
||||
} else {
|
||||
XCTAssertEqual(lhsState.shape, rhsState.shape)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testPreferencesIntegrationWithQuantization() throws {
|
||||
Preferences.kvQuantizationEnabled = true
|
||||
Preferences.kvQuantizationBits = 8
|
||||
|
||||
XCTAssertTrue(Preferences.kvQuantizationEnabled)
|
||||
XCTAssertEqual(Preferences.kvQuantizationBits, 8)
|
||||
|
||||
Preferences.kvQuantizationBits = 2
|
||||
XCTAssertGreaterThanOrEqual(Preferences.kvQuantizationBits, 4)
|
||||
|
||||
Preferences.kvQuantizationBits = 32
|
||||
XCTAssertLessThanOrEqual(Preferences.kvQuantizationBits, 16)
|
||||
|
||||
Preferences.kvQuantizationEnabled = false
|
||||
Preferences.kvQuantizationBits = 8
|
||||
}
|
||||
|
||||
private func quantizationPrompt() -> UserInput {
|
||||
UserInput(
|
||||
prompt: .chat([
|
||||
Chat.Message(role: .system, content: "You are terse and deterministic."),
|
||||
Chat.Message(role: .user, content: String(repeating: "cache reuse test ", count: 48))
|
||||
]),
|
||||
images: [],
|
||||
videos: [],
|
||||
tools: nil
|
||||
)
|
||||
}
|
||||
|
||||
private func generatePromptCache(
|
||||
engine: InferenceEngine,
|
||||
prepared: InferenceEngine.PreparedInference,
|
||||
maxTokens: Int
|
||||
) async throws -> [KVCache] {
|
||||
let handle = try await engine.stream(
|
||||
InferenceEngine.InferenceRequest(
|
||||
input: prepared.lmInput,
|
||||
tokens: prepared.tokens,
|
||||
parameters: GenerateParameters(maxTokens: maxTokens, temperature: 0),
|
||||
cachedKV: nil,
|
||||
cachedTokenCount: 0
|
||||
),
|
||||
cancellation: CancellationToken()
|
||||
)
|
||||
|
||||
_ = await collectText(handle.stream)
|
||||
trimCacheToPrompt(handle.workingCache, promptTokenCount: prepared.tokens.count)
|
||||
return handle.workingCache
|
||||
}
|
||||
|
||||
private func collectText(_ stream: AsyncStream<Generation>) async -> String {
|
||||
var text = ""
|
||||
for await generation in stream {
|
||||
if case .chunk(let chunk) = generation {
|
||||
text += chunk
|
||||
}
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
private func trimCacheToPrompt(_ cache: [KVCache], promptTokenCount: Int) {
|
||||
for layer in cache {
|
||||
let excess = layer.offset - promptTokenCount
|
||||
if excess > 0 {
|
||||
XCTAssertTrue(layer.isTrimmable)
|
||||
XCTAssertEqual(layer.trim(excess), excess)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func localGemmaContainer() async throws -> ModelContainer {
|
||||
try await LocalGemmaFixture.shared.container()
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - LocalGemmaFixture
|
||||
|
||||
private actor LocalGemmaFixture {
|
||||
static let shared = LocalGemmaFixture()
|
||||
|
||||
private var task: Task<ModelContainer, Error>?
|
||||
|
||||
func container() async throws -> ModelContainer {
|
||||
if let task {
|
||||
return try await task.value
|
||||
}
|
||||
|
||||
guard let config = ModelConfig.resolve("gemma") else {
|
||||
throw XCTSkip("Gemma model config is unavailable")
|
||||
}
|
||||
guard let localDir = LocalModelResolver.resolve(repoId: config.repoId) else {
|
||||
throw XCTSkip("Local gemma cache is unavailable")
|
||||
}
|
||||
|
||||
let loadTask = Task<ModelContainer, Error> {
|
||||
let cachesDir = FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first
|
||||
let hub = HubApi(downloadBase: cachesDir, cache: nil)
|
||||
return try await VLMModelFactory.shared.loadContainer(
|
||||
hub: hub,
|
||||
configuration: ModelConfiguration(directory: localDir),
|
||||
progressHandler: { _ in }
|
||||
)
|
||||
}
|
||||
task = loadTask
|
||||
|
||||
do {
|
||||
return try await loadTask.value
|
||||
} catch {
|
||||
task = nil
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
252
MLXServerTests/Server/TokenPrefixCacheQuantizationTests.swift
Normal file
252
MLXServerTests/Server/TokenPrefixCacheQuantizationTests.swift
Normal file
@@ -0,0 +1,252 @@
|
||||
import Foundation
|
||||
import MLX
|
||||
import MLXLMCommon
|
||||
import XCTest
|
||||
@testable import MLX_Server
|
||||
|
||||
final class TokenPrefixCacheQuantizationTests: XCTestCase {
|
||||
func testQuantizationConfigDefault() {
|
||||
let config = TokenPrefixCache.QuantizationConfig.default
|
||||
XCTAssertFalse(config.enabled)
|
||||
XCTAssertEqual(config.bits, 8)
|
||||
XCTAssertEqual(config.groupSize, 64)
|
||||
XCTAssertEqual(config.minTokens, 256)
|
||||
}
|
||||
|
||||
func testQuantizationReducesStoredMemoryAndTracksSavings() {
|
||||
let rawCache = [makeSimpleCache(tokenCount: 320, heads: 4, headDim: 64)]
|
||||
let rawBytes = estimateBytes(rawCache)
|
||||
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: rawBytes * 2,
|
||||
quantizationConfig: .aggressive
|
||||
)
|
||||
|
||||
cache.store(
|
||||
entryId: UUID(),
|
||||
kvCache: rawCache,
|
||||
cacheKey: Array(1...320),
|
||||
modelId: "model"
|
||||
)
|
||||
|
||||
let snapshot = cache.snapshot()
|
||||
|
||||
XCTAssertTrue(snapshot.quantizationEnabled)
|
||||
XCTAssertGreaterThan(snapshot.quantizationBytesSaved, 0)
|
||||
XCTAssertLessThan(snapshot.estimatedBytes, rawBytes)
|
||||
XCTAssertLessThan(Double(snapshot.estimatedBytes) / Double(rawBytes), 0.80)
|
||||
}
|
||||
|
||||
func testShortSequencesBelowThresholdRemainUnquantized() throws {
|
||||
let rawCache = [makeSimpleCache(tokenCount: 32)]
|
||||
let rawBytes = estimateBytes(rawCache)
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: rawBytes * 2,
|
||||
quantizationConfig: .aggressive
|
||||
)
|
||||
|
||||
cache.store(
|
||||
entryId: UUID(),
|
||||
kvCache: rawCache,
|
||||
cacheKey: Array(1...32),
|
||||
modelId: "model"
|
||||
)
|
||||
|
||||
let snapshot = cache.snapshot()
|
||||
XCTAssertEqual(snapshot.quantizationBytesSaved, 0)
|
||||
XCTAssertEqual(snapshot.estimatedBytes, rawBytes)
|
||||
|
||||
let lease = cache.lookup(cacheKey: Array(1...32), modelId: "model")
|
||||
let returned = try XCTUnwrap(lease.kvCache)
|
||||
XCTAssertTrue(returned.allSatisfy { $0 is KVCacheSimple })
|
||||
XCTAssertFalse(returned.contains { $0 is QuantizedKVCache })
|
||||
}
|
||||
|
||||
func testQuantizedExactHitReturnsDequantizedCacheCloseToOriginal() throws {
|
||||
let rawCache = [makeSimpleCache(tokenCount: 300)]
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: estimateBytes(rawCache) * 2,
|
||||
quantizationConfig: .aggressive
|
||||
)
|
||||
|
||||
cache.store(
|
||||
entryId: UUID(),
|
||||
kvCache: rawCache,
|
||||
cacheKey: Array(1...300),
|
||||
modelId: "model"
|
||||
)
|
||||
|
||||
let lease = cache.lookup(cacheKey: Array(1...300), modelId: "model")
|
||||
let returned = try XCTUnwrap(lease.kvCache)
|
||||
|
||||
XCTAssertTrue(lease.isHit)
|
||||
XCTAssertTrue(returned.allSatisfy { $0 is KVCacheSimple })
|
||||
XCTAssertFalse(returned.contains { $0 is QuantizedKVCache })
|
||||
XCTAssertEqual(returned.count, rawCache.count)
|
||||
|
||||
for (original, roundTripped) in zip(rawCache, returned) {
|
||||
XCTAssertEqual(original.offset, roundTripped.offset)
|
||||
XCTAssertLessThanOrEqual(maxRelativeError(original.state[0], roundTripped.state[0]), 0.02)
|
||||
XCTAssertLessThanOrEqual(maxRelativeError(original.state[1], roundTripped.state[1]), 0.02)
|
||||
}
|
||||
}
|
||||
|
||||
func testNonStandardLayersPassThroughUnquantized() throws {
|
||||
let nonStandard = NonStandardCache(tokenCount: 300, headDim: 32)
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: estimateBytes([nonStandard]) * 2,
|
||||
quantizationConfig: .aggressive
|
||||
)
|
||||
|
||||
cache.store(
|
||||
entryId: UUID(),
|
||||
kvCache: [nonStandard],
|
||||
cacheKey: Array(1...300),
|
||||
modelId: "model"
|
||||
)
|
||||
|
||||
let snapshot = cache.snapshot()
|
||||
XCTAssertEqual(snapshot.quantizationBytesSaved, 0)
|
||||
|
||||
let lease = cache.lookup(cacheKey: Array(1...300), modelId: "model")
|
||||
let returned = try XCTUnwrap(lease.kvCache)
|
||||
XCTAssertEqual(returned.count, 1)
|
||||
XCTAssertTrue(returned[0] is NonStandardCache)
|
||||
}
|
||||
|
||||
func testQuantizedSupersequenceHitReturnsDequantizedTrimmedCache() throws {
|
||||
let rawCache = [makeSimpleCache(tokenCount: 300)]
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: estimateBytes(rawCache) * 2,
|
||||
quantizationConfig: .aggressive
|
||||
)
|
||||
|
||||
cache.store(
|
||||
entryId: UUID(),
|
||||
kvCache: rawCache,
|
||||
cacheKey: Array(1...300),
|
||||
modelId: "model"
|
||||
)
|
||||
|
||||
let lease = cache.lookup(cacheKey: Array(1...260), modelId: "model")
|
||||
let returned = try XCTUnwrap(lease.kvCache)
|
||||
|
||||
XCTAssertTrue(lease.isHit)
|
||||
XCTAssertEqual(lease.matchedTokenCount, 260)
|
||||
XCTAssertTrue(returned.allSatisfy { $0 is KVCacheSimple })
|
||||
for layer in returned {
|
||||
XCTAssertEqual(layer.offset, 260)
|
||||
}
|
||||
}
|
||||
|
||||
func testQuantizationConfigChangesOnlyAffectFutureStores() {
|
||||
let firstCache = [makeSimpleCache(tokenCount: 300)]
|
||||
let secondCache = [makeSimpleCache(tokenCount: 300, base: 10_000)]
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: estimateBytes(firstCache) * 4,
|
||||
quantizationConfig: .default
|
||||
)
|
||||
|
||||
cache.store(
|
||||
entryId: UUID(),
|
||||
kvCache: firstCache,
|
||||
cacheKey: Array(1...300),
|
||||
modelId: "model"
|
||||
)
|
||||
let before = cache.snapshot()
|
||||
XCTAssertEqual(before.quantizationBytesSaved, 0)
|
||||
|
||||
cache.setQuantizationConfig(.aggressive)
|
||||
let toggled = cache.snapshot()
|
||||
XCTAssertTrue(toggled.quantizationEnabled)
|
||||
XCTAssertEqual(toggled.quantizationBytesSaved, 0)
|
||||
|
||||
cache.store(
|
||||
entryId: UUID(),
|
||||
kvCache: secondCache,
|
||||
cacheKey: Array(1001...1300),
|
||||
modelId: "model"
|
||||
)
|
||||
|
||||
let after = cache.snapshot()
|
||||
XCTAssertGreaterThan(after.quantizationBytesSaved, 0)
|
||||
XCTAssertGreaterThan(after.totalEntries, 1)
|
||||
}
|
||||
|
||||
private func makeSimpleCache(tokenCount: Int, heads: Int = 2, headDim: Int = 64, base: Int = 0)
|
||||
-> KVCacheSimple
|
||||
{
|
||||
let count = heads * tokenCount * headDim
|
||||
let keyValues = (0..<count).map { index in
|
||||
Float(base + index) / Float(max(count - 1, 1)) * 2 - 1
|
||||
}
|
||||
let valueValues = keyValues.reversed()
|
||||
let keys = MLXArray(keyValues, [1, heads, tokenCount, headDim])
|
||||
let values = MLXArray(Array(valueValues), [1, heads, tokenCount, headDim])
|
||||
let cache = KVCacheSimple()
|
||||
cache.state = [keys, values]
|
||||
MLX.eval(cache.state)
|
||||
return cache
|
||||
}
|
||||
|
||||
private func estimateBytes(_ cache: [KVCache]) -> Int {
|
||||
max(cache.flatMap(\.state).reduce(0) { $0 + $1.nbytes }, 1024)
|
||||
}
|
||||
|
||||
private func maxRelativeError(_ lhs: MLXArray, _ rhs: MLXArray) -> Float {
|
||||
let left = lhs.asArray(Float.self)
|
||||
let right = rhs.asArray(Float.self)
|
||||
XCTAssertEqual(left.count, right.count)
|
||||
|
||||
var maximum: Float = 0
|
||||
for (l, r) in zip(left, right) {
|
||||
let denominator = max(abs(l), 1e-6)
|
||||
maximum = max(maximum, abs(l - r) / denominator)
|
||||
}
|
||||
return maximum
|
||||
}
|
||||
}
|
||||
|
||||
private final class NonStandardCache: KVCache {
|
||||
private var arrays: [MLXArray]
|
||||
var offset: Int
|
||||
let maxSize: Int? = nil
|
||||
|
||||
init(tokenCount: Int, headDim: Int) {
|
||||
let count = tokenCount * headDim
|
||||
let values = (0..<count).map { Float($0) / Float(max(count - 1, 1)) }
|
||||
self.arrays = [MLXArray(values, [1, 1, tokenCount, headDim])]
|
||||
self.offset = tokenCount
|
||||
}
|
||||
|
||||
func innerState() -> [MLXArray] {
|
||||
arrays
|
||||
}
|
||||
|
||||
var state: [MLXArray] {
|
||||
get { arrays }
|
||||
set { arrays = newValue }
|
||||
}
|
||||
|
||||
var metaState: [String] {
|
||||
get { [String(offset)] }
|
||||
set { offset = Int(newValue.first ?? "0") ?? 0 }
|
||||
}
|
||||
|
||||
var isTrimmable: Bool { false }
|
||||
|
||||
func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) {
|
||||
fatalError("NonStandardCache is test-only and does not support update")
|
||||
}
|
||||
|
||||
@discardableResult
|
||||
func trim(_ n: Int) -> Int { 0 }
|
||||
|
||||
func makeMask(
|
||||
n: Int,
|
||||
windowSize: Int?,
|
||||
returnArray: Bool
|
||||
) -> MLXFast.ScaledDotProductAttentionMaskMode {
|
||||
.none
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user