diff --git a/bench/METHODOLOGY.md b/bench/METHODOLOGY.md new file mode 100644 index 0000000000..841752a6f8 --- /dev/null +++ b/bench/METHODOLOGY.md @@ -0,0 +1,152 @@ +# Exo-Bench — Methodology + +exo bench measures inference throughput and resource consumption of an exo cluster under controlled conditions. It sends prompts to the `/bench/chat/completions` endpoint, collects server-reported timing statistics, and records system-level metrics (power, GPU utilisation, temperature) throughout each run. + +The goal is to have accurate, transparent and reproducible numbers to compare speed and scaling across different models and different setups, and to be able to track these results as optimizations and features are added to EXO. + +Below is a technical summary of how Exo-Bench works. While the methodology and benchmark may change over time, this document will be kept up to date whenever this happens. If you find an issue with the methodology, or would like a feature to be added, please open a GitHub issue! + +--- + +## Prompt Construction + +Benchmarks need prompts of an exact token length. Unfortunately, we do not have direct access to the model but just the chat completion endpoint. To get around this fact, we create a request that will tokenise to a certain prompt length. + +This is achieved by: + +1. Tokenising a sample message through the model's `apply_chat_template()` to measure overhead (system tokens, special tokens, chat formatting). +2. Binary-searching over a repeated atom string (default `"a "`) to find the content length that produces exactly the target number of tokens after template expansion. +3. Returning both the content string and the verified token count. + +The actual token count is recorded in every result row as `pp_tokens`, so downstream analysis can confirm the prompt hit its target. + +Chat template formatting means that it may be impossible to attain very small pp benchmarks. e.g. pp=32 may not work. This tradeoff was made because the result of such a small prompt does not seem very interesting or useful for any real-world use cases. + +--- + +## Bench Endpoint + +When a request reaches the server via the `/bench/chat/completions` endpoint, three things change compared to a normal chat completion: + +- **KV prefix cache is disabled**. Every request starts from a cold cache, ensuring prefill timing is not affected by prior requests. +- **EOS tokens are banned**. A logits processor suppresses all end-of-sequence tokens, forcing the model to generate exactly `max_tokens` tokens. This guarantees consistent generation length for fair TPS comparison — the model cannot short-circuit a run by stopping early. +- **No model output parsing**. The bench collection path concatenates raw token text without any model-specific post-processing (thinking tag extraction, structured output handling, etc.). This is to avoid model outputs such as tool parsing or any structural mistakes from breaking the benchmark - we are testing for speed; see Exo-Eval for performance metrics. + +--- + +## Timing + +### Prefill TPS + +Measured server-side per task. + +``` +prefill_tps = num_prompt_tokens / prefill_wall_seconds +``` + +### Generation TPS + +Measured server-side per task. Each task records wall-clock timestamps as tokens arrive: + +- First generated token: timestamp recorded +- Every subsequent token: timestamp updated + +When generation completes: + +``` +gen_span = last_token_time - first_token_time +generation_tps = (completion_tokens - 1) / gen_span +``` + +The first token is excluded from the numerator because the rate measures inter-token throughput — the time between the first and last token divided by the number of intervals. + +This does mean that tg=1 will not work. + +--- + +## Concurrency + +### Single Request + +The client records wall-clock `elapsed_s` around the HTTP round-trip (network latency + server prefill + generation + response serialisation). This is a convenience metric for end-to-end latency. The authoritative TPS numbers come from the server-side per-task timing in the `generation_stats` response. + +### Concurrent Requests + +When `--concurrency N` is set with N > 1, all N requests must hit the server at the same instant. The mechanism: + +1. The prompt is built once and shared across all threads. +2. Each thread gets its own HTTP connection. +3. A thread barrier blocks all threads until every thread is ready. +4. The first thread past the barrier records the batch start time and signals the others. +5. All threads use the same start time as their reference, then fire their HTTP request. +6. Each thread's `elapsed_s` is measured from the shared start time to its own response completion. + +**Batch wall time** is the maximum `elapsed_s` across all N requests — the time until the last request finishes. + +### Aggregate TPS + +``` +per_req_tps = max(generation_tps across N concurrent requests) +agg_gen_tps = per_req_tps * concurrency +``` + +`max` is used instead of `mean` because all requests run in parallel against the same model. The fastest request's generation rate represents the system's per-stream throughput capacity; multiplying by concurrency gives aggregate throughput. + +--- + +## Warmup + +Before measurement begins, `--warmup N` (default: 0) discarded requests are sent using the first pp/tg pair. Warmup results are not included in the output. + +--- + +## Compatibility Version + +Each exo node reports a `benchCompat` version (e.g. `"0.3.69.1"`) — the exo release version plus a bench-specific revision. This tracks bench-relevant changes independently of the release cycle: timing methodology, stats format, endpoint behavior, etc. + +Before benchmarking, exo bench checks every node's `benchCompat` against a supported range. If any node has an unknown or out-of-range version, the benchmark aborts. This prevents silently producing inaccurate results with incompatible nodes. + +--- + +## System Metrics + +A background thread polls each node at 1 Hz, collecting: + +- GPU utilisation (%) +- Temperature (C) +- System power draw (W) +- CPU cluster usage (performance and efficiency cores) + +**Energy** is computed via trapezoidal integration of the power samples over each inference window (the wall-clock span of each benchmark request or concurrent batch). Average power is `total_joules / total_inference_seconds`. + +--- + +## Output Format + +Results are written as JSON with three top-level keys: + +- **`runs`**: Array of per-request result objects, each containing: + - `elapsed_s`, `output_text_preview` (first 200 chars) + - `stats`: `{ prompt_tps, generation_tps, prompt_tokens, generation_tokens, peak_memory_usage }` + - Placement metadata: `model_id`, `placement_sharding`, `placement_instance_meta`, `placement_nodes` + - Run metadata: `pp_tokens`, `tg`, `repeat_index`, `concurrency`, `concurrent_index` + - `download_duration_s` (if model was freshly downloaded) +- **`cluster`**: Cluster state snapshot at time of benchmarking. +- **`system_metrics`**: Per-node time-series samples (GPU, power, temperature). + +--- + +## Reproducing Results + +```bash +cd bench && uv run python exo_bench.py \ + --model "mlx-community/Qwen3.5-27B-4bit" \ + --instance-meta jaccl \ + --sharding tensor \ + --min-nodes 2 --max-nodes 2 \ + --pp 512 4096 --tg 128 \ + --repeat 3 \ + --warmup 1 +``` + +Run --help for all the available flags. diff --git a/bench/exo_bench.py b/bench/exo_bench.py index 97b6b3ac83..33f118382a 100644 --- a/bench/exo_bench.py +++ b/bench/exo_bench.py @@ -45,8 +45,14 @@ wait_for_instance_ready, ) from loguru import logger +from packaging.version import InvalidVersion, Version from transformers import AutoTokenizer +from exo.shared.types.profiling import BENCH_COMPAT + +MIN_BENCH_COMPAT = Version("0.3.69.1") +MAX_BENCH_COMPAT = Version(BENCH_COMPAT) + # Monkey-patch for transformers 5.x compatibility # Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location # which was moved in transformers 5.0.0rc2 @@ -435,6 +441,7 @@ def main() -> int: reverse=True, ) + logger.info(f"exo-bench benchCompat={BENCH_COMPAT}") logger.debug(f"exo-bench model: short_id={short_id} full_id={full_model_id}") logger.info(f"placements: {len(selected)}") for p in selected: @@ -464,6 +471,30 @@ def main() -> int: logger.info("Download: model already cached") cluster_snapshot = capture_cluster_snapshot(client) + + node_identities = cluster_snapshot.get("nodeIdentities", {}) + for node_id, identity in node_identities.items(): + bench_compat_str = identity.get("benchCompat", "Unknown") + exo_ver = identity.get("exoVersion", "Unknown") + if bench_compat_str == "Unknown": + logger.error( + f"Node {node_id} (exo {exo_ver}) does not report benchCompat — " + f"requires {MIN_BENCH_COMPAT}–{MAX_BENCH_COMPAT}. " + f"Update exo on this node." + ) + return 1 + try: + bench_compat = Version(bench_compat_str) + except InvalidVersion: + logger.error(f"Node {node_id} has invalid benchCompat {bench_compat_str!r}") + return 1 + if bench_compat < MIN_BENCH_COMPAT or bench_compat > MAX_BENCH_COMPAT: + logger.error( + f"Node {node_id} benchCompat {bench_compat} outside supported range " + f"{MIN_BENCH_COMPAT}–{MAX_BENCH_COMPAT}" + ) + return 1 + all_rows: list[dict[str, Any]] = [] all_system_metrics: dict[str, dict[str, dict[str, float]]] = {} @@ -557,21 +588,51 @@ def main() -> int: all_rows.append(row) else: # Concurrent: fire N requests in parallel - # Each thread gets its own ExoClient (separate HTTP connection) + # Pre-build prompt once, barrier ensures simultaneous dispatch + content, actual_pp = prompt_sizer.build(pp) + pre_built_payload: dict[str, Any] = { + "model": full_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "max_tokens": tg, + } + barrier = threading.Barrier(concurrency) + batch_start = threading.Event() + batch_t0: float = 0.0 batch_results: list[tuple[dict[str, Any], int]] = [] batch_errors = 0 def _run_concurrent( - idx: int, *, _pp: int = pp, _tg: int = tg + idx: int, + _barrier: threading.Barrier = barrier, + _batch_start: threading.Event = batch_start, + _payload: dict[str, Any] = pre_built_payload, + _actual_pp: int = actual_pp, ) -> tuple[dict[str, Any], int]: + nonlocal batch_t0 c = ExoClient( args.host, args.port, timeout_s=args.timeout ) - return run_one_completion( - c, full_model_id, _pp, _tg, prompt_sizer + if _barrier.wait() == 0: + batch_t0 = time.perf_counter() + _batch_start.set() + else: + _batch_start.wait() + t0 = batch_t0 + out = c.post_bench_chat_completions(_payload) + elapsed = time.perf_counter() - t0 + stats = out.get("generation_stats") + choices = out.get("choices") or [{}] + message = ( + choices[0].get("message", {}) if choices else {} ) + text = message.get("content") or "" + return { + "elapsed_s": elapsed, + "output_text_preview": text[:200], + "stats": stats, + }, _actual_pp - inf_t0 = time.monotonic() with ThreadPoolExecutor(max_workers=concurrency) as pool: futures = { pool.submit(_run_concurrent, i): i @@ -583,6 +644,11 @@ def _run_concurrent( except Exception as e: logger.error(f"Concurrent request failed: {e}") batch_errors += 1 + batch_wall_s = ( + max(x["elapsed_s"] for x, _ in batch_results) + if batch_results + else time.perf_counter() - batch_t0 + ) inference_windows.append((inf_t0, time.monotonic())) for idx, (row, actual_pp_tokens) in enumerate( @@ -618,19 +684,25 @@ def _run_concurrent( if x["stats"]["generation_tps"] > 0 ] per_req_tps = ( - mean(valid_gen_tps) if valid_gen_tps else 0.0 + max(valid_gen_tps) if valid_gen_tps else 0.0 ) agg_gen_tps = per_req_tps * concurrency logger.info( f"[concurrent {concurrency}x] " f"agg_gen_tps={agg_gen_tps:.2f} " f"per_req_tps={per_req_tps:.2f} " + f"wall_s={batch_wall_s:.2f} " f"errors={batch_errors}" ) if runs: prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs) - per_req_tps = mean(x["stats"]["generation_tps"] for x in runs) + valid_gen = [ + x["stats"]["generation_tps"] + for x in runs + if x["stats"]["generation_tps"] > 0 + ] + per_req_tps = max(valid_gen) if valid_gen else 0.0 gen_tps = per_req_tps * concurrency ptok = mean(x["stats"]["prompt_tokens"] for x in runs) gtok = mean(x["stats"]["generation_tokens"] for x in runs) @@ -672,7 +744,7 @@ def _run_concurrent( time.sleep(5) - output: dict[str, Any] = {"runs": all_rows} + output: dict[str, Any] = {"bench_compat": BENCH_COMPAT, "runs": all_rows} if cluster_snapshot: output["cluster"] = cluster_snapshot if all_system_metrics: diff --git a/rust/exo_pyo3_bindings/tests/test_python.py b/rust/exo_pyo3_bindings/tests/test_python.py index fc6f0caa13..a653103d16 100644 --- a/rust/exo_pyo3_bindings/tests/test_python.py +++ b/rust/exo_pyo3_bindings/tests/test_python.py @@ -12,7 +12,7 @@ @pytest.mark.asyncio async def test_sleep_on_multiple_items() -> None: print("PYTHON: starting handle") - h = NetworkingHandle(Keypair.generate()) + h = NetworkingHandle(Keypair.generate(), [], 0) rt = asyncio.create_task(_await_recv(h)) diff --git a/src/exo/shared/apply.py b/src/exo/shared/apply.py index fb9f6d9ab4..b453788dfa 100644 --- a/src/exo/shared/apply.py +++ b/src/exo/shared/apply.py @@ -316,6 +316,8 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State: "chip_id": info.chip, "os_version": info.os_version, "os_build_version": info.os_build_version, + "exo_version": info.exo_version, + "bench_compat": info.bench_compat, } ) update["node_identities"] = { diff --git a/src/exo/shared/types/profiling.py b/src/exo/shared/types/profiling.py index ad1d48c0f2..3aa5bd3f40 100644 --- a/src/exo/shared/types/profiling.py +++ b/src/exo/shared/types/profiling.py @@ -75,6 +75,9 @@ class NetworkInterfaceInfo(CamelCaseModel): interface_type: InterfaceType = "unknown" +BENCH_COMPAT: str = "0.3.69.1" + + class NodeIdentity(CamelCaseModel): """Static and slow-changing node identification data.""" @@ -83,6 +86,8 @@ class NodeIdentity(CamelCaseModel): friendly_name: str = "Unknown" os_version: str = "Unknown" os_build_version: str = "Unknown" + exo_version: str = "Unknown" + bench_compat: str = "Unknown" class NodeNetworkInfo(CamelCaseModel): diff --git a/src/exo/utils/info_gatherer/info_gatherer.py b/src/exo/utils/info_gatherer/info_gatherer.py index 9e75e14353..3a421bcfad 100644 --- a/src/exo/utils/info_gatherer/info_gatherer.py +++ b/src/exo/utils/info_gatherer/info_gatherer.py @@ -4,6 +4,7 @@ import tomllib from collections.abc import Sequence from dataclasses import dataclass, field +from importlib.metadata import version as pkg_version from subprocess import CalledProcessError from typing import Self, cast @@ -16,6 +17,7 @@ from exo.shared.constants import EXO_CONFIG_FILE, EXO_DEFAULT_MODELS_DIR from exo.shared.types.memory import Memory from exo.shared.types.profiling import ( + BENCH_COMPAT, DiskUsage, MemoryUsage, NetworkInterfaceInfo, @@ -185,6 +187,8 @@ class StaticNodeInformation(TaggedModel): chip: str os_version: str os_build_version: str + exo_version: str + bench_compat: str @classmethod async def gather(cls) -> Self: @@ -194,6 +198,8 @@ async def gather(cls) -> Self: chip=chip, os_version=get_os_version(), os_build_version=await get_os_build_version(), + exo_version=pkg_version("exo"), + bench_compat=BENCH_COMPAT, ) diff --git a/src/exo/worker/engines/mlx/generator/batch_generate.py b/src/exo/worker/engines/mlx/generator/batch_generate.py index 143c4d4012..dba462aa31 100644 --- a/src/exo/worker/engines/mlx/generator/batch_generate.py +++ b/src/exo/worker/engines/mlx/generator/batch_generate.py @@ -82,6 +82,8 @@ class _EngineTask: reasoning_tokens: int = 0 prefill_tps: float = 0.0 media_regions: list[MediaRegion] = field(default_factory=list) + first_gen_token_time: float | None = None + last_gen_token_time: float | None = None @dataclass(eq=False) @@ -292,6 +294,10 @@ def step(self) -> list[tuple[int, GenerationResponse]]: continue state = self._active_tasks[response.uid] + now = time.perf_counter() + if state.first_gen_token_time is None: + state.first_gen_token_time = now + state.last_gen_token_time = now if state.on_generation_token is not None: state.on_generation_token() if response.finish_reason != "stop": @@ -300,6 +306,11 @@ def step(self) -> list[tuple[int, GenerationResponse]]: state.detokenizer.finalize() text = state.detokenizer.last_segment state.completion_tokens += 1 + if state.task_params.bench: + delta = now - state.first_gen_token_time + logger.debug( + f"[bench] uid={response.uid} tok#{state.completion_tokens} {text!r} t={delta:.4f}s" + ) state.generated_text_parts.append(text) state.potential_stop_sequence_text += text @@ -354,15 +365,15 @@ def step(self) -> list[tuple[int, GenerationResponse]]: stats: GenerationStats | None = None usage: Usage | None = None if is_done: - gen_time_delta = ( - self._mlx_gen._stats.generation_time - - state.generation_time_at_start - ) - generation_tps = ( - state.completion_tokens / gen_time_delta - if gen_time_delta > 0 - else 0.0 - ) + if state.completion_tokens > 1: + gen_span = state.last_gen_token_time - state.first_gen_token_time + generation_tps = ( + (state.completion_tokens - 1) / gen_span + if gen_span > 0 + else 0.0 + ) + else: + generation_tps = 0.0 stats = GenerationStats( prompt_tps=state.prefill_tps, diff --git a/src/exo/worker/engines/mlx/generator/generate.py b/src/exo/worker/engines/mlx/generator/generate.py index bb89dc7ef1..500d174945 100644 --- a/src/exo/worker/engines/mlx/generator/generate.py +++ b/src/exo/worker/engines/mlx/generator/generate.py @@ -218,6 +218,8 @@ def pipeline_parallel_prefill( model(prompt[-1:][None], cache=_prompt_cache) quantize_cache_fn(_prompt_cache) flush_prefill_sends() + if distributed_prompt_progress_callback is not None: + distributed_prompt_progress_callback() assert _prompt_cache is not None with mx.stream(generation_stream): diff --git a/src/exo/worker/tests/unittests/test_mlx/test_pipeline_prefill_callbacks.py b/src/exo/worker/tests/unittests/test_mlx/test_pipeline_prefill_callbacks.py index be723380de..78b1d9e63e 100644 --- a/src/exo/worker/tests/unittests/test_mlx/test_pipeline_prefill_callbacks.py +++ b/src/exo/worker/tests/unittests/test_mlx/test_pipeline_prefill_callbacks.py @@ -329,16 +329,16 @@ class TestPipelineNoDeadlock: """Pipeline prefill must not deadlock at any rank count or prompt length.""" @pytest.mark.parametrize( - "layer_splits,prompt_tokens", + "layer_splits,prompt_tokens,base_port", [ - (LAYER_SPLITS_2WAY, 128), - (LAYER_SPLITS_2WAY, 4096), - (LAYER_SPLITS_2WAY, 8192), - (LAYER_SPLITS_2WAY, 16384), - (LAYER_SPLITS_4WAY, 128), - (LAYER_SPLITS_4WAY, 4096), - (LAYER_SPLITS_4WAY, 8192), - (LAYER_SPLITS_4WAY, 16384), + (LAYER_SPLITS_2WAY, 128, 29650), + (LAYER_SPLITS_2WAY, 4096, 29654), + (LAYER_SPLITS_2WAY, 8192, 29658), + (LAYER_SPLITS_2WAY, 16384, 29662), + (LAYER_SPLITS_4WAY, 128, 29666), + (LAYER_SPLITS_4WAY, 4096, 29670), + (LAYER_SPLITS_4WAY, 8192, 29674), + (LAYER_SPLITS_4WAY, 16384, 29678), ], ids=[ "2rank_128tok", @@ -355,12 +355,13 @@ def test_no_deadlock( self, layer_splits: list[tuple[int, int]], prompt_tokens: int, + base_port: int, ) -> None: """Pipeline must complete without deadlock at various prompt lengths.""" pipeline_results = _run_pipeline_test( layer_splits=layer_splits, prompt_tokens=prompt_tokens, - base_port=29650, + base_port=base_port, timeout=60, ) # If we get here, no deadlock. Verify all ranks produced output.