feat: finished all open things up to and including phase 6

This commit is contained in:
2026-03-21 08:41:13 +01:00
parent 0325fa8964
commit 107ac0524b
9 changed files with 457 additions and 33 deletions

View File

@@ -1,4 +1,5 @@
import Foundation
import MLX
import XCTest
import MLXLMCommon
@testable import MLX_Server
@@ -225,6 +226,96 @@ final class TokenPrefixCacheTests: XCTestCase {
XCTAssertEqual(snapshot.lcpHits, 0)
}
func testSupersequenceSkipsNonTrimmableLayersGracefully() {
let cache = TokenPrefixCache(
memoryBudgetBytes: 10_000,
estimateBytesProvider: { _ in 1_024 }
)
let layer = TestTrimRecordingCache(offset: 4, trimmable: false)
cache.store(entryId: UUID(), kvCache: [layer], cacheKey: [1, 2, 3, 4], modelId: "model")
let lease = cache.lookup(cacheKey: [1, 2, 3], modelId: "model")
let snapshot = cache.snapshot()
XCTAssertFalse(lease.isHit)
XCTAssertEqual(layer.offset, 4)
XCTAssertTrue(layer.trimCalls.isEmpty)
XCTAssertEqual(snapshot.supersequenceHits, 0)
XCTAssertEqual(snapshot.totalMisses, 1)
}
func testSupersequenceChoosesShallowestCandidate() {
let cache = TokenPrefixCache(
memoryBudgetBytes: 10_000,
estimateBytesProvider: { _ in 1_024 }
)
let shallowestId = UUID()
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3, 4, 5], modelId: "model")
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3, 4], modelId: "model")
cache.store(entryId: shallowestId, kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
let lease = cache.lookup(cacheKey: [1, 2], modelId: "model")
XCTAssertTrue(lease.isHit)
XCTAssertEqual(lease.entryId, shallowestId)
XCTAssertEqual(lease.matchedTokenCount, 2)
}
func testSupersequencePathWinsWhenFullQueryWalkCanAlsoSeeDivergentSibling() {
let cache = TokenPrefixCache(
memoryBudgetBytes: 10_000,
estimateBytesProvider: { _ in 1_024 }
)
let supersequenceId = UUID()
cache.store(entryId: supersequenceId, kvCache: [], cacheKey: [1, 2, 3], modelId: "model")
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 9, 8], modelId: "model")
let lease = cache.lookup(cacheKey: [1, 2], modelId: "model")
let snapshot = cache.snapshot()
XCTAssertTrue(lease.isHit)
XCTAssertEqual(lease.entryId, supersequenceId)
XCTAssertEqual(snapshot.supersequenceHits, 1)
XCTAssertEqual(snapshot.lcpHits, 0)
}
func testLCPChoosesShallowestSiblingCandidate() {
let cache = TokenPrefixCache(
memoryBudgetBytes: 10_000,
estimateBytesProvider: { _ in 1_024 }
)
let shallowestId = UUID()
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 3, 7], modelId: "model")
cache.store(entryId: UUID(), kvCache: [], cacheKey: [1, 2, 4, 7, 8], modelId: "model")
cache.store(entryId: shallowestId, kvCache: [], cacheKey: [1, 2, 5], modelId: "model")
let lease = cache.lookup(cacheKey: [1, 2, 9, 9], modelId: "model")
XCTAssertTrue(lease.isHit)
XCTAssertEqual(lease.entryId, shallowestId)
XCTAssertEqual(lease.matchedTokenCount, 2)
}
func testTrimUsesExactExcessAndReducesOffset() {
let cache = TokenPrefixCache(
memoryBudgetBytes: 10_000,
estimateBytesProvider: { _ in 1_024 }
)
let layer = TestTrimRecordingCache(offset: 5, trimmable: true)
cache.store(entryId: UUID(), kvCache: [layer], cacheKey: [1, 2, 3, 4, 5], modelId: "model")
let lease = cache.lookup(cacheKey: [1, 2, 3], modelId: "model")
XCTAssertTrue(lease.isHit)
XCTAssertEqual(layer.trimCalls, [2])
XCTAssertEqual(layer.offset, 3)
}
func testComputeMemoryBudgetUsesFallbackWhenDeviceUnavailable() {
let budget = TokenPrefixCache.computeMemoryBudget(recommendedWorkingSetSize: nil)
@@ -248,4 +339,53 @@ final class TokenPrefixCacheTests: XCTestCase {
XCTAssertEqual(budget, 8 * 1024 * 1024 * 1024)
}
}
private final class TestTrimRecordingCache: KVCache {
private var arrays: [MLXArray] = []
var offset: Int
let maxSize: Int? = nil
let trimmable: Bool
private(set) var trimCalls: [Int] = []
init(offset: Int, trimmable: Bool) {
self.offset = offset
self.trimmable = trimmable
}
func innerState() -> [MLXArray] {
arrays
}
var state: [MLXArray] {
get { arrays }
set { arrays = newValue }
}
var metaState: [String] {
get { [String(offset)] }
set { offset = Int(newValue.first ?? "0") ?? 0 }
}
var isTrimmable: Bool { trimmable }
func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) {
fatalError("TestTrimRecordingCache does not support update")
}
@discardableResult
func trim(_ n: Int) -> Int {
guard trimmable else { return 0 }
trimCalls.append(n)
offset = max(0, offset - n)
return n
}
func makeMask(
n: Int,
windowSize: Int?,
returnArray: Bool
) -> MLXFast.ScaledDotProductAttentionMaskMode {
.none
}
}