Skip to content

feat: add Llama 4 Scout/Maverick tensor and pipeline parallel support#1781

Open
vskiwi wants to merge 3 commits intoexo-explore:mainfrom
vskiwi:feat/llama4-parallel-support
Open

feat: add Llama 4 Scout/Maverick tensor and pipeline parallel support#1781
vskiwi wants to merge 3 commits intoexo-explore:mainfrom
vskiwi:feat/llama4-parallel-support

Conversation

@vskiwi
Copy link
Copy Markdown
Contributor

@vskiwi vskiwi commented Mar 24, 2026

Summary

Add full inference support for Meta's Llama 4 family — Scout (109B/17B active, 16 experts) and Maverick (400B/17B active, 128 experts). Previously Llama 4 did not work at all in exo, even on a single node.

  • Single-node inference: automatic fallback to SequentialGenerator for models using ChunkedKVCache (mlx-lm's BatchGenerator does not support it — ml-explore/mlx-lm#686)
  • Tensor parallel sharding for both MoE (SwitchGLU experts + shared expert via ShardedMoE) and dense layers
  • Pipeline parallel with patched make_cache and __call__ to handle Llama 4's hybrid ChunkedKVCache/KVCache across pipeline stages
  • Model cards for Scout (4/6/8-bit) and Maverick (4/6-bit) with actual safetensors file sizes

Benchmark (4× Mac Studio M3 Ultra 512GB, TB5 RDMA)

63 tokens input → 512 tokens output, bench mode:

Model Quant Size 1 node 2n tensor 4n tensor 4n pipeline
Scout 4-bit 57 GB 46.7 58.2 69.7 36.8
Scout 8-bit 108 GB 30.3 44.0 57.7 26.0
Maverick 4-bit 211 GB 53.1 67.5 76.1 41.6
Maverick 6-bit 304 GB 56.3 69.9 33.0

(generation tok/s; tensor RDMA scales up to +90% from 1→4 nodes on memory-bound configs)

Changes (9 files, +283/-2)

auto_parallel.py (+200):

  • Llama4ShardingStrategy — mixed dense/MoE: GQA attention sharding + SwitchGLU expert sharding via ShardedMoE + shared expert. Handles Scout (all-MoE, interleave_moe_layer_step=1) and Maverick (alternating, step=2).
  • Llama4TextShardingStrategy — dense-only variant (Llama4ForCausalLM)
  • _patch_llama4_pipeline() — pipeline parallel fix: Llama 4's LlamaModel.__call__ hardcodes cache[3] and (idx+1)%4 for chunked/global attention mask selection using local layer indices. After pipeline slicing these become incorrect. The patch replaces make_cache and __call__ on the inner model (via dynamic subclass) to use global layer offsets. Skipped for world_size=1 to preserve MLX JIT compilation cache.

runner.py (+17):

  • _has_chunked_kv_cache() — detect models with ChunkedKVCache in make_cache() and route to SequentialGenerator instead of BatchGenerator. Workaround for ml-explore/mlx-lm#686.

model_cards.py (+2): Llama4ForConditionalGeneration, Llama4ForCausalLM in supports_tensor whitelist.

Model cards (5 TOML files): Scout 4/6/8-bit, Maverick 4/6-bit with actual safetensors file sizes.

test_prefix_cache_architectures.py (+1): ArchSpec("llama4", ...) for Scout.

Known limitations

  • No concurrent batching for Llama 4 — ChunkedKVCache doesn't implement .merge() (upstream mlx-lm limitation). Falls back to sequential (one request at a time). Does not affect distributed inference.
  • Single-node memory threshold — models occupying >60% unified memory (e.g. Maverick 6-bit 304GB on 512GB node) may deadlock in Metal/MLX eval_impl due to insufficient memory for GPU command buffers. Recommend distributing to 2+ nodes.

Test plan

  • basedpyright — 0 errors
  • ruff check — 0 errors
  • pytest — all tests pass
  • Single-node inference (Scout 4-bit): ~47 tok/s generation, ~153 tok/s prefill
  • Tensor parallel RDMA: 2-node and 4-node (Scout + Maverick, all quantizations)
  • Pipeline parallel TCP: 4-node (Scout + Maverick)
  • Full benchmark matrix: 5 models × up to 5 configs (see table above)

user added 3 commits March 24, 2026 09:21
Add sharding strategies for both Llama 4 model variants:
- Llama4ShardingStrategy for MoE models (Scout 16E, Maverick 128E)
  with SwitchGLU expert sharding via ShardedMoE + shared_expert
- Llama4TextShardingStrategy for dense-only text models

Pipeline parallel support patches LlamaModel.__call__ to use global
layer indices for chunked/global attention mask selection and hybrid
ChunkedKVCache/KVCache creation after layer slicing.

Includes model cards for Scout (4/6/8-bit) and Maverick (4/6-bit),
whitelist entries for Llama4ForConditionalGeneration and
Llama4ForCausalLM, and prefix cache test coverage.
The dynamic class replacement in _patch_llama4_pipeline breaks MLX's
JIT compilation cache, causing ~100x slowdown on single-device.
Only apply the patch when world_size > 1, since at start_layer=0
the original LlamaModel.__call__ uses correct global indices.
mlx_lm's BatchGenerator calls _merge_caches which requires a merge()
method on each cache entry. ChunkedKVCache (used by Llama 4 for
chunked attention layers) does not implement merge, causing a crash
after prefill completes.

Detect ChunkedKVCache at generator build time and fall back to
SequentialGenerator which uses stream_generate without batching.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant