diff --git a/mlx_server/engine.py b/mlx_server/engine.py index 8187eeb..5e1014a 100644 --- a/mlx_server/engine.py +++ b/mlx_server/engine.py @@ -151,25 +151,34 @@ def _get_param_names(func_name: str, tool_defs: dict[str, dict] | None) -> list[ class PromptCache: - """Manages KV cache reuse across requests with shared prompt prefixes.""" + """Manages KV cache reuse across requests with shared prompt prefixes. + + Gemma 3 uses a mix of KVCache (full attention every 6th layer) and + RotatingKVCache (sliding window, 1024 tokens). Since RotatingKVCache + cannot be safely trimmed mid-sequence, we only reuse the cache when + the ENTIRE cached token sequence is a prefix of the new prompt. + + In multi-turn chat this is the common case: each new request extends + the previous prompt with the assistant response + new user message. + """ def __init__(self): - self._cache = None + self._cache: list | None = None self._cached_token_ids: list[int] | None = None def get_reusable_length(self, new_token_ids: list[int]) -> int: - """Find how many leading tokens match the cached prefix.""" + """Return cached length if the entire cache is a valid prefix, else 0.""" if self._cached_token_ids is None or self._cache is None: return 0 - max_match = min(len(self._cached_token_ids), len(new_token_ids)) - match_len = 0 - for i in range(max_match): + cached_len = len(self._cached_token_ids) + if cached_len > len(new_token_ids): + return 0 + for i in range(cached_len): if self._cached_token_ids[i] != new_token_ids[i]: - break - match_len = i + 1 - return match_len + return 0 + return cached_len - def update(self, cache, token_ids: list[int]) -> None: + def update(self, cache: list, token_ids: list[int]) -> None: """Store cache and the token IDs it was built from.""" self._cache = cache self._cached_token_ids = list(token_ids) @@ -431,13 +440,106 @@ class InferenceEngine: return "\n".join(parts) # ------------------------------------------------------------------ - # Generation + # Prefix cache & generation # ------------------------------------------------------------------ # Common kwargs for mlx_vlm generate calls # Note: KV cache quantization is not supported with Gemma 3's RotatingKVCache _GENERATE_KWARGS: dict = {} + def _get_tokenizer(self): + """Get the underlying tokenizer from the processor.""" + proc = self.processor + return proc.tokenizer if hasattr(proc, "tokenizer") else proc + + def _prepare_generation( + self, prompt: str, images: list[str] | None = None + ) -> dict: + """Tokenize prompt, check prefix cache, return generation kwargs. + + Returns a dict with keys: + input_ids, pixel_values, mask, prompt_cache, + _full_token_ids, _prompt_token_count + """ + from mlx_vlm.models import cache as cache_module + from mlx_vlm.utils import prepare_inputs + + model_type = getattr(self.config, "model_type", "") + add_special_tokens = ( + not hasattr(self.processor, "chat_template") + if model_type in ("gemma3", "gemma3n") + else True + ) + image_token_index = getattr(self.model.config, "image_token_index", None) + + # Tokenize the full prompt (+ process pixel values if images present) + inputs = prepare_inputs( + self.processor, + images=images if images else None, + prompts=prompt, + image_token_index=image_token_index, + add_special_tokens=add_special_tokens, + ) + full_input_ids = inputs["input_ids"] + pixel_values = inputs.get("pixel_values") + mask = inputs.get("attention_mask") + + full_token_list = full_input_ids.flatten().tolist() + prefix_len = self._prompt_cache.get_reusable_length(full_token_list) + + if prefix_len > 0: + suffix_token_list = full_token_list[prefix_len:] + + # If the suffix contains image placeholder tokens, we can't skip + # the vision encoder — fall back to full processing. + if ( + image_token_index is not None + and image_token_index in suffix_token_list + ): + logger.info( + "New images in suffix — prefix cache invalidated" + ) + prefix_len = 0 + + if prefix_len > 0: + suffix_ids = mx.array([suffix_token_list]) + logger.info( + "Prefix cache hit: reusing %d/%d tokens (%.1f%%), " + "processing %d new tokens", + prefix_len, + len(full_token_list), + 100 * prefix_len / len(full_token_list), + len(suffix_token_list), + ) + return { + "input_ids": suffix_ids, + "pixel_values": None, # images already in cached KV + "mask": None, + "prompt_cache": self._prompt_cache.cache, + "_full_token_ids": full_token_list, + "_prompt_token_count": len(full_token_list), + } + + # Cache miss — create a fresh KV cache + cache = cache_module.make_prompt_cache(self.model.language_model) + logger.info( + "Prefix cache miss: processing %d tokens from scratch", + len(full_token_list), + ) + return { + "input_ids": full_input_ids, + "pixel_values": pixel_values, + "mask": mask, + "prompt_cache": cache, + "_full_token_ids": full_token_list, + "_prompt_token_count": len(full_token_list), + } + + def _save_cache(self, prep: dict, generated_tokens: list[int]) -> None: + """Persist the KV cache and token IDs after generation.""" + full_sequence = prep["_full_token_ids"] + generated_tokens + self._prompt_cache.update(prep["prompt_cache"], full_sequence) + def generate( self, prompt: str, @@ -450,23 +552,41 @@ class InferenceEngine: ) -> tuple[str, int, int]: """Generate a complete response. Returns (text, prompt_tokens, completion_tokens).""" with self._lock: - image_arg = images if images else None - result = mlx_vlm.generate( + prep = self._prepare_generation(prompt, images) + prompt_token_count = prep["_prompt_token_count"] + + # Ensure stopping criteria is initialised + tokenizer = self._get_tokenizer() + tokenizer.stopping_criteria.reset(self.model.config.eos_token_id) + + text = "" + generated_tokens: list[int] = [] + gen_tokens = 0 + + for result in mlx_vlm.stream_generate( self.model, self.processor, prompt, - image=image_arg, + input_ids=prep["input_ids"], + pixel_values=prep.get("pixel_values"), + mask=prep.get("mask"), + prompt_cache=prep["prompt_cache"], max_tokens=max_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, - verbose=False, **self._GENERATE_KWARGS, - ) - text = result.text + ): + text += result.text + if result.token is not None: + generated_tokens.append(result.token) + gen_tokens = result.generation_tokens + + self._save_cache(prep, generated_tokens) + if stop: text = self._apply_stop(text, stop) - return text, result.prompt_tokens, result.generation_tokens + return text, prompt_token_count, gen_tokens def stream_generate( self, @@ -480,40 +600,52 @@ class InferenceEngine: ) -> Generator[tuple[str, bool, int, int], None, None]: """Stream tokens. Yields (token_text, is_final, prompt_tokens, gen_tokens).""" with self._lock: - image_arg = images if images else None + prep = self._prepare_generation(prompt, images) + prompt_token_count = prep["_prompt_token_count"] + + # Ensure stopping criteria is initialised + tokenizer = self._get_tokenizer() + tokenizer.stopping_criteria.reset(self.model.config.eos_token_id) + accumulated = "" - prompt_tokens = 0 + generated_tokens: list[int] = [] gen_tokens = 0 - for result in mlx_vlm.stream_generate( - self.model, - self.processor, - prompt, - image=image_arg, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - repetition_penalty=repetition_penalty, - **self._GENERATE_KWARGS, - ): - # result.text is the incremental segment (detokenizer.last_segment), - # NOT the full accumulated text. - token_text = result.text - accumulated += token_text - prompt_tokens = result.prompt_tokens - gen_tokens = result.generation_tokens - if stop and self._check_stop(accumulated, stop): - # Trim the accumulated text and yield what's safe - trimmed = self._apply_stop(accumulated, stop) - # Only yield the part we haven't yielded yet - safe_delta = trimmed[len(accumulated) - len(token_text):] - yield safe_delta, True, prompt_tokens, gen_tokens - return + try: + for result in mlx_vlm.stream_generate( + self.model, + self.processor, + prompt, + input_ids=prep["input_ids"], + pixel_values=prep.get("pixel_values"), + mask=prep.get("mask"), + prompt_cache=prep["prompt_cache"], + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + repetition_penalty=repetition_penalty, + **self._GENERATE_KWARGS, + ): + token_text = result.text + accumulated += token_text + if result.token is not None: + generated_tokens.append(result.token) + gen_tokens = result.generation_tokens - yield token_text, False, prompt_tokens, gen_tokens + if stop and self._check_stop(accumulated, stop): + trimmed = self._apply_stop(accumulated, stop) + safe_delta = trimmed[ + len(accumulated) - len(token_text) : + ] + yield safe_delta, True, prompt_token_count, gen_tokens + return - # Final yield to signal completion - yield "", True, prompt_tokens, gen_tokens + yield token_text, False, prompt_token_count, gen_tokens + + # Final yield to signal completion + yield "", True, prompt_token_count, gen_tokens + finally: + self._save_cache(prep, generated_tokens) @staticmethod def _apply_stop(text: str, stop: list[str]) -> str: