397 lines
14 KiB
Swift
397 lines
14 KiB
Swift
import Foundation
|
|
import MLX
|
|
import XCTest
|
|
import MLXLMCommon
|
|
@testable import MLX_Server
|
|
|
|
final class TokenPrefixCacheTests: XCTestCase {
|
|
func testStoreAndLookupRemovesCheckedOutEntry() {
|
|
var now = Date(timeIntervalSince1970: 100)
|
|
let cache = TokenPrefixCache(
|
|
memoryBudgetBytes: 10_000,
|
|
estimateBytesProvider: { _ in 1_024 },
|
|
nowProvider: { now }
|
|
)
|
|
|
|
let entryId = UUID()
|
|
cache.store(entryId: entryId, kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
|
|
|
|
XCTAssertEqual(cache.snapshot().totalEntries, 1)
|
|
|
|
let lease = cache.lookup(cacheKey: [1, 2, 3, 4], modelId: "model")
|
|
|
|
XCTAssertTrue(lease.isHit)
|
|
XCTAssertEqual(lease.entryId, entryId)
|
|
XCTAssertEqual(lease.matchedTokenCount, 3)
|
|
XCTAssertNotNil(lease.kvCache)
|
|
XCTAssertEqual(cache.snapshot().totalEntries, 0)
|
|
}
|
|
|
|
func testLookupPrefersDeepestPrefixMatch() {
|
|
var now = Date(timeIntervalSince1970: 100)
|
|
let cache = TokenPrefixCache(
|
|
memoryBudgetBytes: 10_000,
|
|
estimateBytesProvider: { _ in 1_024 },
|
|
nowProvider: { now }
|
|
)
|
|
|
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2], modelId: "model")
|
|
now.addTimeInterval(1)
|
|
let deepId = UUID()
|
|
cache.store(entryId: deepId, kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
|
|
|
|
let lease = cache.lookup(cacheKey: [1, 2, 3, 4], modelId: "model")
|
|
|
|
XCTAssertTrue(lease.isHit)
|
|
XCTAssertEqual(lease.entryId, deepId)
|
|
XCTAssertEqual(lease.matchedTokenCount, 3)
|
|
}
|
|
|
|
func testEvictsLeastRecentlyUsedEntryWhenOverBudget() {
|
|
var now = Date(timeIntervalSince1970: 100)
|
|
let cache = TokenPrefixCache(
|
|
memoryBudgetBytes: 2_048,
|
|
estimateBytesProvider: { _ in 1_024 },
|
|
nowProvider: { now }
|
|
)
|
|
|
|
let firstId = UUID()
|
|
cache.store(entryId: firstId, kvCache: [], cacheKey: [1], modelId: "model")
|
|
now.addTimeInterval(1)
|
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: [2], modelId: "model")
|
|
now.addTimeInterval(1)
|
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: [3], modelId: "model")
|
|
|
|
let firstLookup = cache.lookup(cacheKey: [1], modelId: "model")
|
|
let secondLookup = cache.lookup(cacheKey: [2], modelId: "model")
|
|
let thirdLookup = cache.lookup(cacheKey: [3], modelId: "model")
|
|
|
|
XCTAssertFalse(firstLookup.isHit)
|
|
XCTAssertTrue(secondLookup.isHit)
|
|
XCTAssertTrue(thirdLookup.isHit)
|
|
}
|
|
|
|
func testSnapshotPrunesExpiredEntries() {
|
|
var now = Date(timeIntervalSince1970: 100)
|
|
let cache = TokenPrefixCache(
|
|
memoryBudgetBytes: 10_000,
|
|
idleTTL: 5,
|
|
estimateBytesProvider: { _ in 1_024 },
|
|
nowProvider: { now }
|
|
)
|
|
|
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
|
|
XCTAssertEqual(cache.snapshot().totalEntries, 1)
|
|
|
|
now.addTimeInterval(10)
|
|
let snapshot = cache.snapshot()
|
|
|
|
XCTAssertEqual(snapshot.totalEntries, 0)
|
|
XCTAssertGreaterThanOrEqual(snapshot.totalEvictions, 1)
|
|
}
|
|
|
|
func testLookupPrunesTrieNodesForRemovedBranch() {
|
|
let cache = TokenPrefixCache(
|
|
memoryBudgetBytes: 10_000,
|
|
estimateBytesProvider: { _ in 1_024 }
|
|
)
|
|
|
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
|
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 4], modelId: "model")
|
|
|
|
XCTAssertEqual(cache.debugTrieNodeCount(), 5)
|
|
|
|
_ = cache.lookup(cacheKey: [1, 2, 3], modelId: "model")
|
|
|
|
XCTAssertEqual(cache.debugTrieNodeCount(), 4)
|
|
|
|
_ = cache.lookup(cacheKey: [1, 2, 4], modelId: "model")
|
|
|
|
XCTAssertEqual(cache.debugTrieNodeCount(), 1)
|
|
}
|
|
|
|
func testCheckoutHitDoesNotCountAsEviction() {
|
|
let cache = TokenPrefixCache(
|
|
memoryBudgetBytes: 10_000,
|
|
estimateBytesProvider: { _ in 1_024 }
|
|
)
|
|
|
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
|
|
|
|
let lease = cache.lookup(cacheKey: [1, 2, 3, 4], modelId: "model")
|
|
let snapshot = cache.snapshot()
|
|
|
|
XCTAssertTrue(lease.isHit)
|
|
XCTAssertEqual(snapshot.totalEvictions, 0)
|
|
}
|
|
|
|
func testSnapshotReportsHitRateAndTokenTotals() {
|
|
let cache = TokenPrefixCache(
|
|
memoryBudgetBytes: 10_000,
|
|
estimateBytesProvider: { _ in 2_048 }
|
|
)
|
|
|
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: [10, 20, 30], modelId: "model")
|
|
_ = cache.lookup(cacheKey: [10, 20, 30, 40], modelId: "model")
|
|
_ = cache.lookup(cacheKey: [99], modelId: "model")
|
|
|
|
let snapshot = cache.snapshot()
|
|
|
|
XCTAssertEqual(snapshot.totalHits, 1)
|
|
XCTAssertEqual(snapshot.totalMisses, 1)
|
|
XCTAssertEqual(snapshot.hitRate, 50, accuracy: 0.001)
|
|
XCTAssertEqual(snapshot.totalCachedTokens, 0)
|
|
XCTAssertEqual(snapshot.estimatedBytes, 0)
|
|
}
|
|
|
|
func testSupersequenceLookupReusesLongerEntryForShorterQuery() {
|
|
let cache = TokenPrefixCache(
|
|
memoryBudgetBytes: 10_000,
|
|
estimateBytesProvider: { _ in 1_024 }
|
|
)
|
|
|
|
let entryId = UUID()
|
|
cache.store(entryId: entryId, kvCache: [], cacheKey: [1, 2, 3, 4], modelId: "model")
|
|
|
|
let lease = cache.lookup(cacheKey: [1, 2, 3], modelId: "model")
|
|
let snapshot = cache.snapshot()
|
|
|
|
XCTAssertTrue(lease.isHit)
|
|
XCTAssertEqual(lease.entryId, entryId)
|
|
XCTAssertEqual(lease.matchedTokenCount, 3)
|
|
XCTAssertEqual(snapshot.totalHits, 1)
|
|
XCTAssertEqual(snapshot.supersequenceHits, 1)
|
|
XCTAssertEqual(snapshot.prefixHits, 0)
|
|
XCTAssertEqual(snapshot.lcpHits, 0)
|
|
}
|
|
|
|
func testLCPLookupReusesSharedPrefixAcrossDivergentSuffixes() {
|
|
let cache = TokenPrefixCache(
|
|
memoryBudgetBytes: 10_000,
|
|
estimateBytesProvider: { _ in 1_024 }
|
|
)
|
|
|
|
let entryId = UUID()
|
|
cache.store(entryId: entryId, kvCache: [], cacheKey: [10, 20, 90], modelId: "model")
|
|
|
|
let lease = cache.lookup(cacheKey: [10, 20, 30], modelId: "model")
|
|
let snapshot = cache.snapshot()
|
|
|
|
XCTAssertTrue(lease.isHit)
|
|
XCTAssertEqual(lease.entryId, entryId)
|
|
XCTAssertEqual(lease.matchedTokenCount, 2)
|
|
XCTAssertEqual(snapshot.totalHits, 1)
|
|
XCTAssertEqual(snapshot.lcpHits, 1)
|
|
XCTAssertEqual(snapshot.prefixHits, 0)
|
|
XCTAssertEqual(snapshot.supersequenceHits, 0)
|
|
}
|
|
|
|
func testLCPLookupRejectsShallowSharedPrefix() {
|
|
let cache = TokenPrefixCache(
|
|
memoryBudgetBytes: 10_000,
|
|
estimateBytesProvider: { _ in 1_024 }
|
|
)
|
|
|
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: [10, 20, 30, 40], modelId: "model")
|
|
|
|
let lease = cache.lookup(cacheKey: [10, 99, 98, 97], modelId: "model")
|
|
let snapshot = cache.snapshot()
|
|
|
|
XCTAssertFalse(lease.isHit)
|
|
XCTAssertEqual(lease.matchedTokenCount, 0)
|
|
XCTAssertEqual(snapshot.totalHits, 0)
|
|
XCTAssertEqual(snapshot.totalMisses, 1)
|
|
XCTAssertEqual(snapshot.lcpHits, 0)
|
|
}
|
|
|
|
func testLookupPrefersPrefixMatchOverSupersequenceAndLCP() {
|
|
let cache = TokenPrefixCache(
|
|
memoryBudgetBytes: 10_000,
|
|
estimateBytesProvider: { _ in 1_024 }
|
|
)
|
|
|
|
let prefixId = UUID()
|
|
cache.store(entryId: prefixId, kvCache: [], cacheKey: [7, 8], modelId: "model")
|
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: [7, 8, 9, 10], modelId: "model")
|
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: [7, 8, 11], modelId: "model")
|
|
|
|
let lease = cache.lookup(cacheKey: [7, 8, 12], modelId: "model")
|
|
let snapshot = cache.snapshot()
|
|
|
|
XCTAssertTrue(lease.isHit)
|
|
XCTAssertEqual(lease.entryId, prefixId)
|
|
XCTAssertEqual(lease.matchedTokenCount, 2)
|
|
XCTAssertEqual(snapshot.prefixHits, 1)
|
|
XCTAssertEqual(snapshot.supersequenceHits, 0)
|
|
XCTAssertEqual(snapshot.lcpHits, 0)
|
|
}
|
|
|
|
func testSupersequenceSkipsNonTrimmableLayersGracefully() {
|
|
let cache = TokenPrefixCache(
|
|
memoryBudgetBytes: 10_000,
|
|
estimateBytesProvider: { _ in 1_024 }
|
|
)
|
|
|
|
let layer = TestTrimRecordingCache(offset: 4, trimmable: false)
|
|
cache.store(entryId: UUID(), kvCache: [layer], cacheKey: [1, 2, 3, 4], modelId: "model")
|
|
|
|
let lease = cache.lookup(cacheKey: [1, 2, 3], modelId: "model")
|
|
let snapshot = cache.snapshot()
|
|
|
|
XCTAssertFalse(lease.isHit)
|
|
XCTAssertEqual(layer.offset, 4)
|
|
XCTAssertTrue(layer.trimCalls.isEmpty)
|
|
XCTAssertEqual(snapshot.supersequenceHits, 0)
|
|
XCTAssertEqual(snapshot.totalMisses, 1)
|
|
}
|
|
|
|
func testSupersequenceChoosesShallowestCandidate() {
|
|
let cache = TokenPrefixCache(
|
|
memoryBudgetBytes: 10_000,
|
|
estimateBytesProvider: { _ in 1_024 }
|
|
)
|
|
|
|
let shallowestId = UUID()
|
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3, 4, 5], modelId: "model")
|
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3, 4], modelId: "model")
|
|
cache.store(entryId: shallowestId, kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
|
|
|
|
let lease = cache.lookup(cacheKey: [1, 2], modelId: "model")
|
|
|
|
XCTAssertTrue(lease.isHit)
|
|
XCTAssertEqual(lease.entryId, shallowestId)
|
|
XCTAssertEqual(lease.matchedTokenCount, 2)
|
|
}
|
|
|
|
func testSupersequencePathWinsWhenFullQueryWalkCanAlsoSeeDivergentSibling() {
|
|
let cache = TokenPrefixCache(
|
|
memoryBudgetBytes: 10_000,
|
|
estimateBytesProvider: { _ in 1_024 }
|
|
)
|
|
|
|
let supersequenceId = UUID()
|
|
cache.store(entryId: supersequenceId, kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
|
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 9, 8], modelId: "model")
|
|
|
|
let lease = cache.lookup(cacheKey: [1, 2], modelId: "model")
|
|
let snapshot = cache.snapshot()
|
|
|
|
XCTAssertTrue(lease.isHit)
|
|
XCTAssertEqual(lease.entryId, supersequenceId)
|
|
XCTAssertEqual(snapshot.supersequenceHits, 1)
|
|
XCTAssertEqual(snapshot.lcpHits, 0)
|
|
}
|
|
|
|
func testLCPChoosesShallowestSiblingCandidate() {
|
|
let cache = TokenPrefixCache(
|
|
memoryBudgetBytes: 10_000,
|
|
estimateBytesProvider: { _ in 1_024 }
|
|
)
|
|
|
|
let shallowestId = UUID()
|
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3, 7], modelId: "model")
|
|
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 4, 7, 8], modelId: "model")
|
|
cache.store(entryId: shallowestId, kvCache: [], cacheKey: [1, 2, 5], modelId: "model")
|
|
|
|
let lease = cache.lookup(cacheKey: [1, 2, 9, 9], modelId: "model")
|
|
|
|
XCTAssertTrue(lease.isHit)
|
|
XCTAssertEqual(lease.entryId, shallowestId)
|
|
XCTAssertEqual(lease.matchedTokenCount, 2)
|
|
}
|
|
|
|
func testTrimUsesExactExcessAndReducesOffset() {
|
|
let cache = TokenPrefixCache(
|
|
memoryBudgetBytes: 10_000,
|
|
estimateBytesProvider: { _ in 1_024 }
|
|
)
|
|
|
|
let layer = TestTrimRecordingCache(offset: 5, trimmable: true)
|
|
cache.store(entryId: UUID(), kvCache: [layer], cacheKey: [1, 2, 3, 4, 5], modelId: "model")
|
|
|
|
let lease = cache.lookup(cacheKey: [1, 2, 3], modelId: "model")
|
|
|
|
XCTAssertTrue(lease.isHit)
|
|
XCTAssertEqual(layer.trimCalls, [2])
|
|
XCTAssertEqual(layer.offset, 3)
|
|
}
|
|
|
|
func testComputeMemoryBudgetUsesFallbackWhenDeviceUnavailable() {
|
|
let budget = TokenPrefixCache.computeMemoryBudget(recommendedWorkingSetSize: nil)
|
|
|
|
XCTAssertEqual(budget, 512 * 1024 * 1024)
|
|
}
|
|
|
|
func testComputeMemoryBudgetClampsToMinimumFloor() {
|
|
let budget = TokenPrefixCache.computeMemoryBudget(recommendedWorkingSetSize: 512 * 1024 * 1024)
|
|
|
|
XCTAssertEqual(budget, 256 * 1024 * 1024)
|
|
}
|
|
|
|
func testComputeMemoryBudgetUsesTwentyPercentOfWorkingSet() {
|
|
let budget = TokenPrefixCache.computeMemoryBudget(recommendedWorkingSetSize: 8 * 1024 * 1024 * 1024)
|
|
|
|
XCTAssertEqual(budget, Int(Double(8 * 1024 * 1024 * 1024) * 0.20))
|
|
}
|
|
|
|
func testComputeMemoryBudgetClampsToMaximumCap() {
|
|
let budget = TokenPrefixCache.computeMemoryBudget(recommendedWorkingSetSize: 80 * 1024 * 1024 * 1024)
|
|
|
|
XCTAssertEqual(budget, 8 * 1024 * 1024 * 1024)
|
|
}
|
|
}
|
|
|
|
private final class TestTrimRecordingCache: KVCache {
|
|
private var arrays: [MLXArray] = []
|
|
var offset: Int
|
|
let maxSize: Int? = nil
|
|
let trimmable: Bool
|
|
private(set) var trimCalls: [Int] = []
|
|
|
|
init(offset: Int, trimmable: Bool) {
|
|
self.offset = offset
|
|
self.trimmable = trimmable
|
|
}
|
|
|
|
func innerState() -> [MLXArray] {
|
|
arrays
|
|
}
|
|
|
|
var state: [MLXArray] {
|
|
get { arrays }
|
|
set { arrays = newValue }
|
|
}
|
|
|
|
var metaState: [String] {
|
|
get { [String(offset)] }
|
|
set { offset = Int(newValue.first ?? "0") ?? 0 }
|
|
}
|
|
|
|
var isTrimmable: Bool { trimmable }
|
|
|
|
func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) {
|
|
fatalError("TestTrimRecordingCache does not support update")
|
|
}
|
|
|
|
@discardableResult
|
|
func trim(_ n: Int) -> Int {
|
|
guard trimmable else { return 0 }
|
|
trimCalls.append(n)
|
|
offset = max(0, offset - n)
|
|
return n
|
|
}
|
|
|
|
func makeMask(
|
|
n: Int,
|
|
windowSize: Int?,
|
|
returnArray: Bool
|
|
) -> MLXFast.ScaledDotProductAttentionMaskMode {
|
|
.none
|
|
}
|
|
|
|
func copy() -> any KVCache {
|
|
let c = TestTrimRecordingCache(offset: offset, trimmable: trimmable)
|
|
c.state = state
|
|
return c
|
|
}
|
|
} |