fix: trying to do kv prefix caching

This commit is contained in:
2026-03-17 10:04:14 +01:00
parent 5bf170cedb
commit bdfbd14577

View File

@@ -151,25 +151,34 @@ def _get_param_names(func_name: str, tool_defs: dict[str, dict] | None) -> list[
class PromptCache:
"""Manages KV cache reuse across requests with shared prompt prefixes."""
"""Manages KV cache reuse across requests with shared prompt prefixes.
Gemma 3 uses a mix of KVCache (full attention every 6th layer) and
RotatingKVCache (sliding window, 1024 tokens). Since RotatingKVCache
cannot be safely trimmed mid-sequence, we only reuse the cache when
the ENTIRE cached token sequence is a prefix of the new prompt.
In multi-turn chat this is the common case: each new request extends
the previous prompt with the assistant response + new user message.
"""
def __init__(self):
self._cache = None
self._cache: list | None = None
self._cached_token_ids: list[int] | None = None
def get_reusable_length(self, new_token_ids: list[int]) -> int:
"""Find how many leading tokens match the cached prefix."""
"""Return cached length if the entire cache is a valid prefix, else 0."""
if self._cached_token_ids is None or self._cache is None:
return 0
max_match = min(len(self._cached_token_ids), len(new_token_ids))
match_len = 0
for i in range(max_match):
cached_len = len(self._cached_token_ids)
if cached_len > len(new_token_ids):
return 0
for i in range(cached_len):
if self._cached_token_ids[i] != new_token_ids[i]:
break
match_len = i + 1
return match_len
return 0
return cached_len
def update(self, cache, token_ids: list[int]) -> None:
def update(self, cache: list, token_ids: list[int]) -> None:
"""Store cache and the token IDs it was built from."""
self._cache = cache
self._cached_token_ids = list(token_ids)
@@ -431,13 +440,106 @@ class InferenceEngine:
return "\n".join(parts)
# ------------------------------------------------------------------
# Generation
# Prefix cache & generation
# ------------------------------------------------------------------
# Common kwargs for mlx_vlm generate calls
# Note: KV cache quantization is not supported with Gemma 3's RotatingKVCache
_GENERATE_KWARGS: dict = {}
def _get_tokenizer(self):
"""Get the underlying tokenizer from the processor."""
proc = self.processor
return proc.tokenizer if hasattr(proc, "tokenizer") else proc
def _prepare_generation(
self, prompt: str, images: list[str] | None = None
) -> dict:
"""Tokenize prompt, check prefix cache, return generation kwargs.
Returns a dict with keys:
input_ids, pixel_values, mask, prompt_cache,
_full_token_ids, _prompt_token_count
"""
from mlx_vlm.models import cache as cache_module
from mlx_vlm.utils import prepare_inputs
model_type = getattr(self.config, "model_type", "")
add_special_tokens = (
not hasattr(self.processor, "chat_template")
if model_type in ("gemma3", "gemma3n")
else True
)
image_token_index = getattr(self.model.config, "image_token_index", None)
# Tokenize the full prompt (+ process pixel values if images present)
inputs = prepare_inputs(
self.processor,
images=images if images else None,
prompts=prompt,
image_token_index=image_token_index,
add_special_tokens=add_special_tokens,
)
full_input_ids = inputs["input_ids"]
pixel_values = inputs.get("pixel_values")
mask = inputs.get("attention_mask")
full_token_list = full_input_ids.flatten().tolist()
prefix_len = self._prompt_cache.get_reusable_length(full_token_list)
if prefix_len > 0:
suffix_token_list = full_token_list[prefix_len:]
# If the suffix contains image placeholder tokens, we can't skip
# the vision encoder — fall back to full processing.
if (
image_token_index is not None
and image_token_index in suffix_token_list
):
logger.info(
"New images in suffix — prefix cache invalidated"
)
prefix_len = 0
if prefix_len > 0:
suffix_ids = mx.array([suffix_token_list])
logger.info(
"Prefix cache hit: reusing %d/%d tokens (%.1f%%), "
"processing %d new tokens",
prefix_len,
len(full_token_list),
100 * prefix_len / len(full_token_list),
len(suffix_token_list),
)
return {
"input_ids": suffix_ids,
"pixel_values": None, # images already in cached KV
"mask": None,
"prompt_cache": self._prompt_cache.cache,
"_full_token_ids": full_token_list,
"_prompt_token_count": len(full_token_list),
}
# Cache miss — create a fresh KV cache
cache = cache_module.make_prompt_cache(self.model.language_model)
logger.info(
"Prefix cache miss: processing %d tokens from scratch",
len(full_token_list),
)
return {
"input_ids": full_input_ids,
"pixel_values": pixel_values,
"mask": mask,
"prompt_cache": cache,
"_full_token_ids": full_token_list,
"_prompt_token_count": len(full_token_list),
}
def _save_cache(self, prep: dict, generated_tokens: list[int]) -> None:
"""Persist the KV cache and token IDs after generation."""
full_sequence = prep["_full_token_ids"] + generated_tokens
self._prompt_cache.update(prep["prompt_cache"], full_sequence)
def generate(
self,
prompt: str,
@@ -450,23 +552,41 @@ class InferenceEngine:
) -> tuple[str, int, int]:
"""Generate a complete response. Returns (text, prompt_tokens, completion_tokens)."""
with self._lock:
image_arg = images if images else None
result = mlx_vlm.generate(
prep = self._prepare_generation(prompt, images)
prompt_token_count = prep["_prompt_token_count"]
# Ensure stopping criteria is initialised
tokenizer = self._get_tokenizer()
tokenizer.stopping_criteria.reset(self.model.config.eos_token_id)
text = ""
generated_tokens: list[int] = []
gen_tokens = 0
for result in mlx_vlm.stream_generate(
self.model,
self.processor,
prompt,
image=image_arg,
input_ids=prep["input_ids"],
pixel_values=prep.get("pixel_values"),
mask=prep.get("mask"),
prompt_cache=prep["prompt_cache"],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
verbose=False,
**self._GENERATE_KWARGS,
)
text = result.text
):
text += result.text
if result.token is not None:
generated_tokens.append(result.token)
gen_tokens = result.generation_tokens
self._save_cache(prep, generated_tokens)
if stop:
text = self._apply_stop(text, stop)
return text, result.prompt_tokens, result.generation_tokens
return text, prompt_token_count, gen_tokens
def stream_generate(
self,
@@ -480,40 +600,52 @@ class InferenceEngine:
) -> Generator[tuple[str, bool, int, int], None, None]:
"""Stream tokens. Yields (token_text, is_final, prompt_tokens, gen_tokens)."""
with self._lock:
image_arg = images if images else None
prep = self._prepare_generation(prompt, images)
prompt_token_count = prep["_prompt_token_count"]
# Ensure stopping criteria is initialised
tokenizer = self._get_tokenizer()
tokenizer.stopping_criteria.reset(self.model.config.eos_token_id)
accumulated = ""
prompt_tokens = 0
generated_tokens: list[int] = []
gen_tokens = 0
try:
for result in mlx_vlm.stream_generate(
self.model,
self.processor,
prompt,
image=image_arg,
input_ids=prep["input_ids"],
pixel_values=prep.get("pixel_values"),
mask=prep.get("mask"),
prompt_cache=prep["prompt_cache"],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
**self._GENERATE_KWARGS,
):
# result.text is the incremental segment (detokenizer.last_segment),
# NOT the full accumulated text.
token_text = result.text
accumulated += token_text
prompt_tokens = result.prompt_tokens
if result.token is not None:
generated_tokens.append(result.token)
gen_tokens = result.generation_tokens
if stop and self._check_stop(accumulated, stop):
# Trim the accumulated text and yield what's safe
trimmed = self._apply_stop(accumulated, stop)
# Only yield the part we haven't yielded yet
safe_delta = trimmed[len(accumulated) - len(token_text):]
yield safe_delta, True, prompt_tokens, gen_tokens
safe_delta = trimmed[
len(accumulated) - len(token_text) :
]
yield safe_delta, True, prompt_token_count, gen_tokens
return
yield token_text, False, prompt_tokens, gen_tokens
yield token_text, False, prompt_token_count, gen_tokens
# Final yield to signal completion
yield "", True, prompt_tokens, gen_tokens
yield "", True, prompt_token_count, gen_tokens
finally:
self._save_cache(prep, generated_tokens)
@staticmethod
def _apply_stop(text: str, stop: list[str]) -> str: