feat: finished all open things up to and including phase 6
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import Foundation
|
||||
import MLX
|
||||
import XCTest
|
||||
import MLXLMCommon
|
||||
@testable import MLX_Server
|
||||
@@ -225,6 +226,96 @@ final class TokenPrefixCacheTests: XCTestCase {
|
||||
XCTAssertEqual(snapshot.lcpHits, 0)
|
||||
}
|
||||
|
||||
func testSupersequenceSkipsNonTrimmableLayersGracefully() {
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: 10_000,
|
||||
estimateBytesProvider: { _ in 1_024 }
|
||||
)
|
||||
|
||||
let layer = TestTrimRecordingCache(offset: 4, trimmable: false)
|
||||
cache.store(entryId: UUID(), kvCache: [layer], cacheKey: [1, 2, 3, 4], modelId: "model")
|
||||
|
||||
let lease = cache.lookup(cacheKey: [1, 2, 3], modelId: "model")
|
||||
let snapshot = cache.snapshot()
|
||||
|
||||
XCTAssertFalse(lease.isHit)
|
||||
XCTAssertEqual(layer.offset, 4)
|
||||
XCTAssertTrue(layer.trimCalls.isEmpty)
|
||||
XCTAssertEqual(snapshot.supersequenceHits, 0)
|
||||
XCTAssertEqual(snapshot.totalMisses, 1)
|
||||
}
|
||||
|
||||
func testSupersequenceChoosesShallowestCandidate() {
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: 10_000,
|
||||
estimateBytesProvider: { _ in 1_024 }
|
||||
)
|
||||
|
||||
let shallowestId = UUID()
|
||||
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3, 4, 5], modelId: "model")
|
||||
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3, 4], modelId: "model")
|
||||
cache.store(entryId: shallowestId, kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
|
||||
|
||||
let lease = cache.lookup(cacheKey: [1, 2], modelId: "model")
|
||||
|
||||
XCTAssertTrue(lease.isHit)
|
||||
XCTAssertEqual(lease.entryId, shallowestId)
|
||||
XCTAssertEqual(lease.matchedTokenCount, 2)
|
||||
}
|
||||
|
||||
func testSupersequencePathWinsWhenFullQueryWalkCanAlsoSeeDivergentSibling() {
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: 10_000,
|
||||
estimateBytesProvider: { _ in 1_024 }
|
||||
)
|
||||
|
||||
let supersequenceId = UUID()
|
||||
cache.store(entryId: supersequenceId, kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
|
||||
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 9, 8], modelId: "model")
|
||||
|
||||
let lease = cache.lookup(cacheKey: [1, 2], modelId: "model")
|
||||
let snapshot = cache.snapshot()
|
||||
|
||||
XCTAssertTrue(lease.isHit)
|
||||
XCTAssertEqual(lease.entryId, supersequenceId)
|
||||
XCTAssertEqual(snapshot.supersequenceHits, 1)
|
||||
XCTAssertEqual(snapshot.lcpHits, 0)
|
||||
}
|
||||
|
||||
func testLCPChoosesShallowestSiblingCandidate() {
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: 10_000,
|
||||
estimateBytesProvider: { _ in 1_024 }
|
||||
)
|
||||
|
||||
let shallowestId = UUID()
|
||||
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3, 7], modelId: "model")
|
||||
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 4, 7, 8], modelId: "model")
|
||||
cache.store(entryId: shallowestId, kvCache: [], cacheKey: [1, 2, 5], modelId: "model")
|
||||
|
||||
let lease = cache.lookup(cacheKey: [1, 2, 9, 9], modelId: "model")
|
||||
|
||||
XCTAssertTrue(lease.isHit)
|
||||
XCTAssertEqual(lease.entryId, shallowestId)
|
||||
XCTAssertEqual(lease.matchedTokenCount, 2)
|
||||
}
|
||||
|
||||
func testTrimUsesExactExcessAndReducesOffset() {
|
||||
let cache = TokenPrefixCache(
|
||||
memoryBudgetBytes: 10_000,
|
||||
estimateBytesProvider: { _ in 1_024 }
|
||||
)
|
||||
|
||||
let layer = TestTrimRecordingCache(offset: 5, trimmable: true)
|
||||
cache.store(entryId: UUID(), kvCache: [layer], cacheKey: [1, 2, 3, 4, 5], modelId: "model")
|
||||
|
||||
let lease = cache.lookup(cacheKey: [1, 2, 3], modelId: "model")
|
||||
|
||||
XCTAssertTrue(lease.isHit)
|
||||
XCTAssertEqual(layer.trimCalls, [2])
|
||||
XCTAssertEqual(layer.offset, 3)
|
||||
}
|
||||
|
||||
func testComputeMemoryBudgetUsesFallbackWhenDeviceUnavailable() {
|
||||
let budget = TokenPrefixCache.computeMemoryBudget(recommendedWorkingSetSize: nil)
|
||||
|
||||
@@ -248,4 +339,53 @@ final class TokenPrefixCacheTests: XCTestCase {
|
||||
|
||||
XCTAssertEqual(budget, 8 * 1024 * 1024 * 1024)
|
||||
}
|
||||
}
|
||||
|
||||
private final class TestTrimRecordingCache: KVCache {
|
||||
private var arrays: [MLXArray] = []
|
||||
var offset: Int
|
||||
let maxSize: Int? = nil
|
||||
let trimmable: Bool
|
||||
private(set) var trimCalls: [Int] = []
|
||||
|
||||
init(offset: Int, trimmable: Bool) {
|
||||
self.offset = offset
|
||||
self.trimmable = trimmable
|
||||
}
|
||||
|
||||
func innerState() -> [MLXArray] {
|
||||
arrays
|
||||
}
|
||||
|
||||
var state: [MLXArray] {
|
||||
get { arrays }
|
||||
set { arrays = newValue }
|
||||
}
|
||||
|
||||
var metaState: [String] {
|
||||
get { [String(offset)] }
|
||||
set { offset = Int(newValue.first ?? "0") ?? 0 }
|
||||
}
|
||||
|
||||
var isTrimmable: Bool { trimmable }
|
||||
|
||||
func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) {
|
||||
fatalError("TestTrimRecordingCache does not support update")
|
||||
}
|
||||
|
||||
@discardableResult
|
||||
func trim(_ n: Int) -> Int {
|
||||
guard trimmable else { return 0 }
|
||||
trimCalls.append(n)
|
||||
offset = max(0, offset - n)
|
||||
return n
|
||||
}
|
||||
|
||||
func makeMask(
|
||||
n: Int,
|
||||
windowSize: Int?,
|
||||
returnArray: Bool
|
||||
) -> MLXFast.ScaledDotProductAttentionMaskMode {
|
||||
.none
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user