feat: add Llama 4 Scout/Maverick tensor and pipeline parallel support#1781
Open
vskiwi wants to merge 3 commits intoexo-explore:mainfrom
Open
feat: add Llama 4 Scout/Maverick tensor and pipeline parallel support#1781vskiwi wants to merge 3 commits intoexo-explore:mainfrom
vskiwi wants to merge 3 commits intoexo-explore:mainfrom
Conversation
added 3 commits
March 24, 2026 09:21
Add sharding strategies for both Llama 4 model variants: - Llama4ShardingStrategy for MoE models (Scout 16E, Maverick 128E) with SwitchGLU expert sharding via ShardedMoE + shared_expert - Llama4TextShardingStrategy for dense-only text models Pipeline parallel support patches LlamaModel.__call__ to use global layer indices for chunked/global attention mask selection and hybrid ChunkedKVCache/KVCache creation after layer slicing. Includes model cards for Scout (4/6/8-bit) and Maverick (4/6-bit), whitelist entries for Llama4ForConditionalGeneration and Llama4ForCausalLM, and prefix cache test coverage.
The dynamic class replacement in _patch_llama4_pipeline breaks MLX's JIT compilation cache, causing ~100x slowdown on single-device. Only apply the patch when world_size > 1, since at start_layer=0 the original LlamaModel.__call__ uses correct global indices.
mlx_lm's BatchGenerator calls _merge_caches which requires a merge() method on each cache entry. ChunkedKVCache (used by Llama 4 for chunked attention layers) does not implement merge, causing a crash after prefill completes. Detect ChunkedKVCache at generator build time and fall back to SequentialGenerator which uses stream_generate without batching.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Add full inference support for Meta's Llama 4 family — Scout (109B/17B active, 16 experts) and Maverick (400B/17B active, 128 experts). Previously Llama 4 did not work at all in exo, even on a single node.
SequentialGeneratorfor models usingChunkedKVCache(mlx-lm'sBatchGeneratordoes not support it — ml-explore/mlx-lm#686)ShardedMoE) and dense layersmake_cacheand__call__to handle Llama 4's hybrid ChunkedKVCache/KVCache across pipeline stagesBenchmark (4× Mac Studio M3 Ultra 512GB, TB5 RDMA)
63 tokens input → 512 tokens output, bench mode:
(generation tok/s; tensor RDMA scales up to +90% from 1→4 nodes on memory-bound configs)
Changes (9 files, +283/-2)
auto_parallel.py(+200):Llama4ShardingStrategy— mixed dense/MoE: GQA attention sharding + SwitchGLU expert sharding viaShardedMoE+ shared expert. Handles Scout (all-MoE,interleave_moe_layer_step=1) and Maverick (alternating,step=2).Llama4TextShardingStrategy— dense-only variant (Llama4ForCausalLM)_patch_llama4_pipeline()— pipeline parallel fix: Llama 4'sLlamaModel.__call__hardcodescache[3]and(idx+1)%4for chunked/global attention mask selection using local layer indices. After pipeline slicing these become incorrect. The patch replacesmake_cacheand__call__on the inner model (via dynamic subclass) to use global layer offsets. Skipped forworld_size=1to preserve MLX JIT compilation cache.runner.py(+17):_has_chunked_kv_cache()— detect models withChunkedKVCacheinmake_cache()and route toSequentialGeneratorinstead ofBatchGenerator. Workaround for ml-explore/mlx-lm#686.model_cards.py(+2):Llama4ForConditionalGeneration,Llama4ForCausalLMinsupports_tensorwhitelist.Model cards (5 TOML files): Scout 4/6/8-bit, Maverick 4/6-bit with actual safetensors file sizes.
test_prefix_cache_architectures.py(+1):ArchSpec("llama4", ...)for Scout.Known limitations
ChunkedKVCachedoesn't implement.merge()(upstream mlx-lm limitation). Falls back to sequential (one request at a time). Does not affect distributed inference.eval_impldue to insufficient memory for GPU command buffers. Recommend distributing to 2+ nodes.Test plan
basedpyright— 0 errorsruff check— 0 errorspytest— all tests pass