fix: trying to do kv prefix caching
This commit is contained in:
@@ -151,25 +151,34 @@ def _get_param_names(func_name: str, tool_defs: dict[str, dict] | None) -> list[
|
|||||||
|
|
||||||
|
|
||||||
class PromptCache:
|
class PromptCache:
|
||||||
"""Manages KV cache reuse across requests with shared prompt prefixes."""
|
"""Manages KV cache reuse across requests with shared prompt prefixes.
|
||||||
|
|
||||||
|
Gemma 3 uses a mix of KVCache (full attention every 6th layer) and
|
||||||
|
RotatingKVCache (sliding window, 1024 tokens). Since RotatingKVCache
|
||||||
|
cannot be safely trimmed mid-sequence, we only reuse the cache when
|
||||||
|
the ENTIRE cached token sequence is a prefix of the new prompt.
|
||||||
|
|
||||||
|
In multi-turn chat this is the common case: each new request extends
|
||||||
|
the previous prompt with the assistant response + new user message.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._cache = None
|
self._cache: list | None = None
|
||||||
self._cached_token_ids: list[int] | None = None
|
self._cached_token_ids: list[int] | None = None
|
||||||
|
|
||||||
def get_reusable_length(self, new_token_ids: list[int]) -> int:
|
def get_reusable_length(self, new_token_ids: list[int]) -> int:
|
||||||
"""Find how many leading tokens match the cached prefix."""
|
"""Return cached length if the entire cache is a valid prefix, else 0."""
|
||||||
if self._cached_token_ids is None or self._cache is None:
|
if self._cached_token_ids is None or self._cache is None:
|
||||||
return 0
|
return 0
|
||||||
max_match = min(len(self._cached_token_ids), len(new_token_ids))
|
cached_len = len(self._cached_token_ids)
|
||||||
match_len = 0
|
if cached_len > len(new_token_ids):
|
||||||
for i in range(max_match):
|
return 0
|
||||||
|
for i in range(cached_len):
|
||||||
if self._cached_token_ids[i] != new_token_ids[i]:
|
if self._cached_token_ids[i] != new_token_ids[i]:
|
||||||
break
|
return 0
|
||||||
match_len = i + 1
|
return cached_len
|
||||||
return match_len
|
|
||||||
|
|
||||||
def update(self, cache, token_ids: list[int]) -> None:
|
def update(self, cache: list, token_ids: list[int]) -> None:
|
||||||
"""Store cache and the token IDs it was built from."""
|
"""Store cache and the token IDs it was built from."""
|
||||||
self._cache = cache
|
self._cache = cache
|
||||||
self._cached_token_ids = list(token_ids)
|
self._cached_token_ids = list(token_ids)
|
||||||
@@ -431,13 +440,106 @@ class InferenceEngine:
|
|||||||
return "\n".join(parts)
|
return "\n".join(parts)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Generation
|
# Prefix cache & generation
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
# Common kwargs for mlx_vlm generate calls
|
# Common kwargs for mlx_vlm generate calls
|
||||||
# Note: KV cache quantization is not supported with Gemma 3's RotatingKVCache
|
# Note: KV cache quantization is not supported with Gemma 3's RotatingKVCache
|
||||||
_GENERATE_KWARGS: dict = {}
|
_GENERATE_KWARGS: dict = {}
|
||||||
|
|
||||||
|
def _get_tokenizer(self):
|
||||||
|
"""Get the underlying tokenizer from the processor."""
|
||||||
|
proc = self.processor
|
||||||
|
return proc.tokenizer if hasattr(proc, "tokenizer") else proc
|
||||||
|
|
||||||
|
def _prepare_generation(
|
||||||
|
self, prompt: str, images: list[str] | None = None
|
||||||
|
) -> dict:
|
||||||
|
"""Tokenize prompt, check prefix cache, return generation kwargs.
|
||||||
|
|
||||||
|
Returns a dict with keys:
|
||||||
|
input_ids, pixel_values, mask, prompt_cache,
|
||||||
|
_full_token_ids, _prompt_token_count
|
||||||
|
"""
|
||||||
|
from mlx_vlm.models import cache as cache_module
|
||||||
|
from mlx_vlm.utils import prepare_inputs
|
||||||
|
|
||||||
|
model_type = getattr(self.config, "model_type", "")
|
||||||
|
add_special_tokens = (
|
||||||
|
not hasattr(self.processor, "chat_template")
|
||||||
|
if model_type in ("gemma3", "gemma3n")
|
||||||
|
else True
|
||||||
|
)
|
||||||
|
image_token_index = getattr(self.model.config, "image_token_index", None)
|
||||||
|
|
||||||
|
# Tokenize the full prompt (+ process pixel values if images present)
|
||||||
|
inputs = prepare_inputs(
|
||||||
|
self.processor,
|
||||||
|
images=images if images else None,
|
||||||
|
prompts=prompt,
|
||||||
|
image_token_index=image_token_index,
|
||||||
|
add_special_tokens=add_special_tokens,
|
||||||
|
)
|
||||||
|
full_input_ids = inputs["input_ids"]
|
||||||
|
pixel_values = inputs.get("pixel_values")
|
||||||
|
mask = inputs.get("attention_mask")
|
||||||
|
|
||||||
|
full_token_list = full_input_ids.flatten().tolist()
|
||||||
|
prefix_len = self._prompt_cache.get_reusable_length(full_token_list)
|
||||||
|
|
||||||
|
if prefix_len > 0:
|
||||||
|
suffix_token_list = full_token_list[prefix_len:]
|
||||||
|
|
||||||
|
# If the suffix contains image placeholder tokens, we can't skip
|
||||||
|
# the vision encoder — fall back to full processing.
|
||||||
|
if (
|
||||||
|
image_token_index is not None
|
||||||
|
and image_token_index in suffix_token_list
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"New images in suffix — prefix cache invalidated"
|
||||||
|
)
|
||||||
|
prefix_len = 0
|
||||||
|
|
||||||
|
if prefix_len > 0:
|
||||||
|
suffix_ids = mx.array([suffix_token_list])
|
||||||
|
logger.info(
|
||||||
|
"Prefix cache hit: reusing %d/%d tokens (%.1f%%), "
|
||||||
|
"processing %d new tokens",
|
||||||
|
prefix_len,
|
||||||
|
len(full_token_list),
|
||||||
|
100 * prefix_len / len(full_token_list),
|
||||||
|
len(suffix_token_list),
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"input_ids": suffix_ids,
|
||||||
|
"pixel_values": None, # images already in cached KV
|
||||||
|
"mask": None,
|
||||||
|
"prompt_cache": self._prompt_cache.cache,
|
||||||
|
"_full_token_ids": full_token_list,
|
||||||
|
"_prompt_token_count": len(full_token_list),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Cache miss — create a fresh KV cache
|
||||||
|
cache = cache_module.make_prompt_cache(self.model.language_model)
|
||||||
|
logger.info(
|
||||||
|
"Prefix cache miss: processing %d tokens from scratch",
|
||||||
|
len(full_token_list),
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"input_ids": full_input_ids,
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
"mask": mask,
|
||||||
|
"prompt_cache": cache,
|
||||||
|
"_full_token_ids": full_token_list,
|
||||||
|
"_prompt_token_count": len(full_token_list),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _save_cache(self, prep: dict, generated_tokens: list[int]) -> None:
|
||||||
|
"""Persist the KV cache and token IDs after generation."""
|
||||||
|
full_sequence = prep["_full_token_ids"] + generated_tokens
|
||||||
|
self._prompt_cache.update(prep["prompt_cache"], full_sequence)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -450,23 +552,41 @@ class InferenceEngine:
|
|||||||
) -> tuple[str, int, int]:
|
) -> tuple[str, int, int]:
|
||||||
"""Generate a complete response. Returns (text, prompt_tokens, completion_tokens)."""
|
"""Generate a complete response. Returns (text, prompt_tokens, completion_tokens)."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
image_arg = images if images else None
|
prep = self._prepare_generation(prompt, images)
|
||||||
result = mlx_vlm.generate(
|
prompt_token_count = prep["_prompt_token_count"]
|
||||||
|
|
||||||
|
# Ensure stopping criteria is initialised
|
||||||
|
tokenizer = self._get_tokenizer()
|
||||||
|
tokenizer.stopping_criteria.reset(self.model.config.eos_token_id)
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
generated_tokens: list[int] = []
|
||||||
|
gen_tokens = 0
|
||||||
|
|
||||||
|
for result in mlx_vlm.stream_generate(
|
||||||
self.model,
|
self.model,
|
||||||
self.processor,
|
self.processor,
|
||||||
prompt,
|
prompt,
|
||||||
image=image_arg,
|
input_ids=prep["input_ids"],
|
||||||
|
pixel_values=prep.get("pixel_values"),
|
||||||
|
mask=prep.get("mask"),
|
||||||
|
prompt_cache=prep["prompt_cache"],
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
verbose=False,
|
|
||||||
**self._GENERATE_KWARGS,
|
**self._GENERATE_KWARGS,
|
||||||
)
|
):
|
||||||
text = result.text
|
text += result.text
|
||||||
|
if result.token is not None:
|
||||||
|
generated_tokens.append(result.token)
|
||||||
|
gen_tokens = result.generation_tokens
|
||||||
|
|
||||||
|
self._save_cache(prep, generated_tokens)
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
text = self._apply_stop(text, stop)
|
text = self._apply_stop(text, stop)
|
||||||
return text, result.prompt_tokens, result.generation_tokens
|
return text, prompt_token_count, gen_tokens
|
||||||
|
|
||||||
def stream_generate(
|
def stream_generate(
|
||||||
self,
|
self,
|
||||||
@@ -480,40 +600,52 @@ class InferenceEngine:
|
|||||||
) -> Generator[tuple[str, bool, int, int], None, None]:
|
) -> Generator[tuple[str, bool, int, int], None, None]:
|
||||||
"""Stream tokens. Yields (token_text, is_final, prompt_tokens, gen_tokens)."""
|
"""Stream tokens. Yields (token_text, is_final, prompt_tokens, gen_tokens)."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
image_arg = images if images else None
|
prep = self._prepare_generation(prompt, images)
|
||||||
|
prompt_token_count = prep["_prompt_token_count"]
|
||||||
|
|
||||||
|
# Ensure stopping criteria is initialised
|
||||||
|
tokenizer = self._get_tokenizer()
|
||||||
|
tokenizer.stopping_criteria.reset(self.model.config.eos_token_id)
|
||||||
|
|
||||||
accumulated = ""
|
accumulated = ""
|
||||||
prompt_tokens = 0
|
generated_tokens: list[int] = []
|
||||||
gen_tokens = 0
|
gen_tokens = 0
|
||||||
|
|
||||||
|
try:
|
||||||
for result in mlx_vlm.stream_generate(
|
for result in mlx_vlm.stream_generate(
|
||||||
self.model,
|
self.model,
|
||||||
self.processor,
|
self.processor,
|
||||||
prompt,
|
prompt,
|
||||||
image=image_arg,
|
input_ids=prep["input_ids"],
|
||||||
|
pixel_values=prep.get("pixel_values"),
|
||||||
|
mask=prep.get("mask"),
|
||||||
|
prompt_cache=prep["prompt_cache"],
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
**self._GENERATE_KWARGS,
|
**self._GENERATE_KWARGS,
|
||||||
):
|
):
|
||||||
# result.text is the incremental segment (detokenizer.last_segment),
|
|
||||||
# NOT the full accumulated text.
|
|
||||||
token_text = result.text
|
token_text = result.text
|
||||||
accumulated += token_text
|
accumulated += token_text
|
||||||
prompt_tokens = result.prompt_tokens
|
if result.token is not None:
|
||||||
|
generated_tokens.append(result.token)
|
||||||
gen_tokens = result.generation_tokens
|
gen_tokens = result.generation_tokens
|
||||||
|
|
||||||
if stop and self._check_stop(accumulated, stop):
|
if stop and self._check_stop(accumulated, stop):
|
||||||
# Trim the accumulated text and yield what's safe
|
|
||||||
trimmed = self._apply_stop(accumulated, stop)
|
trimmed = self._apply_stop(accumulated, stop)
|
||||||
# Only yield the part we haven't yielded yet
|
safe_delta = trimmed[
|
||||||
safe_delta = trimmed[len(accumulated) - len(token_text):]
|
len(accumulated) - len(token_text) :
|
||||||
yield safe_delta, True, prompt_tokens, gen_tokens
|
]
|
||||||
|
yield safe_delta, True, prompt_token_count, gen_tokens
|
||||||
return
|
return
|
||||||
|
|
||||||
yield token_text, False, prompt_tokens, gen_tokens
|
yield token_text, False, prompt_token_count, gen_tokens
|
||||||
|
|
||||||
# Final yield to signal completion
|
# Final yield to signal completion
|
||||||
yield "", True, prompt_tokens, gen_tokens
|
yield "", True, prompt_token_count, gen_tokens
|
||||||
|
finally:
|
||||||
|
self._save_cache(prep, generated_tokens)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _apply_stop(text: str, stop: list[str]) -> str:
|
def _apply_stop(text: str, stop: list[str]) -> str:
|
||||||
|
|||||||
Reference in New Issue
Block a user