fix: trying to do kv prefix caching
This commit is contained in:
@@ -151,25 +151,34 @@ def _get_param_names(func_name: str, tool_defs: dict[str, dict] | None) -> list[
|
||||
|
||||
|
||||
class PromptCache:
|
||||
"""Manages KV cache reuse across requests with shared prompt prefixes."""
|
||||
"""Manages KV cache reuse across requests with shared prompt prefixes.
|
||||
|
||||
Gemma 3 uses a mix of KVCache (full attention every 6th layer) and
|
||||
RotatingKVCache (sliding window, 1024 tokens). Since RotatingKVCache
|
||||
cannot be safely trimmed mid-sequence, we only reuse the cache when
|
||||
the ENTIRE cached token sequence is a prefix of the new prompt.
|
||||
|
||||
In multi-turn chat this is the common case: each new request extends
|
||||
the previous prompt with the assistant response + new user message.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._cache = None
|
||||
self._cache: list | None = None
|
||||
self._cached_token_ids: list[int] | None = None
|
||||
|
||||
def get_reusable_length(self, new_token_ids: list[int]) -> int:
|
||||
"""Find how many leading tokens match the cached prefix."""
|
||||
"""Return cached length if the entire cache is a valid prefix, else 0."""
|
||||
if self._cached_token_ids is None or self._cache is None:
|
||||
return 0
|
||||
max_match = min(len(self._cached_token_ids), len(new_token_ids))
|
||||
match_len = 0
|
||||
for i in range(max_match):
|
||||
cached_len = len(self._cached_token_ids)
|
||||
if cached_len > len(new_token_ids):
|
||||
return 0
|
||||
for i in range(cached_len):
|
||||
if self._cached_token_ids[i] != new_token_ids[i]:
|
||||
break
|
||||
match_len = i + 1
|
||||
return match_len
|
||||
return 0
|
||||
return cached_len
|
||||
|
||||
def update(self, cache, token_ids: list[int]) -> None:
|
||||
def update(self, cache: list, token_ids: list[int]) -> None:
|
||||
"""Store cache and the token IDs it was built from."""
|
||||
self._cache = cache
|
||||
self._cached_token_ids = list(token_ids)
|
||||
@@ -431,13 +440,106 @@ class InferenceEngine:
|
||||
return "\n".join(parts)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Generation
|
||||
# Prefix cache & generation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# Common kwargs for mlx_vlm generate calls
|
||||
# Note: KV cache quantization is not supported with Gemma 3's RotatingKVCache
|
||||
_GENERATE_KWARGS: dict = {}
|
||||
|
||||
def _get_tokenizer(self):
|
||||
"""Get the underlying tokenizer from the processor."""
|
||||
proc = self.processor
|
||||
return proc.tokenizer if hasattr(proc, "tokenizer") else proc
|
||||
|
||||
def _prepare_generation(
|
||||
self, prompt: str, images: list[str] | None = None
|
||||
) -> dict:
|
||||
"""Tokenize prompt, check prefix cache, return generation kwargs.
|
||||
|
||||
Returns a dict with keys:
|
||||
input_ids, pixel_values, mask, prompt_cache,
|
||||
_full_token_ids, _prompt_token_count
|
||||
"""
|
||||
from mlx_vlm.models import cache as cache_module
|
||||
from mlx_vlm.utils import prepare_inputs
|
||||
|
||||
model_type = getattr(self.config, "model_type", "")
|
||||
add_special_tokens = (
|
||||
not hasattr(self.processor, "chat_template")
|
||||
if model_type in ("gemma3", "gemma3n")
|
||||
else True
|
||||
)
|
||||
image_token_index = getattr(self.model.config, "image_token_index", None)
|
||||
|
||||
# Tokenize the full prompt (+ process pixel values if images present)
|
||||
inputs = prepare_inputs(
|
||||
self.processor,
|
||||
images=images if images else None,
|
||||
prompts=prompt,
|
||||
image_token_index=image_token_index,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
full_input_ids = inputs["input_ids"]
|
||||
pixel_values = inputs.get("pixel_values")
|
||||
mask = inputs.get("attention_mask")
|
||||
|
||||
full_token_list = full_input_ids.flatten().tolist()
|
||||
prefix_len = self._prompt_cache.get_reusable_length(full_token_list)
|
||||
|
||||
if prefix_len > 0:
|
||||
suffix_token_list = full_token_list[prefix_len:]
|
||||
|
||||
# If the suffix contains image placeholder tokens, we can't skip
|
||||
# the vision encoder — fall back to full processing.
|
||||
if (
|
||||
image_token_index is not None
|
||||
and image_token_index in suffix_token_list
|
||||
):
|
||||
logger.info(
|
||||
"New images in suffix — prefix cache invalidated"
|
||||
)
|
||||
prefix_len = 0
|
||||
|
||||
if prefix_len > 0:
|
||||
suffix_ids = mx.array([suffix_token_list])
|
||||
logger.info(
|
||||
"Prefix cache hit: reusing %d/%d tokens (%.1f%%), "
|
||||
"processing %d new tokens",
|
||||
prefix_len,
|
||||
len(full_token_list),
|
||||
100 * prefix_len / len(full_token_list),
|
||||
len(suffix_token_list),
|
||||
)
|
||||
return {
|
||||
"input_ids": suffix_ids,
|
||||
"pixel_values": None, # images already in cached KV
|
||||
"mask": None,
|
||||
"prompt_cache": self._prompt_cache.cache,
|
||||
"_full_token_ids": full_token_list,
|
||||
"_prompt_token_count": len(full_token_list),
|
||||
}
|
||||
|
||||
# Cache miss — create a fresh KV cache
|
||||
cache = cache_module.make_prompt_cache(self.model.language_model)
|
||||
logger.info(
|
||||
"Prefix cache miss: processing %d tokens from scratch",
|
||||
len(full_token_list),
|
||||
)
|
||||
return {
|
||||
"input_ids": full_input_ids,
|
||||
"pixel_values": pixel_values,
|
||||
"mask": mask,
|
||||
"prompt_cache": cache,
|
||||
"_full_token_ids": full_token_list,
|
||||
"_prompt_token_count": len(full_token_list),
|
||||
}
|
||||
|
||||
def _save_cache(self, prep: dict, generated_tokens: list[int]) -> None:
|
||||
"""Persist the KV cache and token IDs after generation."""
|
||||
full_sequence = prep["_full_token_ids"] + generated_tokens
|
||||
self._prompt_cache.update(prep["prompt_cache"], full_sequence)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -450,23 +552,41 @@ class InferenceEngine:
|
||||
) -> tuple[str, int, int]:
|
||||
"""Generate a complete response. Returns (text, prompt_tokens, completion_tokens)."""
|
||||
with self._lock:
|
||||
image_arg = images if images else None
|
||||
result = mlx_vlm.generate(
|
||||
prep = self._prepare_generation(prompt, images)
|
||||
prompt_token_count = prep["_prompt_token_count"]
|
||||
|
||||
# Ensure stopping criteria is initialised
|
||||
tokenizer = self._get_tokenizer()
|
||||
tokenizer.stopping_criteria.reset(self.model.config.eos_token_id)
|
||||
|
||||
text = ""
|
||||
generated_tokens: list[int] = []
|
||||
gen_tokens = 0
|
||||
|
||||
for result in mlx_vlm.stream_generate(
|
||||
self.model,
|
||||
self.processor,
|
||||
prompt,
|
||||
image=image_arg,
|
||||
input_ids=prep["input_ids"],
|
||||
pixel_values=prep.get("pixel_values"),
|
||||
mask=prep.get("mask"),
|
||||
prompt_cache=prep["prompt_cache"],
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
verbose=False,
|
||||
**self._GENERATE_KWARGS,
|
||||
)
|
||||
text = result.text
|
||||
):
|
||||
text += result.text
|
||||
if result.token is not None:
|
||||
generated_tokens.append(result.token)
|
||||
gen_tokens = result.generation_tokens
|
||||
|
||||
self._save_cache(prep, generated_tokens)
|
||||
|
||||
if stop:
|
||||
text = self._apply_stop(text, stop)
|
||||
return text, result.prompt_tokens, result.generation_tokens
|
||||
return text, prompt_token_count, gen_tokens
|
||||
|
||||
def stream_generate(
|
||||
self,
|
||||
@@ -480,40 +600,52 @@ class InferenceEngine:
|
||||
) -> Generator[tuple[str, bool, int, int], None, None]:
|
||||
"""Stream tokens. Yields (token_text, is_final, prompt_tokens, gen_tokens)."""
|
||||
with self._lock:
|
||||
image_arg = images if images else None
|
||||
prep = self._prepare_generation(prompt, images)
|
||||
prompt_token_count = prep["_prompt_token_count"]
|
||||
|
||||
# Ensure stopping criteria is initialised
|
||||
tokenizer = self._get_tokenizer()
|
||||
tokenizer.stopping_criteria.reset(self.model.config.eos_token_id)
|
||||
|
||||
accumulated = ""
|
||||
prompt_tokens = 0
|
||||
generated_tokens: list[int] = []
|
||||
gen_tokens = 0
|
||||
for result in mlx_vlm.stream_generate(
|
||||
self.model,
|
||||
self.processor,
|
||||
prompt,
|
||||
image=image_arg,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
**self._GENERATE_KWARGS,
|
||||
):
|
||||
# result.text is the incremental segment (detokenizer.last_segment),
|
||||
# NOT the full accumulated text.
|
||||
token_text = result.text
|
||||
accumulated += token_text
|
||||
prompt_tokens = result.prompt_tokens
|
||||
gen_tokens = result.generation_tokens
|
||||
|
||||
if stop and self._check_stop(accumulated, stop):
|
||||
# Trim the accumulated text and yield what's safe
|
||||
trimmed = self._apply_stop(accumulated, stop)
|
||||
# Only yield the part we haven't yielded yet
|
||||
safe_delta = trimmed[len(accumulated) - len(token_text):]
|
||||
yield safe_delta, True, prompt_tokens, gen_tokens
|
||||
return
|
||||
try:
|
||||
for result in mlx_vlm.stream_generate(
|
||||
self.model,
|
||||
self.processor,
|
||||
prompt,
|
||||
input_ids=prep["input_ids"],
|
||||
pixel_values=prep.get("pixel_values"),
|
||||
mask=prep.get("mask"),
|
||||
prompt_cache=prep["prompt_cache"],
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
**self._GENERATE_KWARGS,
|
||||
):
|
||||
token_text = result.text
|
||||
accumulated += token_text
|
||||
if result.token is not None:
|
||||
generated_tokens.append(result.token)
|
||||
gen_tokens = result.generation_tokens
|
||||
|
||||
yield token_text, False, prompt_tokens, gen_tokens
|
||||
if stop and self._check_stop(accumulated, stop):
|
||||
trimmed = self._apply_stop(accumulated, stop)
|
||||
safe_delta = trimmed[
|
||||
len(accumulated) - len(token_text) :
|
||||
]
|
||||
yield safe_delta, True, prompt_token_count, gen_tokens
|
||||
return
|
||||
|
||||
# Final yield to signal completion
|
||||
yield "", True, prompt_tokens, gen_tokens
|
||||
yield token_text, False, prompt_token_count, gen_tokens
|
||||
|
||||
# Final yield to signal completion
|
||||
yield "", True, prompt_token_count, gen_tokens
|
||||
finally:
|
||||
self._save_cache(prep, generated_tokens)
|
||||
|
||||
@staticmethod
|
||||
def _apply_stop(text: str, stop: list[str]) -> str:
|
||||
|
||||
Reference in New Issue
Block a user