Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion bench/METHODOLOGY.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Chat template formatting means that it may be impossible to attain very small pp

When a request reaches the server via the `/bench/chat/completions` endpoint, three things change compared to a normal chat completion:

- **KV prefix cache is disabled**. Every request starts from a cold cache, ensuring prefill timing is not affected by prior requests.
- **KV prefix cache is disabled by default**. Every request starts from a cold cache, ensuring prefill timing is not affected by prior requests. See [Prefix Cache Mode](#prefix-cache-mode) for the `--use-prefix-cache` option.
- **EOS tokens are banned**. A logits processor suppresses all end-of-sequence tokens, forcing the model to generate exactly `max_tokens` tokens. This guarantees consistent generation length for fair TPS comparison — the model cannot short-circuit a run by stopping early.
- **No model output parsing**. The bench collection path concatenates raw token text without any model-specific post-processing (thinking tag extraction, structured output handling, etc.). This is to avoid model outputs such as tool parsing or any structural mistakes from breaking the benchmark - we are testing for speed; see Exo-Eval for performance metrics.

Expand Down Expand Up @@ -94,6 +94,22 @@ agg_gen_tps = per_req_tps * concurrency

---

## Prefix Cache Mode

When `--use-prefix-cache` is passed, the KV prefix cache remains active during benchmarking. This speeds up repeated runs by skipping redundant prefill work, which is useful when prompt processing is not the focus of the benchmark (e.g. when measuring generation throughput or power consumption across many configurations).

Each response includes a `prefix_cache_hit` field (`"none"`, `"partial"`, or `"exact"`):

- **none**: Cold prefill — no cached KV state was available. The reported `prompt_tps` is the real prefill throughput.
- **partial**: A prefix of the prompt was found in cache. Only the remaining tokens were prefilled. The reported `prompt_tps` reflects the real throughput on the uncached portion. This occurs when multiple ascending `--pp` values share a common prefix (e.g. `--pp 1000,5000` — the 5000-token prompt reuses the 1000-token cache entry and prefills the remaining 4000 tokens).
- **exact**: The entire prompt was found in cache (e.g. same `--pp` value on a `--repeat`). No prefill work was done. The reported `prompt_tps` is the TPS from when the cache entry was originally created, not a new measurement.

**Prompt TPS is approximate in this mode.** Exact-hit runs report the stored TPS from the original cold/partial prefill rather than a freshly measured value. For accurate cold prefill numbers, run without `--use-prefix-cache`.

Ascending `--pp` order (e.g. `--pp 1000,5000,10000`) gives the most useful data: each size gets a meaningful partial hit except the first which is cold. Descending order produces exact hits with approximate TPS from a longer prompt's original run.

---

## Warmup

Before measurement begins, `--warmup N` (default: 0) discarded requests are sent using the first pp/tg pair. Warmup results are not included in the output.
Expand Down
38 changes: 35 additions & 3 deletions bench/exo_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,13 @@ def parse_int_list(values: list[str]) -> list[int]:


def run_one_completion(
client: ExoClient, model_id: str, pp_hint: int, tg: int, prompt_sizer: PromptSizer
client: ExoClient,
model_id: str,
pp_hint: int,
tg: int,
prompt_sizer: PromptSizer,
*,
use_prefix_cache: bool = False,
) -> tuple[dict[str, Any], int]:
content, pp_tokens = prompt_sizer.build(pp_hint)
payload: dict[str, Any] = {
Expand All @@ -239,6 +245,7 @@ def run_one_completion(
"stream": False,
"max_tokens": tg,
"logprobs": False,
"use_prefix_cache": use_prefix_cache,
}

t0 = time.perf_counter()
Expand Down Expand Up @@ -379,6 +386,11 @@ def main() -> int:
default=1.0,
help="System metrics polling interval in seconds (default: 1.0).",
)
ap.add_argument(
"--use-prefix-cache",
action="store_true",
help="Enable KV prefix cache during bench (default: disabled for cold-cache measurements).",
)
args = ap.parse_args()

pp_list = parse_int_list(args.pp)
Expand All @@ -394,6 +406,15 @@ def main() -> int:
logger.error("--concurrency values must be >= 1")
return 2

if args.use_prefix_cache:
logger.warning(
"--use-prefix-cache: prompt TPS will be approximate. See METHODOLOGY.md for details."
)
if pp_list != sorted(pp_list):
logger.warning(
"--pp values are not in ascending order: prompt TPS will be less accurate. Use ascending --pp for best results."
)

# Log pairing mode
use_combinations = args.all_combinations or len(pp_list) != len(tg_list)
if use_combinations:
Expand Down Expand Up @@ -505,7 +526,12 @@ def main() -> int:
try:
for i in range(args.warmup):
run_one_completion(
client, full_model_id, pp_list[0], tg_list[0], prompt_sizer
client,
full_model_id,
pp_list[0],
tg_list[0],
prompt_sizer,
use_prefix_cache=args.use_prefix_cache,
)
logger.debug(f" warmup {i + 1}/{args.warmup} done")

Expand All @@ -529,7 +555,12 @@ def main() -> int:
try:
inf_t0 = time.monotonic()
row, actual_pp_tokens = run_one_completion(
client, full_model_id, pp, tg, prompt_sizer
client,
full_model_id,
pp,
tg,
prompt_sizer,
use_prefix_cache=args.use_prefix_cache,
)
inference_windows.append((inf_t0, time.monotonic()))
except Exception as e:
Expand Down Expand Up @@ -566,6 +597,7 @@ def main() -> int:
"stream": False,
"max_tokens": tg,
"logprobs": False,
"use_prefix_cache": args.use_prefix_cache,
}
barrier = threading.Barrier(concurrency)
batch_start = threading.Event()
Expand Down
8 changes: 7 additions & 1 deletion src/exo/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,13 @@ async def bench_chat_completions(
)
task_params = task_params.model_copy(update={"model": resolved_model})

task_params = task_params.model_copy(update={"stream": False, "bench": True})
task_params = task_params.model_copy(
update={
"stream": False,
"bench": True,
"use_prefix_cache": payload.use_prefix_cache,
}
)

command = await self._send_text_generation_with_images(task_params)

Expand Down
3 changes: 2 additions & 1 deletion src/exo/api/types/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ class GenerationStats(BaseModel):
prompt_tokens: int
generation_tokens: int
peak_memory_usage: Memory
prefix_cache_hit: Literal["none", "partial", "exact"] = "none"


class ImageGenerationStats(BaseModel):
Expand Down Expand Up @@ -232,7 +233,7 @@ class ChatCompletionRequest(BaseModel):


class BenchChatCompletionRequest(ChatCompletionRequest):
pass
use_prefix_cache: bool = False


class AddCustomModelParams(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions src/exo/shared/types/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class TextGenerationTaskParams(BaseModel, frozen=True):
stream: bool = False
tools: list[dict[str, Any]] | None = None
bench: bool = False
use_prefix_cache: bool = False
top_k: int | None = None
stop: str | list[str] | None = None
seed: int | None = None
Expand Down
18 changes: 13 additions & 5 deletions src/exo/worker/engines/mlx/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(self, group: mx.distributed.Group | None):
self._snapshots: list[list[CacheSnapshot] | None] = []
self._media_regions: list[list["MediaRegion"]] = []
self._last_used: list[int] = [] # monotonic counter of last access per entry
self.prefill_tps: list[float] = []
self._access_counter: int = 0
self._group = group

Expand All @@ -130,20 +131,23 @@ def clear(self):
self._snapshots.clear()
self._media_regions.clear()
self._last_used.clear()
self.prefill_tps.clear()

def add_kv_cache(
self,
prompt_tokens: mx.array,
cache: KVCacheType,
ssm_snapshots: list[CacheSnapshot] | None = None,
media_regions: list["MediaRegion"] | None = None,
prefill_tps: float = 0.0,
):
"""Add a new cache entry. Evicts LRU entries if memory is high."""
self._evict_if_needed()
self.prompts.append(prompt_tokens)
self.caches.append(deepcopy(cache))
self._snapshots.append(ssm_snapshots)
self._media_regions.append(media_regions or [])
self.prefill_tps.append(prefill_tps)
self._access_counter += 1
self._last_used.append(self._access_counter)
logger.info(f"KV cache added: {len(prompt_tokens)} tokens")
Expand All @@ -156,6 +160,7 @@ def update_kv_cache(
snapshots: list[CacheSnapshot] | None,
restore_pos: int,
media_regions: list["MediaRegion"] | None = None,
prefill_tps: float = 0.0,
):
"""Update an existing cache entry in-place."""
old_snapshots = self._snapshots[index]
Expand All @@ -169,6 +174,7 @@ def update_kv_cache(
self.caches[index] = deepcopy(cache)
self._snapshots[index] = merged or None
self._media_regions[index] = media_regions or []
self.prefill_tps[index] = prefill_tps
self._access_counter += 1
self._last_used[index] = self._access_counter
logger.info(f"KV cache updated (index {index}): {len(prompt_tokens)} tokens")
Expand All @@ -194,14 +200,15 @@ def get_kv_cache(
model: Model,
prompt_tokens: mx.array,
media_regions: list["MediaRegion"] | None = None,
) -> tuple[KVCacheType, mx.array, int | None]:
) -> tuple[KVCacheType, mx.array, int | None, bool]:
"""Get KV cache for prompt, returning remaining tokens to prefill.

Returns:
Tuple of (cache, remaining_tokens, matched_index) where:
Tuple of (cache, remaining_tokens, matched_index, is_exact) where:
- cache: KV cache to use for generation
- remaining_tokens: tokens that still need prefilling
- matched_index: index of the matched entry (None if no match)
- is_exact: True if the full prompt matched the cached entry

For models with SSM layers (which are ArraysCache in mlx), the cache is trimmed to the
nearest SSM snapshot position at or before the match point for correctness.
Expand Down Expand Up @@ -235,7 +242,7 @@ def get_kv_cache(
best_index, best_length = i, length

if best_index is None:
return make_kv_cache(model), prompt_tokens, None
return make_kv_cache(model), prompt_tokens, None, False

# For exact match: trim to max_length-1 so remaining has the last token
# For partial match: trim to best_length, remaining has suffix to prefill
Expand All @@ -246,7 +253,7 @@ def get_kv_cache(

# No usable snapshot — need fresh cache
if restore_snap is None and has_ssm:
return make_kv_cache(model), prompt_tokens, None
return make_kv_cache(model), prompt_tokens, None, False

prompt_cache = deepcopy(self.caches[best_index])
cached_length = cache_length(self.caches[best_index])
Expand All @@ -262,7 +269,7 @@ def get_kv_cache(
self._last_used[best_index] = self._access_counter
remaining = prompt_tokens[restore_pos:]

return prompt_cache, remaining, best_index
return prompt_cache, remaining, best_index, is_exact

@staticmethod
def _validate_media_match(
Expand Down Expand Up @@ -312,6 +319,7 @@ def _evict_if_needed(self):
self._snapshots.pop(lru_index)
self._media_regions.pop(lru_index)
self._last_used.pop(lru_index)
self.prefill_tps.pop(lru_index)
logger.info(
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory usage"
)
Expand Down
35 changes: 26 additions & 9 deletions src/exo/worker/engines/mlx/generator/batch_generate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
import time
from dataclasses import dataclass, field
from typing import Callable, cast
from typing import Callable, Literal, cast

import mlx.core as mx
from mlx_lm.generate import (
Expand Down Expand Up @@ -74,14 +74,14 @@ class _EngineTask:
all_prompt_tokens: mx.array
prefix_hit_length: int
matched_index: int | None
cache_snapshots: list[CacheSnapshot] | None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this just unused?

detokenizer: StreamingDetokenizer
on_generation_token: Callable[[], None] | None = None
generated_text_parts: list[str] = field(default_factory=list)
potential_stop_sequence_text: str = ""
completion_tokens: int = 0
generation_start_time: float = 0.0
prefill_tps: float = 0.0
prefix_cache_hit: Literal["none", "partial", "exact"] = "none"
media_regions: list[MediaRegion] = field(default_factory=list)
first_gen_token_time: float | None = None
last_gen_token_time: float | None = None
Expand Down Expand Up @@ -155,11 +155,16 @@ def submit(

prefix_hit_length = 0
matched_index: int | None = None
is_exact_hit = False
prompt_tokens = all_prompt_tokens

if self.kv_prefix_cache is not None and not is_bench:
cache, remaining_tokens, matched_index = self.kv_prefix_cache.get_kv_cache(
self.model, all_prompt_tokens, media_regions=media_regions
if self.kv_prefix_cache is not None and (
not is_bench or task_params.use_prefix_cache
):
cache, remaining_tokens, matched_index, is_exact_hit = (
self.kv_prefix_cache.get_kv_cache(
self.model, all_prompt_tokens, media_regions=media_regions
)
)
prefix_hit_length = len(all_prompt_tokens) - len(remaining_tokens)
if prefix_hit_length > 0:
Expand All @@ -168,8 +173,6 @@ def submit(
f"cached ({100 * prefix_hit_length / len(all_prompt_tokens):.1f}%)"
)
prompt_tokens = remaining_tokens
else:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to do this twice

cache = make_kv_cache(self.model)
else:
cache = make_kv_cache(self.model)

Expand Down Expand Up @@ -208,6 +211,15 @@ def submit(
distributed_prompt_progress_callback,
)

prefix_cache_hit: Literal["none", "partial", "exact"] = "none"
if matched_index is not None and prefix_hit_length > 0:
assert self.kv_prefix_cache is not None
if is_exact_hit:
prefix_cache_hit = "exact"
_prefill_tps = self.kv_prefix_cache.prefill_tps[matched_index]
else:
prefix_cache_hit = "partial"

# We need to clamp rotating kv caches to max size so that mlx lm's _merge_caches behaves
for c in cache:
if (
Expand All @@ -221,7 +233,7 @@ def submit(
c.values = c._trim(trim_size, c.values)
c._idx = c.max_size

if not is_bench:
if not is_bench or task_params.use_prefix_cache:
min_prefix_hit_length = max(
1000, system_prompt_token_count(task_params, self.tokenizer)
)
Expand All @@ -233,6 +245,7 @@ def submit(
matched_index,
min_prefix_hit_length,
media_regions,
prefill_tps=_prefill_tps,
)

last_tokens = prompt_tokens[-2:]
Expand Down Expand Up @@ -268,11 +281,11 @@ def submit(
all_prompt_tokens=all_prompt_tokens,
prefix_hit_length=prefix_hit_length,
matched_index=matched_index,
cache_snapshots=cache_snapshots or None,
detokenizer=self.tokenizer.detokenizer,
on_generation_token=on_generation_token,
generation_start_time=time.perf_counter(),
prefill_tps=_prefill_tps,
prefix_cache_hit=prefix_cache_hit,
media_regions=media_regions,
)

Expand Down Expand Up @@ -383,6 +396,7 @@ def step(self) -> list[tuple[int, GenerationResponse]]:
prompt_tokens=len(state.all_prompt_tokens),
generation_tokens=state.completion_tokens,
peak_memory_usage=Memory.from_gb(mx.get_peak_memory() / 1e9),
prefix_cache_hit=state.prefix_cache_hit,
)
total_prompt_tokens = len(state.all_prompt_tokens)
usage = Usage(
Expand Down Expand Up @@ -449,6 +463,7 @@ def _save_prefix_cache(
matched_index: int | None,
min_prefix_hit_length: int = 1000,
media_regions: list[MediaRegion] | None = None,
prefill_tps: float = 0.0,
) -> None:
if self.kv_prefix_cache is None:
return
Expand All @@ -470,13 +485,15 @@ def _save_prefix_cache(
cache_snapshots,
restore_pos=prefix_hit_length,
media_regions=media_regions,
prefill_tps=prefill_tps,
)
else:
self.kv_prefix_cache.add_kv_cache(
all_prompt_tokens,
cache,
cache_snapshots,
media_regions=media_regions,
prefill_tps=prefill_tps,
)
except Exception:
logger.warning("Failed to save prefix cache", exc_info=True)
Loading
Loading