diff --git a/.mlx_typings/mlx/nn/layers/quantized.pyi b/.mlx_typings/mlx/nn/layers/quantized.pyi index 137a4c8ed1..7cd43dd65d 100644 --- a/.mlx_typings/mlx/nn/layers/quantized.pyi +++ b/.mlx_typings/mlx/nn/layers/quantized.pyi @@ -2,7 +2,7 @@ This type stub file was generated by pyright. """ -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, Union import mlx.core as mx from base import Module @@ -13,8 +13,10 @@ def quantize( bits: int = ..., *, mode: str = ..., - class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = ..., -): # -> None: + class_predicate: Optional[ + Callable[[str, Module], Union[bool, dict[str, Any]]] + ] = ..., +) -> None: """Quantize the sub-modules of a module according to a predicate. By default all layers that define a ``to_quantized(group_size, bits)`` diff --git a/.mlx_typings/mlx_lm/models/gemma4.pyi b/.mlx_typings/mlx_lm/models/gemma4.pyi new file mode 100644 index 0000000000..0ff6ef87f8 --- /dev/null +++ b/.mlx_typings/mlx_lm/models/gemma4.pyi @@ -0,0 +1,31 @@ +from dataclasses import dataclass +from typing import Any, Optional + +import mlx.core as mx +import mlx.nn as nn + +from . import gemma4_text +from .base import BaseModelArgs +from .cache import KVCache, RotatingKVCache + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + text_config: Optional[dict[str, Any]] + vocab_size: int + + def __post_init__(self) -> None: ... + +class Model(nn.Module): + args: ModelArgs + model_type: str + language_model: gemma4_text.Model + + def __init__(self, args: ModelArgs) -> None: ... + def __call__(self, *args: Any, **kwargs: Any) -> mx.array: ... + def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ... + @property + def layers(self) -> list[gemma4_text.DecoderLayer]: ... + @property + def quant_predicate(self) -> Any: ... + def make_cache(self) -> list[KVCache | RotatingKVCache]: ... diff --git a/.mlx_typings/mlx_lm/models/gemma4_text.pyi b/.mlx_typings/mlx_lm/models/gemma4_text.pyi new file mode 100644 index 0000000000..728d91c108 --- /dev/null +++ b/.mlx_typings/mlx_lm/models/gemma4_text.pyi @@ -0,0 +1,179 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs +from .cache import KVCache, RotatingKVCache +from .switch_layers import SwitchGLU + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + head_dim: int + global_head_dim: int + global_partial_rotary_factor: float + rms_norm_eps: float + vocab_size: int + vocab_size_per_layer_input: int + num_key_value_heads: int + num_global_key_value_heads: Optional[int] + num_kv_shared_layers: int + pad_token_id: int + hidden_size_per_layer_input: int + rope_traditional: bool + partial_rotary_factor: float + rope_parameters: Optional[Dict[str, Any]] + sliding_window: int + sliding_window_pattern: int + max_position_embeddings: int + attention_k_eq_v: bool + final_logit_softcapping: float + use_double_wide_mlp: bool + enable_moe_block: bool + num_experts: Optional[int] + top_k_experts: Optional[int] + moe_intermediate_size: Optional[int] + layer_types: Optional[List[str]] + tie_word_embeddings: bool + + def __post_init__(self) -> None: ... + +class MLP(nn.Module): + gate_proj: nn.Linear + down_proj: nn.Linear + up_proj: nn.Linear + + def __init__(self, config: ModelArgs, layer_idx: int = 0) -> None: ... + def __call__(self, x: mx.array) -> mx.array: ... + +class Router(nn.Module): + proj: nn.Linear + scale: mx.array + per_expert_scale: mx.array + + def __init__(self, config: ModelArgs) -> None: ... + def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ... + +class Experts(nn.Module): + switch_glu: SwitchGLU + + def __init__(self, config: ModelArgs) -> None: ... + def __call__( + self, x: mx.array, top_k_indices: mx.array, top_k_weights: mx.array + ) -> mx.array: ... + +class Attention(nn.Module): + layer_idx: int + layer_type: str + is_sliding: bool + head_dim: int + n_heads: int + n_kv_heads: int + use_k_eq_v: bool + scale: float + q_proj: nn.Linear + k_proj: nn.Linear + v_proj: nn.Linear + o_proj: nn.Linear + q_norm: nn.Module + k_norm: nn.Module + v_norm: nn.Module + rope: nn.Module + + def __init__(self, config: ModelArgs, layer_idx: int) -> None: ... + def __call__(self, *args: Any, **kwargs: Any) -> Any: ... + +class DecoderLayer(nn.Module): + layer_idx: int + layer_type: str + self_attn: Attention + mlp: MLP + enable_moe: bool + router: Router + experts: Experts + input_layernorm: nn.Module + post_attention_layernorm: nn.Module + pre_feedforward_layernorm: nn.Module + post_feedforward_layernorm: nn.Module + post_feedforward_layernorm_1: nn.Module + post_feedforward_layernorm_2: nn.Module + pre_feedforward_layernorm_2: nn.Module + hidden_size_per_layer_input: int + per_layer_input_gate: Optional[nn.Linear] + per_layer_projection: Optional[nn.Linear] + post_per_layer_input_norm: Optional[nn.Module] + layer_scalar: mx.array + + def __init__(self, config: ModelArgs, layer_idx: int) -> None: ... + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = ..., + cache: Optional[Any] = ..., + per_layer_input: Optional[mx.array] = ..., + shared_kv: Optional[tuple[mx.array, mx.array]] = ..., + offset: Optional[mx.array] = ..., + ) -> tuple[mx.array, tuple[mx.array, mx.array], mx.array]: ... + +class Gemma4TextModel(nn.Module): + config: ModelArgs + vocab_size: int + window_size: int + sliding_window_pattern: int + num_hidden_layers: int + embed_tokens: nn.Embedding + embed_scale: float + layers: list[DecoderLayer] + norm: nn.Module + hidden_size_per_layer_input: int + embed_tokens_per_layer: Optional[nn.Embedding] + per_layer_model_projection: Optional[nn.Linear] + per_layer_projection_norm: Optional[nn.Module] + previous_kvs: list[int] + + def __init__(self, config: ModelArgs) -> None: ... + def __call__( + self, + inputs: Optional[mx.array] = ..., + cache: Optional[list[Any]] = ..., + input_embeddings: Optional[mx.array] = ..., + per_layer_inputs: Optional[mx.array] = ..., + ) -> mx.array: ... + def _get_per_layer_inputs( + self, + input_ids: Optional[mx.array], + input_embeddings: Optional[mx.array] = ..., + ) -> mx.array: ... + def _project_per_layer_inputs( + self, + input_embeddings: mx.array, + per_layer_inputs: Optional[mx.array] = ..., + ) -> mx.array: ... + def _make_masks(self, h: mx.array, cache: list[Any]) -> list[Any]: ... + +class Model(nn.Module): + args: ModelArgs + model_type: str + model: Gemma4TextModel + final_logit_softcapping: float + tie_word_embeddings: bool + lm_head: nn.Linear + + def __init__(self, args: ModelArgs) -> None: ... + def __call__(self, *args: Any, **kwargs: Any) -> mx.array: ... + def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ... + @property + def layers(self) -> list[DecoderLayer]: ... + @property + def head_dim(self) -> int: ... + @property + def n_kv_heads(self) -> int: ... + @property + def quant_predicate(self) -> Any: ... + def make_cache(self) -> list[KVCache | RotatingKVCache]: ... diff --git a/.mlx_typings/mlx_lm/tokenizer_utils.pyi b/.mlx_typings/mlx_lm/tokenizer_utils.pyi index 1326f059de..b0ed12c0b6 100644 --- a/.mlx_typings/mlx_lm/tokenizer_utils.pyi +++ b/.mlx_typings/mlx_lm/tokenizer_utils.pyi @@ -117,6 +117,8 @@ class TokenizerWrapper: think_end: str | None think_start_id: int | None think_end_id: int | None + think_start_tokens: list[int] | None + think_end_tokens: list[int] | None def __init__( self, diff --git a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-4bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-4bit.toml new file mode 100644 index 0000000000..a36497e8b8 --- /dev/null +++ b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-4bit.toml @@ -0,0 +1,15 @@ +model_id = "mlx-community/gemma-4-26b-a4b-it-4bit" +n_layers = 30 +hidden_size = 2816 +num_key_value_heads = 8 +supports_tensor = true +tasks = ["TextGeneration"] +family = "gemma" +quantization = "4bit" +base_model = "Gemma 4 26B A4B" +capabilities = ["text", "vision"] + +context_length = 262144 + +[storage_size] +in_bytes = 15608614044 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-6bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-6bit.toml new file mode 100644 index 0000000000..bd58231183 --- /dev/null +++ b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-6bit.toml @@ -0,0 +1,15 @@ +model_id = "mlx-community/gemma-4-26b-a4b-it-6bit" +n_layers = 30 +hidden_size = 2816 +num_key_value_heads = 8 +supports_tensor = true +tasks = ["TextGeneration"] +family = "gemma" +quantization = "6bit" +base_model = "Gemma 4 26B A4B" +capabilities = ["text", "vision"] + +context_length = 262144 + +[storage_size] +in_bytes = 21781015708 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-8bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-8bit.toml new file mode 100644 index 0000000000..9dda6aa449 --- /dev/null +++ b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-8bit.toml @@ -0,0 +1,15 @@ +model_id = "mlx-community/gemma-4-26b-a4b-it-8bit" +n_layers = 30 +hidden_size = 2816 +num_key_value_heads = 8 +supports_tensor = true +tasks = ["TextGeneration"] +family = "gemma" +quantization = "8bit" +base_model = "Gemma 4 26B A4B" +capabilities = ["text", "vision"] + +context_length = 262144 + +[storage_size] +in_bytes = 27953417372 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-bf16.toml b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-bf16.toml new file mode 100644 index 0000000000..ece3998288 --- /dev/null +++ b/resources/inference_model_cards/mlx-community--gemma-4-26b-a4b-it-bf16.toml @@ -0,0 +1,15 @@ +model_id = "mlx-community/gemma-4-26b-a4b-it-bf16" +n_layers = 30 +hidden_size = 2816 +num_key_value_heads = 8 +supports_tensor = true +tasks = ["TextGeneration"] +family = "gemma" +quantization = "bf16" +base_model = "Gemma 4 26B A4B" +capabilities = ["text", "vision"] + +context_length = 262144 + +[storage_size] +in_bytes = 51611872412 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-4bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-4bit.toml new file mode 100644 index 0000000000..d8a40f35ed --- /dev/null +++ b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-4bit.toml @@ -0,0 +1,15 @@ +model_id = "mlx-community/gemma-4-31b-it-4bit" +n_layers = 60 +hidden_size = 5376 +num_key_value_heads = 16 +supports_tensor = true +tasks = ["TextGeneration"] +family = "gemma" +quantization = "4bit" +base_model = "Gemma 4 31B" +capabilities = ["text", "vision"] + +context_length = 262144 + +[storage_size] +in_bytes = 18411755224 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-6bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-6bit.toml new file mode 100644 index 0000000000..6222ce70c2 --- /dev/null +++ b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-6bit.toml @@ -0,0 +1,15 @@ +model_id = "mlx-community/gemma-4-31b-it-6bit" +n_layers = 60 +hidden_size = 5376 +num_key_value_heads = 16 +supports_tensor = true +tasks = ["TextGeneration"] +family = "gemma" +quantization = "6bit" +base_model = "Gemma 4 31B" +capabilities = ["text", "vision"] + +context_length = 262144 + +[storage_size] +in_bytes = 26087306968 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-8bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-8bit.toml new file mode 100644 index 0000000000..863961183c --- /dev/null +++ b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-8bit.toml @@ -0,0 +1,15 @@ +model_id = "mlx-community/gemma-4-31b-it-8bit" +n_layers = 60 +hidden_size = 5376 +num_key_value_heads = 16 +supports_tensor = true +tasks = ["TextGeneration"] +family = "gemma" +quantization = "8bit" +base_model = "Gemma 4 31B" +capabilities = ["text", "vision"] + +context_length = 262144 + +[storage_size] +in_bytes = 33762858712 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-31b-it-bf16.toml b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-bf16.toml new file mode 100644 index 0000000000..1d4f740c24 --- /dev/null +++ b/resources/inference_model_cards/mlx-community--gemma-4-31b-it-bf16.toml @@ -0,0 +1,15 @@ +model_id = "mlx-community/gemma-4-31b-it-bf16" +n_layers = 60 +hidden_size = 5376 +num_key_value_heads = 16 +supports_tensor = true +tasks = ["TextGeneration"] +family = "gemma" +quantization = "bf16" +base_model = "Gemma 4 31B" +capabilities = ["text", "vision"] + +context_length = 262144 + +[storage_size] +in_bytes = 62546177752 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-e2b-it-4bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-e2b-it-4bit.toml new file mode 100644 index 0000000000..9d8f99324f --- /dev/null +++ b/resources/inference_model_cards/mlx-community--gemma-4-e2b-it-4bit.toml @@ -0,0 +1,15 @@ +model_id = "mlx-community/gemma-4-e2b-it-4bit" +n_layers = 35 +hidden_size = 1536 +num_key_value_heads = 1 +supports_tensor = true +tasks = ["TextGeneration"] +family = "gemma" +quantization = "4bit" +base_model = "Gemma 4 E2B" +capabilities = ["text", "vision"] + +context_length = 131072 + +[storage_size] +in_bytes = 3580765126 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-e2b-it-6bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-e2b-it-6bit.toml new file mode 100644 index 0000000000..c9c6d1a003 --- /dev/null +++ b/resources/inference_model_cards/mlx-community--gemma-4-e2b-it-6bit.toml @@ -0,0 +1,15 @@ +model_id = "mlx-community/gemma-4-e2b-it-6bit" +n_layers = 35 +hidden_size = 1536 +num_key_value_heads = 1 +supports_tensor = true +tasks = ["TextGeneration"] +family = "gemma" +quantization = "6bit" +base_model = "Gemma 4 E2B" +capabilities = ["text", "vision"] + +context_length = 131072 + +[storage_size] +in_bytes = 4739998662 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-e2b-it-8bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-e2b-it-8bit.toml new file mode 100644 index 0000000000..804da9076d --- /dev/null +++ b/resources/inference_model_cards/mlx-community--gemma-4-e2b-it-8bit.toml @@ -0,0 +1,15 @@ +model_id = "mlx-community/gemma-4-e2b-it-8bit" +n_layers = 35 +hidden_size = 1536 +num_key_value_heads = 1 +supports_tensor = true +tasks = ["TextGeneration"] +family = "gemma" +quantization = "8bit" +base_model = "Gemma 4 E2B" +capabilities = ["text", "vision"] + +context_length = 131072 + +[storage_size] +in_bytes = 5899232198 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-e2b-it-bf16.toml b/resources/inference_model_cards/mlx-community--gemma-4-e2b-it-bf16.toml new file mode 100644 index 0000000000..8f2920da32 --- /dev/null +++ b/resources/inference_model_cards/mlx-community--gemma-4-e2b-it-bf16.toml @@ -0,0 +1,15 @@ +model_id = "mlx-community/gemma-4-e2b-it-bf16" +n_layers = 35 +hidden_size = 1536 +num_key_value_heads = 1 +supports_tensor = true +tasks = ["TextGeneration"] +family = "gemma" +quantization = "bf16" +base_model = "Gemma 4 E2B" +capabilities = ["text", "vision"] + +context_length = 131072 + +[storage_size] +in_bytes = 10246357958 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-e4b-it-4bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-e4b-it-4bit.toml new file mode 100644 index 0000000000..1122bbeb73 --- /dev/null +++ b/resources/inference_model_cards/mlx-community--gemma-4-e4b-it-4bit.toml @@ -0,0 +1,15 @@ +model_id = "mlx-community/gemma-4-e4b-it-4bit" +n_layers = 42 +hidden_size = 2560 +num_key_value_heads = 2 +supports_tensor = true +tasks = ["TextGeneration"] +family = "gemma" +quantization = "4bit" +base_model = "Gemma 4 E4B" +capabilities = ["text", "vision"] + +context_length = 131072 + +[storage_size] +in_bytes = 5216992212 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-e4b-it-6bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-e4b-it-6bit.toml new file mode 100644 index 0000000000..6f3b430c31 --- /dev/null +++ b/resources/inference_model_cards/mlx-community--gemma-4-e4b-it-6bit.toml @@ -0,0 +1,15 @@ +model_id = "mlx-community/gemma-4-e4b-it-6bit" +n_layers = 42 +hidden_size = 2560 +num_key_value_heads = 2 +supports_tensor = true +tasks = ["TextGeneration"] +family = "gemma" +quantization = "6bit" +base_model = "Gemma 4 E4B" +capabilities = ["text", "vision"] + +context_length = 131072 + +[storage_size] +in_bytes = 7090961364 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-e4b-it-8bit.toml b/resources/inference_model_cards/mlx-community--gemma-4-e4b-it-8bit.toml new file mode 100644 index 0000000000..48e21fef1c --- /dev/null +++ b/resources/inference_model_cards/mlx-community--gemma-4-e4b-it-8bit.toml @@ -0,0 +1,15 @@ +model_id = "mlx-community/gemma-4-e4b-it-8bit" +n_layers = 42 +hidden_size = 2560 +num_key_value_heads = 2 +supports_tensor = true +tasks = ["TextGeneration"] +family = "gemma" +quantization = "8bit" +base_model = "Gemma 4 E4B" +capabilities = ["text", "vision"] + +context_length = 131072 + +[storage_size] +in_bytes = 8964930516 diff --git a/resources/inference_model_cards/mlx-community--gemma-4-e4b-it-bf16.toml b/resources/inference_model_cards/mlx-community--gemma-4-e4b-it-bf16.toml new file mode 100644 index 0000000000..db87c2ee4a --- /dev/null +++ b/resources/inference_model_cards/mlx-community--gemma-4-e4b-it-bf16.toml @@ -0,0 +1,15 @@ +model_id = "mlx-community/gemma-4-e4b-it-bf16" +n_layers = 42 +hidden_size = 2560 +num_key_value_heads = 2 +supports_tensor = true +tasks = ["TextGeneration"] +family = "gemma" +quantization = "bf16" +base_model = "Gemma 4 E4B" +capabilities = ["text", "vision"] + +context_length = 131072 + +[storage_size] +in_bytes = 15992314836 diff --git a/src/exo/api/adapters/chat_completions.py b/src/exo/api/adapters/chat_completions.py index 17ebbf6b01..6dd5d7c136 100644 --- a/src/exo/api/adapters/chat_completions.py +++ b/src/exo/api/adapters/chat_completions.py @@ -31,20 +31,22 @@ ) from exo.shared.types.common import CommandId from exo.shared.types.text_generation import ( + Base64Image, InputMessage, + InputMessageContent, TextGenerationTaskParams, resolve_reasoning_params, ) -def extract_base64_from_data_url(data_url: str) -> str: +def extract_base64_from_data_url(data_url: str) -> Base64Image: match = re.match(r"data:[^;]+;base64,(.+)", data_url) if match: - return match.group(1) - return data_url + return Base64Image(match.group(1)) + return Base64Image(data_url) -async def fetch_image_url(url: str) -> str: +async def fetch_image_url(url: str) -> Base64Image: headers = {"User-Agent": "exo/1.0"} async with ( create_http_session(timeout_profile="short") as session, @@ -52,7 +54,7 @@ async def fetch_image_url(url: str) -> str: ): resp.raise_for_status() data = await resp.read() - return base64.b64encode(data).decode("ascii") + return Base64Image(base64.b64encode(data).decode("ascii")) async def chat_request_to_text_generation( @@ -61,7 +63,7 @@ async def chat_request_to_text_generation( instructions: str | None = None input_messages: list[InputMessage] = [] chat_template_messages: list[dict[str, Any]] = [] - images: list[str] = [] + images: list[Base64Image] = [] for msg in request.messages: # Normalize content to string @@ -115,7 +117,9 @@ async def chat_request_to_text_generation( continue if msg.role in ("user", "assistant", "developer"): - input_messages.append(InputMessage(role=msg.role, content=content)) + input_messages.append( + InputMessage(role=msg.role, content=InputMessageContent(content)) + ) # Build full message dict for chat template (preserves tool_calls etc.) # Normalize content for model_dump @@ -144,8 +148,8 @@ async def chat_request_to_text_generation( model=request.model, input=input_messages if input_messages - else [InputMessage(role="user", content="")], - instructions=instructions, + else [InputMessage(role="user", content=InputMessageContent(""))], + instructions=InputMessageContent(instructions) if instructions else None, max_output_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, diff --git a/src/exo/api/adapters/claude.py b/src/exo/api/adapters/claude.py index 0b53e330c2..04d700630d 100644 --- a/src/exo/api/adapters/claude.py +++ b/src/exo/api/adapters/claude.py @@ -37,7 +37,13 @@ ToolCallChunk, ) from exo.shared.types.common import CommandId -from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams +from exo.shared.types.text_generation import ( + Base64Image, + ChatTemplateValue, + InputMessage, + InputMessageContent, + TextGenerationTaskParams, +) def finish_reason_to_claude_stop_reason( @@ -88,8 +94,8 @@ def claude_request_to_text_generation( ) -> TextGenerationTaskParams: # Handle system message instructions: str | None = None - chat_template_messages: list[dict[str, Any]] = [] - images: list[str] = [] + chat_template_messages: list[dict[str, ChatTemplateValue]] = [] + images: list[Base64Image] = [] if request.system: if isinstance(request.system, str): @@ -98,14 +104,20 @@ def claude_request_to_text_generation( instructions = "".join(block.text for block in request.system) instructions = _strip_volatile_headers(instructions) - chat_template_messages.append({"role": "system", "content": instructions}) + chat_template_messages.append( + {"role": "system", "content": InputMessageContent(instructions)} + ) # Convert messages to input input_messages: list[InputMessage] = [] for msg in request.messages: if isinstance(msg.content, str): - input_messages.append(InputMessage(role=msg.role, content=msg.content)) - chat_template_messages.append({"role": msg.role, "content": msg.content}) + input_messages.append( + InputMessage(role=msg.role, content=InputMessageContent(msg.content)) + ) + chat_template_messages.append( + {"role": msg.role, "content": InputMessageContent(msg.content)} + ) continue # Process structured content blocks @@ -120,10 +132,11 @@ def claude_request_to_text_generation( text_parts.append(block.text) elif isinstance(block, ClaudeImageBlock): if block.source.type == "base64" and block.source.data: - images.append(block.source.data) + images.append(Base64Image(block.source.data)) has_images = True elif block.source.type == "url" and block.source.url: - images.append(block.source.url) + # This is obviously wrong. Im not fixing it in this pr + images.append(Base64Image(block.source.url)) has_images = True elif isinstance(block, ClaudeThinkingBlock): thinking_parts.append(block.thinking) @@ -144,10 +157,11 @@ def claude_request_to_text_generation( for sub in block.content: if isinstance(sub, ClaudeImageBlock): if sub.source.type == "base64" and sub.source.data: - images.append(sub.source.data) + images.append(Base64Image(sub.source.data)) has_images = True elif sub.source.type == "url" and sub.source.url: - images.append(sub.source.url) + # This is obviously wrong. Im not fixing it in this pr + images.append(Base64Image(sub.source.url)) has_images = True content = "".join(text_parts) @@ -155,7 +169,9 @@ def claude_request_to_text_generation( # Build InputMessage from text content if msg.role in ("user", "assistant"): - input_messages.append(InputMessage(role=msg.role, content=content)) + input_messages.append( + InputMessage(role=msg.role, content=InputMessageContent(content)) + ) # Build chat_template_messages preserving tool structure if tool_calls: @@ -216,8 +232,8 @@ def claude_request_to_text_generation( model=request.model, input=input_messages if input_messages - else [InputMessage(role="user", content="")], - instructions=instructions, + else [InputMessage(role="user", content=InputMessageContent(""))], + instructions=InputMessageContent(instructions) if instructions else None, max_output_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, diff --git a/src/exo/api/adapters/ollama.py b/src/exo/api/adapters/ollama.py index 7643637063..1f027616ab 100644 --- a/src/exo/api/adapters/ollama.py +++ b/src/exo/api/adapters/ollama.py @@ -21,7 +21,12 @@ ToolCallChunk, ) from exo.shared.types.common import CommandId -from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams +from exo.shared.types.text_generation import ( + Base64Image, + InputMessage, + InputMessageContent, + TextGenerationTaskParams, +) def _map_done_reason( @@ -82,14 +87,14 @@ def ollama_request_to_text_generation( instructions: str | None = None input_messages: list[InputMessage] = [] chat_template_messages: list[dict[str, Any]] = [] - images: list[str] = [] + images: list[Base64Image] = [] tool_message_index = 0 for msg in request.messages: content = msg.content or "" has_images = False if msg.images: - images.extend(msg.images) + images.extend(map(Base64Image, msg.images)) has_images = True if msg.role == "system": @@ -103,7 +108,9 @@ def ollama_request_to_text_generation( if msg.role in ("user", "assistant") and ( msg.content is not None or msg.thinking is not None or msg.tool_calls ): - input_messages.append(InputMessage(role=msg.role, content=content)) + input_messages.append( + InputMessage(role=msg.role, content=InputMessageContent(content)) + ) if has_images: multimodal: list[dict[str, Any]] = [ @@ -113,7 +120,9 @@ def ollama_request_to_text_generation( multimodal.append({"type": "text", "text": content}) chat_template_messages.append({"role": msg.role, "content": multimodal}) if msg.role in ("user", "assistant"): - input_messages.append(InputMessage(role=msg.role, content=content)) + input_messages.append( + InputMessage(role=msg.role, content=InputMessageContent(content)) + ) continue dumped: dict[str, Any] = {"role": msg.role, "content": content} if msg.thinking is not None: @@ -153,8 +162,8 @@ def ollama_request_to_text_generation( model=request.model, input=input_messages if input_messages - else [InputMessage(role="user", content="")], - instructions=instructions, + else [InputMessage(role="user", content=InputMessageContent(""))], + instructions=InputMessageContent(instructions) if instructions else None, max_output_tokens=options.num_predict if options else None, temperature=options.temperature if options else None, top_p=options.top_p if options else None, @@ -327,11 +336,11 @@ def ollama_generate_request_to_text_generation( ) -> TextGenerationTaskParams: """Convert Ollama generate request to exo's internal text generation format.""" chat_template_messages: list[dict[str, Any]] = [] - images: list[str] = [] + images: list[Base64Image] = [] if request.system: chat_template_messages.append({"role": "system", "content": request.system}) if request.images: - images.extend(request.images) + images.extend(map(Base64Image, request.images)) multimodal: list[dict[str, Any]] = [{"type": "image"} for _ in request.images] if request.prompt: multimodal.append({"type": "text", "text": request.prompt}) @@ -342,8 +351,8 @@ def ollama_generate_request_to_text_generation( options = request.options return TextGenerationTaskParams( model=request.model, - input=[InputMessage(role="user", content=request.prompt)], - instructions=request.system, + input=[InputMessage(role="user", content=InputMessageContent(request.prompt))], + instructions=InputMessageContent(request.system) if request.system else None, max_output_tokens=options.num_predict if options else None, temperature=options.temperature if options else None, top_p=options.top_p if options else None, diff --git a/src/exo/api/adapters/responses.py b/src/exo/api/adapters/responses.py index 2e8baf14b2..a3c248c0d2 100644 --- a/src/exo/api/adapters/responses.py +++ b/src/exo/api/adapters/responses.py @@ -74,7 +74,9 @@ ) from exo.shared.types.common import CommandId from exo.shared.types.text_generation import ( + Base64Image, InputMessage, + InputMessageContent, TextGenerationTaskParams, resolve_reasoning_params, ) @@ -99,9 +101,11 @@ async def responses_request_to_text_generation( ) -> TextGenerationTaskParams: input_value: list[InputMessage] built_chat_template: list[dict[str, Any]] | None = None - images: list[str] = [] + images: list[Base64Image] = [] if isinstance(request.input, str): - input_value = [InputMessage(role="user", content=request.input)] + input_value = [ + InputMessage(role="user", content=InputMessageContent(request.input)) + ] else: input_messages: list[InputMessage] = [] chat_template_messages: list[dict[str, Any]] = [] @@ -130,7 +134,9 @@ async def responses_request_to_text_generation( has_images = True if item.role in ("user", "assistant", "developer"): input_messages.append( - InputMessage(role=item.role, content=content) + InputMessage( + role=item.role, content=InputMessageContent(content) + ) ) if item.role == "system": chat_template_messages.append( @@ -327,7 +333,7 @@ async def responses_request_to_text_generation( input_value = ( input_messages if input_messages - else [InputMessage(role="user", content="")] + else [InputMessage(role="user", content=InputMessageContent(""))] ) built_chat_template = chat_template_messages if chat_template_messages else None @@ -361,7 +367,9 @@ async def responses_request_to_text_generation( return TextGenerationTaskParams( model=request.model, input=input_value, - instructions=request.instructions, + instructions=InputMessageContent(request.instructions) + if request.instructions + else None, max_output_tokens=request.max_output_tokens, temperature=request.temperature, top_p=request.top_p, diff --git a/src/exo/api/main.py b/src/exo/api/main.py index 3c183f4f79..a4c835987e 100644 --- a/src/exo/api/main.py +++ b/src/exo/api/main.py @@ -176,7 +176,7 @@ ) from exo.shared.types.memory import Memory from exo.shared.types.state import State -from exo.shared.types.text_generation import TextGenerationTaskParams +from exo.shared.types.text_generation import Base64Image, TextGenerationTaskParams from exo.shared.types.worker.downloads import DownloadCompleted from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta from exo.shared.types.worker.shards import Sharding @@ -750,9 +750,11 @@ async def _send_text_generation_with_images( self._sent_image_hashes.add(h) new_images.append((idx, img)) + wrapped_hashes = {idx: Base64Image(h) for idx, h in cached_hashes.items()} + if not new_images: task_params = task_params.model_copy( - update={"images": [], "image_hashes": cached_hashes} + update={"images": [], "image_hashes": wrapped_hashes} ) command = TextGeneration(task_params=task_params) await self._send(command) @@ -766,7 +768,7 @@ async def _send_text_generation_with_images( task_params = task_params.model_copy( update={ "images": [], - "image_hashes": cached_hashes, + "image_hashes": wrapped_hashes, "total_input_chunks": len(all_chunks), "image_count": len(new_images), } @@ -1401,7 +1403,9 @@ async def claude_messages( resolved_images.append(await fetch_image_url(img)) else: resolved_images.append(img) - task_params = task_params.model_copy(update={"images": resolved_images}) + task_params = task_params.model_copy( + update={"images": [Base64Image(x) for x in resolved_images]} + ) resolved_model = await self._resolve_and_validate_text_model( ModelId(task_params.model) ) diff --git a/src/exo/master/tests/test_master.py b/src/exo/master/tests/test_master.py index 9d4e579366..c4a1cff0c0 100644 --- a/src/exo/master/tests/test_master.py +++ b/src/exo/master/tests/test_master.py @@ -31,7 +31,11 @@ ) from exo.shared.types.tasks import TaskStatus from exo.shared.types.tasks import TextGeneration as TextGenerationTask -from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams +from exo.shared.types.text_generation import ( + InputMessage, + InputMessageContent, + TextGenerationTaskParams, +) from exo.shared.types.worker.instances import ( InstanceMeta, MlxRingInstance, @@ -159,7 +163,10 @@ def _get_events() -> Sequence[IndexedEvent]: task_params=TextGenerationTaskParams( model=ModelId("llama-3.2-1b"), input=[ - InputMessage(role="user", content="Hello, how are you?") + InputMessage( + role="user", + content=InputMessageContent("Hello, how are you?"), + ) ], ), ) @@ -213,7 +220,11 @@ def _get_events() -> Sequence[IndexedEvent]: assert isinstance(events[2].event.task, TextGenerationTask) assert events[2].event.task.task_params == TextGenerationTaskParams( model=ModelId("llama-3.2-1b"), - input=[InputMessage(role="user", content="Hello, how are you?")], + input=[ + InputMessage( + role="user", content=InputMessageContent("Hello, how are you?") + ) + ], ) ev_send.close() diff --git a/src/exo/master/tests/test_placement.py b/src/exo/master/tests/test_placement.py index 3a8a36c087..40530ad288 100644 --- a/src/exo/master/tests/test_placement.py +++ b/src/exo/master/tests/test_placement.py @@ -23,7 +23,11 @@ from exo.shared.types.multiaddr import Multiaddr from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo from exo.shared.types.tasks import TaskId, TaskStatus, TextGeneration -from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams +from exo.shared.types.text_generation import ( + InputMessage, + InputMessageContent, + TextGenerationTaskParams, +) from exo.shared.types.topology import Connection, SocketConnection from exo.shared.types.worker.downloads import ( DownloadCompleted, @@ -481,7 +485,7 @@ def _make_task( command_id=CommandId(), task_params=TextGenerationTaskParams( model=ModelId("test-model"), - input=[InputMessage(role="user", content="hello")], + input=[InputMessage(role="user", content=InputMessageContent("hello"))], ), ) diff --git a/src/exo/shared/models/model_cards.py b/src/exo/shared/models/model_cards.py index 50da812569..bdbe3f636b 100644 --- a/src/exo/shared/models/model_cards.py +++ b/src/exo/shared/models/model_cards.py @@ -261,6 +261,7 @@ def supports_tensor(self) -> bool: ["GptOssForCausalLM"], ["Step3p5ForCausalLM"], ["NemotronHForCausalLM"], + ["Gemma4ForConditionalGeneration"], ] @model_validator(mode="before") @@ -287,7 +288,7 @@ def defer_to_text_config(cls, data: dict[str, Any], info: ValidationInfo): image_token_id = data.get("image_token_id") if vision_config is not None and image_token_id is not None: model_type = str( - vision_config.get("model_type", data.get("model_type", "")) # pyright: ignore[reportAny] + data.get("model_type", vision_config.get("model_type", "")) # pyright: ignore[reportAny] ) assert info.context is not None diff --git a/src/exo/shared/types/text_generation.py b/src/exo/shared/types/text_generation.py index 29affa671b..f2cd0c45c7 100644 --- a/src/exo/shared/types/text_generation.py +++ b/src/exo/shared/types/text_generation.py @@ -4,13 +4,14 @@ are converted to TextGenerationTaskParams at the API boundary via adapters. """ -from typing import Any, Literal +from typing import Annotated, Any, Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, GetCoreSchemaHandler, WrapValidator +from pydantic_core import CoreSchema, core_schema from exo.shared.types.common import ModelId -MessageRole = Literal["user", "assistant", "system", "developer"] +MessageRole = Literal["user", "assistant", "system", "developer", "tool"] ReasoningEffort = Literal["none", "minimal", "low", "medium", "high", "xhigh"] @@ -36,11 +37,63 @@ def resolve_reasoning_params( return resolved_effort, resolved_thinking +class InputMessageContent(str): + @classmethod + def __get_pydantic_core_schema__( + cls, + source_type: Any, # pyright: ignore[reportAny] + handler: GetCoreSchemaHandler, + ) -> CoreSchema: + return core_schema.no_info_after_validator_function(cls, handler(str)) + + def __repr__(self): + return f"" + + class InputMessage(BaseModel, frozen=True): """Internal message for text generation pipelines.""" role: MessageRole - content: str + content: InputMessageContent + + +class Base64Image(str): + @classmethod + def __get_pydantic_core_schema__( + cls, + source_type: Any, # pyright: ignore[reportAny] + handler: GetCoreSchemaHandler, + ) -> CoreSchema: + return core_schema.no_info_after_validator_function(cls, handler(str)) + + def __repr__(self): + return f"" + + +def _wrap_chat_value(x: Any) -> Any: # pyright: ignore[reportAny] + if isinstance(x, (InputMessageContent, Base64Image)): + return x + if isinstance(x, str): + return InputMessageContent(x) + if isinstance(x, dict): + return {k: _wrap_chat_value(v) for k, v in x.items()} # pyright: ignore[reportUnknownVariableType] + if isinstance(x, list): + return [_wrap_chat_value(i) for i in x] # pyright: ignore[reportUnknownVariableType] + return x # pyright: ignore[reportAny] + + +type ChatTemplateValue = Annotated[ + InputMessageContent + | Base64Image + | dict[str, ChatTemplateValue] + | list[ChatTemplateValue] + | str + | int + | float + | MessageRole + | bool, + WrapValidator(lambda a, b: b(_wrap_chat_value(a))), # pyright: ignore[reportAny] +] class TextGenerationTaskParams(BaseModel, frozen=True): @@ -52,7 +105,7 @@ class TextGenerationTaskParams(BaseModel, frozen=True): model: ModelId input: list[InputMessage] - instructions: str | None = None + instructions: InputMessageContent | None = None max_output_tokens: int | None = None temperature: float | None = None top_p: float | None = None @@ -62,7 +115,7 @@ class TextGenerationTaskParams(BaseModel, frozen=True): top_k: int | None = None stop: str | list[str] | None = None seed: int | None = None - chat_template_messages: list[dict[str, Any]] | None = None + chat_template_messages: list[dict[str, ChatTemplateValue]] | None = None reasoning_effort: ReasoningEffort | None = None enable_thinking: bool | None = None logprobs: bool = False @@ -70,7 +123,7 @@ class TextGenerationTaskParams(BaseModel, frozen=True): min_p: float | None = None repetition_penalty: float | None = None repetition_context_size: int | None = None - images: list[str] = Field(default_factory=list) - image_hashes: dict[int, str] = Field(default_factory=dict) + images: list[Base64Image] = Field(default_factory=list) + image_hashes: dict[int, Base64Image] = Field(default_factory=dict) total_input_chunks: int = 0 image_count: int = 0 diff --git a/src/exo/worker/engines/mlx/auto_parallel.py b/src/exo/worker/engines/mlx/auto_parallel.py index 0085e2f540..a45c14028c 100644 --- a/src/exo/worker/engines/mlx/auto_parallel.py +++ b/src/exo/worker/engines/mlx/auto_parallel.py @@ -21,6 +21,7 @@ from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model from mlx_lm.models.deepseek_v32 import DeepseekV32MLP from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model +from mlx_lm.models.gemma4 import Model as Gemma4Model from mlx_lm.models.glm4_moe import Model as Glm4MoeModel from mlx_lm.models.glm4_moe import MoE from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer, Glm4MoeLiteMLP @@ -58,6 +59,12 @@ from mlx_lm.models.step3p5 import Step3p5Model as Step35InnerModel from exo.shared.types.worker.shards import PipelineShardMetadata +from exo.worker.engines.mlx.gemma4 import ( + is_gemma4_inner_model, + patch_gemma4_pipeline, + try_set_gemma4_pipeline_prefill, + try_set_gemma4_pipeline_queue_sends, +) from exo.worker.runner.bootstrap import logger if TYPE_CHECKING: @@ -225,12 +232,18 @@ def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: def set_pipeline_prefill(model: nn.Module, is_prefill: bool) -> None: + inner = get_inner_model(model) + if try_set_gemma4_pipeline_prefill(inner, is_prefill): + return for layer in model.layers: # type: ignore if isinstance(layer, (PipelineFirstLayer, PipelineLastLayer)): layer.is_prefill = is_prefill def set_pipeline_queue_sends(model: nn.Module, queue_sends: bool) -> None: + inner = get_inner_model(model) + if try_set_gemma4_pipeline_queue_sends(inner, queue_sends): + return for layer in model.layers: # type: ignore if isinstance(layer, PipelineLastLayer): layer.queue_sends = queue_sends @@ -331,13 +344,16 @@ def pipeline_auto_parallel( if on_layer_loaded is not None: on_layer_loaded(i, total) - layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group) - layers[-1] = PipelineLastLayer( - layers[-1], - device_rank, - world_size, - group=group, - ) + # Gemma 4 takes over the inner forward pass itself (see gemma4.py) so we + # skip the generic pipeline-layer wrapping for it. + if not is_gemma4_inner_model(inner_model_instance): + layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group) + layers[-1] = PipelineLastLayer( + layers[-1], + device_rank, + world_size, + group=group, + ) if isinstance(inner_model_instance, GptOssMoeModel): inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore @@ -391,6 +407,18 @@ def pipeline_auto_parallel( has_linear=bool(linear_layers), ) + if is_gemma4_inner_model(inner_model_instance): + patch_gemma4_pipeline( + model, + inner_model_instance, + start_layer, + end_layer, + device_rank, + world_size, + group, + _pending_prefill_sends, + ) + if isinstance(inner_model_instance, NemotronHInnerModel): # NemotronH uses block_type: "M" (Mamba/SSM), "*" (Attention), "E" (MoE), "-" (MLP) # Only "M" and "*" blocks have cache entries. @@ -447,7 +475,7 @@ def patched_call( ) # Add dependency to last cache entry to ensure distributed ops are evaluated - if cache is not None: + if cache is not None and len(cache) > 0: # type: ignore last = cache[-1] # type: ignore dep_cache = last[0] if hasattr(last, "caches") else last # type: ignore if hasattr(dep_cache, "keys") and dep_cache.keys is not None: # type: ignore @@ -609,6 +637,14 @@ def _sharded_to_all(path: str, weight: mx.array): all_to_sharded_linear_in_place, sharded_to_all_linear_in_place, ) + elif isinstance(model, Gemma4Model): + tensor_parallel_sharding_strategy = Gemma4ShardingStrategy( + group, + all_to_sharded_linear, + sharded_to_all_linear, + all_to_sharded_linear_in_place, + sharded_to_all_linear_in_place, + ) else: raise ValueError(f"Unsupported model type: {type(model)}") @@ -1395,3 +1431,59 @@ def _shard_mamba2_mixer(self, mixer: NemotronHMamba2Mixer, rank: int) -> None: mixer.intermediate_size = is_per_rank mixer.conv_dim = new_conv_dim mixer.heads_per_group = heads_per_rank // groups_per_rank + + +class WrappedGemma4Experts(CustomMlxLayer): + def __init__(self, layer: _LayerCallable): + super().__init__(layer) + self.sharding_group: mx.distributed.Group | None = None + + def __call__( + self, x: mx.array, top_k_indices: mx.array, top_k_weights: mx.array + ) -> mx.array: + if self.sharding_group is not None: + x = sum_gradients(self.sharding_group)(x) + y: mx.array = self.original_layer(x, top_k_indices, top_k_weights) + if self.sharding_group is not None: + y = mx.distributed.all_sum(y, group=self.sharding_group) + return y + + +class Gemma4ShardingStrategy(TensorParallelShardingStrategy): + def shard_model( + self, + model: nn.Module, + timeout_seconds: float, + on_timeout: TimeoutCallback | None, + on_layer_loaded: LayerLoadedCallback | None, + ) -> nn.Module: + model = cast(Gemma4Model, model) + layers = model.language_model.model.layers + total = len(layers) + for i, layer in enumerate(layers): + eval_with_timeout(layer.parameters(), timeout_seconds / total, on_timeout) + + attn = layer.self_attn + attn.q_proj = self.all_to_sharded_linear(attn.q_proj) + attn.k_proj = self.all_to_sharded_linear(attn.k_proj) + if not attn.use_k_eq_v: + attn.v_proj = self.all_to_sharded_linear(attn.v_proj) + attn.o_proj = self.sharded_to_all_linear(attn.o_proj) + attn.n_heads //= self.N + attn.n_kv_heads //= self.N + + layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj) + layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj) + layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj) + + if layer.enable_moe: + self.all_to_sharded_linear_in_place(layer.experts.switch_glu.gate_proj) + self.sharded_to_all_linear_in_place(layer.experts.switch_glu.down_proj) + self.all_to_sharded_linear_in_place(layer.experts.switch_glu.up_proj) + layer.experts = WrappedGemma4Experts(layer.experts) # pyright: ignore[reportAttributeAccessIssue,reportArgumentType] + layer.experts.sharding_group = self.group + + mx.eval(layer) + if on_layer_loaded is not None: + on_layer_loaded(i, total) + return model diff --git a/src/exo/worker/engines/mlx/cache.py b/src/exo/worker/engines/mlx/cache.py index 6376917d78..ad941f90e4 100644 --- a/src/exo/worker/engines/mlx/cache.py +++ b/src/exo/worker/engines/mlx/cache.py @@ -338,7 +338,7 @@ def _entry_length( def cache_length(cache: KVCacheType) -> int: """Get the number of tokens in a KV cache.""" - return max(_entry_length(c) for c in cache) + return max((_entry_length(c) for c in cache), default=0) def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int: diff --git a/src/exo/worker/engines/mlx/gemma4.py b/src/exo/worker/engines/mlx/gemma4.py new file mode 100644 index 0000000000..3db2263384 --- /dev/null +++ b/src/exo/worker/engines/mlx/gemma4.py @@ -0,0 +1,285 @@ +from typing import Any + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.cache import KVCache, RotatingKVCache +from mlx_lm.models.gemma4_text import Gemma4TextModel +from mlx_lm.models.gemma4_text import ModelArgs as Gemma4TextModelArgs + +type _IntermediateEntry = tuple[tuple[mx.array, mx.array] | None, mx.array | None] +type _SourceKvs = tuple[mx.array, mx.array, mx.array] + + +# TODO: Really really ugly code that needs refactoring ASAP (but it works) + + +def is_gemma4_inner_model(inner: nn.Module) -> bool: + return isinstance(inner, Gemma4TextModel) + + +def is_gemma4_pipeline_model(inner: nn.Module) -> bool: + return isinstance(inner, Gemma4TextModel) and hasattr( + inner, "_gemma4_pipeline_group" + ) + + +def try_set_gemma4_pipeline_prefill(inner: nn.Module, is_prefill: bool) -> bool: + if isinstance(inner, Gemma4TextModel) and hasattr(inner, "_gemma4_is_prefill"): + inner._gemma4_is_prefill = is_prefill + return True + return False + + +def try_set_gemma4_pipeline_queue_sends(inner: nn.Module, queue_sends: bool) -> bool: + if isinstance(inner, Gemma4TextModel) and hasattr(inner, "_gemma4_queue_sends"): + inner._gemma4_queue_sends = queue_sends + return True + return False + + +def _offset_template() -> mx.array: + return mx.array(0, dtype=mx.int32) + + +def _kv_template( + args: Gemma4TextModelArgs, + layer_types_global: list[str], + h: mx.array, + source_global_idx: int, + seq_len: int, +) -> mx.array: + is_full = ( + source_global_idx < len(layer_types_global) + and layer_types_global[source_global_idx] == "full_attention" + ) + head_dim = ( + int(args.global_head_dim) + if is_full and args.global_head_dim + else int(args.head_dim) + ) + n_kv_heads = int(args.num_key_value_heads) + if ( + is_full + and args.attention_k_eq_v + and args.num_global_key_value_heads is not None + ): + n_kv_heads = int(args.num_global_key_value_heads) + if not is_full and args.sliding_window: + seq_len = min(seq_len, int(args.sliding_window)) + return mx.zeros((int(h.shape[0]), n_kv_heads, seq_len, head_dim), dtype=h.dtype) + + +def patch_gemma4_pipeline( + model: nn.Module, + inner: nn.Module, + start_layer: int, + end_layer: int, + device_rank: int, + world_size: int, + group: mx.distributed.Group, + pending_prefill_sends: list[tuple[mx.array, int, mx.distributed.Group]], +) -> None: + assert isinstance(inner, Gemma4TextModel) + args = inner.config + num_hidden_layers_global = int(args.num_hidden_layers) + num_kv_shared_layers_global = int(args.num_kv_shared_layers) + first_kv_shared_global = num_hidden_layers_global - num_kv_shared_layers_global + layer_types_global = list(args.layer_types or []) + + previous_kvs_global = list(inner.previous_kvs) + + consumed_global_sources: list[int] = sorted( + set(previous_kvs_global[first_kv_shared_global:]) + ) + # Source layers we own locally to local layer index. + local_owned_sources: dict[int, int] = { + g_idx: g_idx - start_layer + for g_idx in consumed_global_sources + if start_layer <= g_idx < end_layer + } + # Local shared-layer slot → the source's global index. + local_shared_to_global_source: dict[int, int] = { + g_idx - start_layer: previous_kvs_global[g_idx] + for g_idx in range(max(start_layer, first_kv_shared_global), end_layer) + } + + new_previous_kvs: list[int] = [] + for g_idx in range(start_layer, end_layer): + local_idx = g_idx - start_layer + source_g = previous_kvs_global[g_idx] + if start_layer <= source_g < end_layer: + new_previous_kvs.append(source_g - start_layer) + else: + new_previous_kvs.append(local_idx) + inner.previous_kvs = new_previous_kvs + + sliding_window = int(args.sliding_window) + + def _make_cache() -> list[KVCache | RotatingKVCache]: + local_source_end = min(first_kv_shared_global, end_layer) + caches: list[KVCache | RotatingKVCache] = [] + for g_idx in range(start_layer, local_source_end): + if layer_types_global[g_idx] == "full_attention": + caches.append(KVCache()) + else: + caches.append(RotatingKVCache(max_size=sliding_window, keep=0)) + return caches + + model.make_cache = _make_cache + + inner._gemma4_pipeline_group = group + inner._gemma4_device_rank = device_rank + inner._gemma4_world_size = world_size + inner._gemma4_is_prefill = False + inner._gemma4_queue_sends = False + # Fallback counter for tracking the source-kv sequence length on ranks + # whose local cache list is empty (e.g. a rank that only owns shared + # layers). Normal ranks read the offset from their local KVCache. + inner._gemma4_prefix_counter = 0 + + next_rank = (device_rank + 1) % world_size + prev_rank = device_rank - 1 + + def patched_call( + self: Gemma4TextModel, + inputs: mx.array | None = None, + cache: list[Any] | None = None, + input_embeddings: mx.array | None = None, + per_layer_inputs: mx.array | None = None, + ) -> mx.array: + if input_embeddings is None: + assert inputs is not None + input_embeddings = self.embed_tokens(inputs) + h: mx.array = input_embeddings * self.embed_scale + + if self.hidden_size_per_layer_input: + if per_layer_inputs is None: + per_layer_inputs = self._get_per_layer_inputs(inputs, input_embeddings) + per_layer_inputs = self._project_per_layer_inputs(h, per_layer_inputs) + # Both helpers above return tensors shaped for the GLOBAL layer + # count (see patch_gemma4_pipeline). Slice to this shard's layers + # so the local loop indexes into the right slots. + per_layer_inputs = per_layer_inputs[:, :, start_layer:end_layer, :] + per_layer_inputs_list: list[mx.array | None] = ( + [per_layer_inputs[:, :, i, :] for i in range(len(self.layers))] + if per_layer_inputs is not None + else [None] * len(self.layers) + ) + + local_cache: list[KVCache | RotatingKVCache | None] = ( + [None] * len(self.layers) + if cache is None + else cache + [None] * (len(self.layers) - len(cache)) + ) + masks: list[mx.array] = self._make_masks(h, local_cache) + + prior_offset: int = int(getattr(self, "_gemma4_prefix_counter", 0)) + + current_seq_len = prior_offset + int(h.shape[1]) + + if device_rank != 0: + mx.eval(h) + h = mx.distributed.recv_like(h, prev_rank, group=group) + mx.eval(h) + + # Receive the source kvs. We always recv ALL consumed + # sources (even ones we own locally) so the byte counts on each side + # of the wire match — entries we own locally are simply ignored when + # seeding intermediates below. + received_source_kvs: dict[int, _SourceKvs] = {} + if device_rank != 0: + offset_template = _offset_template() + len_template = mx.array(0, dtype=mx.int32) + for g_idx in consumed_global_sources: + rlen = mx.distributed.recv_like(len_template, prev_rank, group=group) + mx.eval(rlen) + actual_seq_len = int(rlen.item()) + kv_template = _kv_template( + args, layer_types_global, h, g_idx, actual_seq_len + ) + rk = mx.distributed.recv_like(kv_template, prev_rank, group=group) + mx.eval(rk) + rv = mx.distributed.recv_like(kv_template, prev_rank, group=group) + mx.eval(rv) + ro = mx.distributed.recv_like(offset_template, prev_rank, group=group) + mx.eval(ro) + received_source_kvs[g_idx] = (rk, rv, ro) + + intermediates: list[_IntermediateEntry] = [(None, None)] * len(self.layers) + for local_idx, source_g_idx in local_shared_to_global_source.items(): + if source_g_idx in local_owned_sources: + continue + entry = received_source_kvs.get(source_g_idx) + if entry is not None: + rk, rv, ro = entry + intermediates[local_idx] = ((rk, rv), ro) + + for idx in range(len(self.layers)): + layer = self.layers[idx] + prev_idx = self.previous_kvs[idx] + kvs, offset = intermediates[prev_idx] + h, new_kvs, new_offset = layer( + h, + masks[idx], + local_cache[idx], + per_layer_input=per_layer_inputs_list[idx], + shared_kv=kvs, + offset=offset, + ) + intermediates[idx] = (new_kvs, new_offset) + mx.eval(h) + + # Build the outgoing source kvs. Start from whatever we received (so + # any source we don't own locally gets forwarded along), then + # overwrite with our own freshly-computed entries. + outgoing_source_kvs: dict[int, _SourceKvs] = dict(received_source_kvs) + for g_idx, local_idx in local_owned_sources.items(): + local_kvs, local_offset = intermediates[local_idx] + if local_kvs is not None and local_offset is not None: + outgoing_source_kvs[g_idx] = (local_kvs[0], local_kvs[1], local_offset) + + is_prefill: bool = bool(getattr(self, "_gemma4_is_prefill", False)) + queue_sends: bool = bool(getattr(self, "_gemma4_queue_sends", False)) + + if device_rank != world_size - 1: + mx.eval(h) + if queue_sends: + pending_prefill_sends.append((h, next_rank, group)) + else: + h = mx.distributed.send(h, next_rank, group=group) + mx.eval(h) + offset_template = _offset_template() + for g_idx in consumed_global_sources: + entry = outgoing_source_kvs.get(g_idx) + if entry is None: + kv_template = _kv_template( + args, layer_types_global, h, g_idx, current_seq_len + ) + entry = (kv_template, kv_template, offset_template) + kk, vv, oo = entry + mx.eval(kk, vv, oo) + if queue_sends: + pending_prefill_sends.append((kk, next_rank, group)) + pending_prefill_sends.append((vv, next_rank, group)) + pending_prefill_sends.append((oo, next_rank, group)) + else: + actual_len = mx.array(int(kk.shape[2]), dtype=mx.int32) + sent_len = mx.distributed.send(actual_len, next_rank, group=group) + mx.eval(sent_len) + kk = mx.distributed.send(kk, next_rank, group=group) + mx.eval(kk) + vv = mx.distributed.send(vv, next_rank, group=group) + mx.eval(vv) + oo = mx.distributed.send(oo, next_rank, group=group) + mx.eval(oo) + + self._gemma4_prefix_counter = prior_offset + int(h.shape[1]) + + if not is_prefill: + mx.eval(h) + h = mx.distributed.all_gather(h, group=group)[-h.shape[0] :] + mx.eval(h) + + return self.norm(h) + + type(inner).__call__ = patched_call diff --git a/src/exo/worker/engines/mlx/generator/batch_generate.py b/src/exo/worker/engines/mlx/generator/batch_generate.py index 42e02eea54..2f58e038db 100644 --- a/src/exo/worker/engines/mlx/generator/batch_generate.py +++ b/src/exo/worker/engines/mlx/generator/batch_generate.py @@ -190,7 +190,11 @@ def submit( vision_ctx = ( patch_embed_tokens( - self.model, vision.embeddings, prefix_hit_length, len(prompt_tokens) - 1 + self.model, + vision.embeddings, + prefix_hit_length, + len(prompt_tokens) - 1, + image_token_id=vision.image_token_id, ) if vision is not None else contextlib.nullcontext() diff --git a/src/exo/worker/engines/mlx/generator/generate.py b/src/exo/worker/engines/mlx/generator/generate.py index eef11101ce..a3f9c402a8 100644 --- a/src/exo/worker/engines/mlx/generator/generate.py +++ b/src/exo/worker/engines/mlx/generator/generate.py @@ -6,6 +6,7 @@ from typing import Callable, Generator, cast, get_args import mlx.core as mx +import mlx.nn as nn from mlx_lm.generate import ( maybe_quantize_kv_cache, stream_generate, @@ -25,7 +26,11 @@ from exo.shared.types.common import ModelId from exo.shared.types.memory import Memory from exo.shared.types.mlx import KVCacheType, Model -from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams +from exo.shared.types.text_generation import ( + InputMessage, + InputMessageContent, + TextGenerationTaskParams, +) from exo.shared.types.worker.runner_response import ( GenerationResponse, ) @@ -51,6 +56,7 @@ KV_GROUP_SIZE, MAX_TOKENS, ) +from exo.worker.engines.mlx.gemma4 import is_gemma4_pipeline_model from exo.worker.engines.mlx.utils_mlx import ( apply_chat_template, detect_thinking_prompt_suffix, @@ -74,7 +80,11 @@ @contextlib.contextmanager def patch_embed_tokens( - model: Model, embeddings: mx.array, start_offset: int = 0, token_count: int = 0 + model: Model, + embeddings: mx.array, + start_offset: int = 0, + token_count: int = 0, + image_token_id: int | None = None, ) -> Generator[None]: inner = get_inner_model(model) # type: ignore original_embed = inner.embed_tokens # type: ignore @@ -82,15 +92,30 @@ def patch_embed_tokens( offset = [start_offset] def _inject(input_ids: mx.array) -> mx.array: - start = offset[0] - if start >= end_offset: - return original_embed(input_ids) # type: ignore + chunk_start = offset[0] chunk_len = input_ids.shape[-1] - end = min(start + chunk_len, end_offset) - offset[0] = end - if end - start < chunk_len: + chunk_end = chunk_start + chunk_len + offset[0] = chunk_end + + # The injection window is [start_offset, end_offset). + if chunk_end <= start_offset or chunk_start >= end_offset: return original_embed(input_ids) # type: ignore - return embeddings[:, start:end, :] + + # Mixed chunk: splice the pre-computed embeddings for the overlap + # into `original_embed(input_ids)` for any text-only fringes. + overlap_start = max(chunk_start, start_offset) + overlap_end = min(chunk_end, end_offset) + dst_start = overlap_start - chunk_start + dst_end = overlap_end - chunk_start + text_embeds: mx.array = original_embed(input_ids) # type: ignore + return mx.concatenate( + [ + text_embeds[:, :dst_start, :], + embeddings[:, overlap_start:overlap_end, :], + text_embeds[:, dst_end:, :], + ], + axis=1, + ) for attr in dir(original_embed): # type: ignore if not attr.startswith("_") and not hasattr(_inject, attr): @@ -98,10 +123,30 @@ def _inject(input_ids: mx.array) -> mx.array: setattr(_inject, attr, getattr(original_embed, attr)) # type: ignore inner.embed_tokens = _inject + + # Gemma 4 (e2b/e4b) has a second, independent embedding table that produces + # per-layer conditioning signals via self.embed_tokens_per_layer(input_ids). + # The injected vision embeddings live in the main residual stream only, so + # if image_token_id positions are passed through as-is the per-layer table + # produces garbage signals at those positions (the `` token was never + # trained to have meaningful per-layer inputs). + original_per_layer = getattr(inner, "embed_tokens_per_layer", None) # type: ignore + if original_per_layer is not None and image_token_id is not None: + + def _clean_per_layer(input_ids: mx.array) -> mx.array: + clean_ids = mx.where( + input_ids == image_token_id, mx.zeros_like(input_ids), input_ids + ) + return original_per_layer(clean_ids) # type: ignore + + inner.embed_tokens_per_layer = _clean_per_layer + try: yield finally: inner.embed_tokens = original_embed + if original_per_layer is not None and image_token_id is not None: + inner.embed_tokens_per_layer = original_per_layer class PrefillCancelled(BaseException): @@ -109,6 +154,9 @@ class PrefillCancelled(BaseException): def _has_pipeline_communication_layer(model: Model): + inner: nn.Module = get_inner_model(model) # type: ignore + if is_gemma4_pipeline_model(inner): + return True for layer in model.layers: if isinstance(layer, (PipelineFirstLayer, PipelineLastLayer)): return True @@ -356,7 +404,9 @@ def warmup_inference( ) -> int: logger.info(f"warming up inference for instance: {model_id}") - content = "Prompt to warm up the inference engine. Repeat this." + content = InputMessageContent( + "Prompt to warm up the inference engine. Repeat this." + ) warmup_task_params = TextGenerationTaskParams( model=model_id, @@ -571,7 +621,11 @@ def mlx_generate( maybe_vision_ctx = ( patch_embed_tokens( - model, vision.embeddings, prefix_hit_length, len(prompt_tokens) - 1 + model, + vision.embeddings, + prefix_hit_length, + len(prompt_tokens) - 1, + image_token_id=vision.image_token_id, ) if vision is not None else contextlib.nullcontext() diff --git a/src/exo/worker/engines/mlx/utils_mlx.py b/src/exo/worker/engines/mlx/utils_mlx.py index 790dcd8d3b..63f1b6ed84 100644 --- a/src/exo/worker/engines/mlx/utils_mlx.py +++ b/src/exo/worker/engines/mlx/utils_mlx.py @@ -45,7 +45,7 @@ from exo.shared.types.memory import Memory from exo.shared.types.mlx import Model from exo.shared.types.tasks import TaskId, TextGeneration -from exo.shared.types.text_generation import TextGenerationTaskParams +from exo.shared.types.text_generation import ChatTemplateValue, TextGenerationTaskParams from exo.shared.types.worker.instances import ( BoundInstance, MlxJacclInstance, @@ -336,6 +336,8 @@ def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None: elif "qwen3.5" in model_id_lower or "qwen-3.5" in model_id_lower: # For Qwen3.5: 248046 (<|im_end|>), 248044 (<|endoftext|>) return [248046, 248044] + elif "gemma-4" in model_id_lower or "gemma-3" in model_id_lower: + return [1, 106, 50] return None @@ -407,23 +409,13 @@ def _patched_encode(text: str, **_kwargs: object) -> list[int]: tool_parser=_parse_kimi_tool_calls, ) + # We should really consider going back to mlx lm load to get tokenizer tokenizer = load_tokenizer( model_path, tokenizer_config_extra={"trust_remote_code": trust_remote_code}, eos_token_ids=eos_token_ids, ) - if "gemma-3" in model_id_lower: - gemma_3_eos_id = 1 - gemma_3_end_of_turn_id = 106 - if tokenizer.eos_token_ids is not None: - if gemma_3_end_of_turn_id not in tokenizer.eos_token_ids: - tokenizer.eos_token_ids = list(tokenizer.eos_token_ids) + [ - gemma_3_end_of_turn_id - ] - else: - tokenizer.eos_token_ids = [gemma_3_eos_id, gemma_3_end_of_turn_id] - return tokenizer @@ -611,10 +603,10 @@ def apply_chat_template( tokenizer: TokenizerWrapper, task_params: TextGenerationTaskParams, ) -> str: - messages: list[dict[str, Any]] = [] + messages: list[dict[str, ChatTemplateValue]] = [] if task_params.chat_template_messages is not None: # Use pre-formatted messages that preserve tool_calls, thinking, etc. - messages = list(task_params.chat_template_messages) + messages = task_params.chat_template_messages else: # Add system message (instructions) if present if task_params.instructions: @@ -642,7 +634,7 @@ def system_prompt_token_count( if task_params.chat_template_messages is not None: for msg in task_params.chat_template_messages: if msg.get("role") in ("system", "developer"): - content = msg.get("content", "") # type: ignore + content = msg.get("content", "") if isinstance(content, str): parts.append(content) else: @@ -671,21 +663,37 @@ def fix_unmatched_think_end_tokens( ) -> mx.array: if not tokenizer.has_thinking: return tokens - assert tokenizer.think_start_id - assert tokenizer.think_end_id - think_start_id: int = tokenizer.think_start_id - think_end_id: int = tokenizer.think_end_id + assert tokenizer.think_start_tokens + assert tokenizer.think_end_tokens + think_start_tokens: list[int] = tokenizer.think_start_tokens + think_end_tokens: list[int] = tokenizer.think_end_tokens token_list: list[int] = cast(list[int], tokens.tolist()) result: list[int] = [] + depth = 0 + accumulated_think_start_length = 0 + accumulated_think_end_length = 0 + for token in token_list: - if token == think_start_id: - depth += 1 - elif token == think_end_id: - if depth == 0: - result.append(think_start_id) - else: - depth -= 1 + if token == think_start_tokens[accumulated_think_start_length]: + accumulated_think_start_length += 1 + if accumulated_think_start_length == len(think_start_tokens): + depth += 1 + accumulated_think_start_length = 0 + + elif token == think_end_tokens[accumulated_think_end_length]: + accumulated_think_end_length += 1 + if accumulated_think_end_length == len(think_end_tokens): + if depth == 0: + result.extend(think_start_tokens) + else: + depth -= 1 + accumulated_think_end_length = 0 + + else: + accumulated_think_start_length = 0 + accumulated_think_end_length = 0 + result.append(token) return mx.array(result) diff --git a/src/exo/worker/engines/mlx/vision.py b/src/exo/worker/engines/mlx/vision.py index 49cb024dc3..69749aa76a 100644 --- a/src/exo/worker/engines/mlx/vision.py +++ b/src/exo/worker/engines/mlx/vision.py @@ -6,8 +6,9 @@ import io import json import re +from collections.abc import Callable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast if TYPE_CHECKING: from mlx_vlm.utils import ImageProcessor @@ -25,7 +26,7 @@ from exo.shared.models.model_cards import VisionCardConfig from exo.shared.types.common import ModelId from exo.shared.types.mlx import Model -from exo.shared.types.text_generation import TextGenerationTaskParams +from exo.shared.types.text_generation import Base64Image, TextGenerationTaskParams from exo.worker.engines.mlx.cache import encode_prompt from exo.worker.engines.mlx.utils_mlx import ( fix_unmatched_think_end_tokens, @@ -33,13 +34,64 @@ ) from exo.worker.runner.bootstrap import logger +_video_processor_patched = False + def _filter_config(cls: type, d: dict[str, Any]) -> dict[str, Any]: valid = set(inspect.signature(cls.__init__).parameters.keys()) - {"self"} return {k: v for k, v in d.items() if k in valid} # type: ignore -_video_processor_patched = False +_ProcessorOutput = dict[str, np.ndarray] | tuple[dict[str, np.ndarray], list[int]] + + +def _run_processor( + processor: "ImageProcessor", + pil_images: list[Image.Image], +) -> tuple[dict[str, np.ndarray], list[int] | None]: + """ + Image processors split into two families by how they report per-image + token counts: + + 1. Variable-resolution patch models (Qwen3-VL, Llama 4 vision, ...): + return a `BatchFeature` dict containing `pixel_values` and + `image_grid_thw` — an (n_images, 3) array of (temporal, height, width) + patch counts. The caller multiplies the three to get the per-image + token count, so no override is needed. + + 2. Fixed-token-budget models (Gemma 4): every image collapses to a fixed + number of soft tokens, so there's no grid to report. These processors + return `(batch_feature_dict, [n_tokens_per_image])` instead. + + We normalize both into (dict, optional tokens override). + """ + raw = cast(_ProcessorOutput, processor(images=pil_images, return_tensors="np")) + if isinstance(raw, tuple): + batch, tokens = raw + return batch, [int(n) for n in tokens] + return raw, None + + +def _instantiate_projector( + cls: type, + model_config: Any, # pyright: ignore[reportAny] + vision_config: Any, # pyright: ignore[reportAny] + text_config: Any, # pyright: ignore[reportAny] +) -> nn.Module: + """ + Instantiate projector/embedding classes with any missing values + """ + init_sig = inspect.signature(cls.__init__) + params = {n: p for n, p in init_sig.parameters.items() if n != "self"} + kwargs: dict[str, Any] = {} + + if "embedding_dim" in params: + kwargs["embedding_dim"] = vision_config.hidden_size # pyright: ignore[reportAny] + if "text_hidden_size" in params: + kwargs["text_hidden_size"] = text_config.hidden_size # pyright: ignore[reportAny] + if "eps" in params: + kwargs["eps"] = getattr(vision_config, "rms_norm_eps", 1e-6) # pyright: ignore[reportAny] + return cls(**kwargs) # type: ignore def _patch_video_processor() -> None: @@ -128,6 +180,7 @@ class VisionResult: prompt_tokens: mx.array embeddings: mx.array media_regions: list[MediaRegion] + image_token_id: int class VisionEncoder: @@ -159,6 +212,45 @@ def _import_mlx_vlm(self, *submodules: str) -> Any: # type: ignore results.append(importlib.import_module(name)) return results[0] if len(results) == 1 else tuple(results) + def _apply_projector_quantization_if_needed( + self, projector_weights: dict[str, mx.array] + ) -> None: + # Quantized models ship the projector's Linear layers as packed uint32 + # weights plus `.scales`/`.biases`. Our now instantiated projector + # uses plain nn.Linear, so we must mirror the packing via nn.quantize + # before load_weights, otherwise MLX rejects the extra parameters. + if self._projector is None: + return + has_quantized_tensors = any( + key.endswith((".scales", ".biases")) or val.dtype == mx.uint32 + for key, val in projector_weights.items() + ) + if not has_quantized_tensors: + return + config = self._load_config_json() + quant_cfg = cast(dict[str, Any], config.get("quantization") or {}) + if not quant_cfg: + return + group_size = int(cast(int, quant_cfg.get("group_size", 64))) + bits = int(cast(int, quant_cfg.get("bits", 4))) + nn.quantize(self._projector, group_size=group_size, bits=bits) + + def _load_image_processor_from_module(self, repo: str) -> "ImageProcessor | None": + # mlx_vlm.utils.load_image_processor only works for models that set + # `Model.ImageProcessor = `, but Gemma4 just uses + # `Gemma4ImageProcessor` from the package `__init__.py` + try: + pkg: Any = importlib.import_module( + f"mlx_vlm.models.{self._config.model_type}" + ) + except ImportError: + return None + for attr in dir(pkg): # pyright: ignore[reportAny] + cls = getattr(pkg, attr) # pyright: ignore[reportAny] + if isinstance(cls, type) and attr.endswith("ImageProcessor"): + return cls.from_pretrained(repo) # type: ignore + return None + def ensure_loaded(self) -> None: if self._loaded: return @@ -194,7 +286,7 @@ def _load_weights(self) -> None: if ( isinstance(obj, type) and issubclass(obj, nn.Module) - and "Projector" in attr_name + and ("Projector" in attr_name or "Embedder" in attr_name) ): projector_cls = obj break @@ -214,7 +306,12 @@ def _load_weights(self) -> None: vision_config=vision_config, **_filter_config(config_mod.ModelConfig, extra), # type: ignore ) - self._projector = projector_cls(model_config) # type: ignore + self._projector = _instantiate_projector( + projector_cls, + model_config, + vision_config, + text_config, + ) processor_repo = self._config.processor_repo if processor_repo: @@ -223,7 +320,9 @@ def _load_weights(self) -> None: self._load_weights_from_model_repo() repo = processor_repo or str(self._model_path) - image_proc = load_image_processor(repo) + image_proc = load_image_processor( + repo + ) or self._load_image_processor_from_module(repo) if image_proc is not None: self._processor = image_proc else: @@ -295,18 +394,32 @@ def _load_weights_from_model_repo(self) -> None: raise FileNotFoundError(f"No safetensors files found in {self._model_path}") vision_prefixes = ["vision_tower.", "model.visual."] + projector_prefixes = [ + "embed_vision.", + "multi_modal_projector.", + "mm_projector.", + ] vision_weights: dict[str, mx.array] = {} - found_raw_prefix = False + projector_weights: dict[str, mx.array] = {} + + # If weights under `model.visual.`, we need to call mlx_vlm's VisionModel.sanitize() + # to remap into its own keys. + needs_sanitize = False + for sf_path in safetensors_files: file_weights: dict[str, mx.array] = mx.load(str(sf_path)) # type: ignore for key, val in file_weights.items(): for prefix in vision_prefixes: if key.startswith(prefix): - short_key = key[len(prefix) :] - vision_weights[short_key] = val + vision_weights[key[len(prefix) :]] = val if prefix == "model.visual.": - found_raw_prefix = True + needs_sanitize = True break + else: + for prefix in projector_prefixes: + if key.startswith(prefix): + projector_weights[key[len(prefix) :]] = val + break if not vision_weights: raise ValueError( @@ -315,16 +428,29 @@ def _load_weights_from_model_repo(self) -> None: ) assert self._vision_tower is not None - if found_raw_prefix and hasattr(self._vision_tower, "sanitize"): - vision_weights = self._vision_tower.sanitize(vision_weights) # type: ignore + if needs_sanitize: + sanitize: Callable[[dict[str, mx.array]], dict[str, mx.array]] | None = ( + getattr(self._vision_tower, "sanitize", None) + ) + if sanitize is not None: + vision_weights = sanitize(vision_weights) - self._vision_tower.load_weights(list(vision_weights.items())) # type: ignore + self._vision_tower.load_weights(list(vision_weights.items())) mx.eval(self._vision_tower.parameters()) - n_vision = sum(v.size for _, v in vision_weights.items()) # type: ignore - logger.info(f"Vision encoder loaded: {n_vision / 1e6:.1f}M params") + if self._projector is not None and projector_weights: + self._apply_projector_quantization_if_needed(projector_weights) + self._projector.load_weights(list(projector_weights.items())) + mx.eval(self._projector.parameters()) + + n_vision = sum(v.size for v in vision_weights.values()) + n_proj = sum(v.size for v in projector_weights.values()) + logger.info( + f"Vision encoder loaded: {n_vision / 1e6:.1f}M params" + + (f" (+ projector {n_proj / 1e6:.1f}M)" if n_proj else "") + ) - def encode_images(self, images: list[str]) -> tuple[mx.array, list[int]]: + def encode_images(self, images: list[Base64Image]) -> tuple[mx.array, list[int]]: self.ensure_loaded() assert self._vision_tower is not None assert self._processor is not None @@ -333,12 +459,21 @@ def encode_images(self, images: list[str]) -> tuple[mx.array, list[int]]: for idx, img in enumerate(pil_images): logger.info(f"Image {idx}: {img.width}x{img.height} mode={img.mode}") + per_image_pixels: list[mx.array] + grid_thw: mx.array | None + n_tokens_per_image: list[int] + if self._config.processor_repo: processed = self._processor.preprocess( [{"type": "image", "image": img} for img in pil_images], return_tensors="np", ) - pixel_values = mx.array(processed["pixel_values"]) # type: ignore + stacked_pixels = mx.array(processed["pixel_values"]) # type: ignore + if stacked_pixels.ndim == 3: + stacked_pixels = stacked_pixels[None] + per_image_pixels = [ + stacked_pixels[i : i + 1] for i in range(stacked_pixels.shape[0]) + ] grid_thw = mx.array(processed["grid_thws"]) # type: ignore assert self._merge_kernel_size is not None merge_length = int(np.prod(self._merge_kernel_size)) @@ -347,31 +482,80 @@ def encode_images(self, images: list[str]) -> tuple[mx.array, list[int]]: for i in range(grid_thw.shape[0]) ] else: - processed = self._processor( - images=pil_images, - return_tensors="np", + batch, tokens_override = _run_processor(self._processor, pil_images) + # `Gemma4ImageProcessor` returns pixel_values as a plain ndarray + # when all images resize to the same shape, or as a Python list of + # per-image (C, H, W) ndarrays when they differ. Treat it as the + # union here. + raw_pixel_values: np.ndarray | list[np.ndarray] = cast( + "np.ndarray | list[np.ndarray]", batch["pixel_values"] ) - pixel_values = mx.array(processed["pixel_values"]) # type: ignore - grid_thw = mx.array(processed["image_grid_thw"]) # type: ignore - merge_unit = self._spatial_merge_size**2 - n_tokens_per_image = [ - int( - grid_thw[i, 0].item() - * grid_thw[i, 1].item() - * grid_thw[i, 2].item() - ) - // merge_unit - for i in range(grid_thw.shape[0]) - ] + raw_grid = batch.get("image_grid_thw") + grid_thw = mx.array(raw_grid) if raw_grid is not None else None + if tokens_override is not None: + n_tokens_per_image = tokens_override + else: + assert grid_thw is not None + merge_unit = self._spatial_merge_size**2 + n_tokens_per_image = [ + int( + grid_thw[i, 0].item() + * grid_thw[i, 1].item() + * grid_thw[i, 2].item() + ) + // merge_unit + for i in range(grid_thw.shape[0]) + ] + + if isinstance(raw_pixel_values, list): + per_image_pixels = [ + # (C, H, W) -> (1, C, H, W) + mx.array(p)[None] if p.ndim == 3 else mx.array(p) + for p in raw_pixel_values + ] + else: + stacked = mx.array(raw_pixel_values) + per_image_pixels = [stacked[i : i + 1] for i in range(stacked.shape[0])] + + patch_embed_weight = None + for head_attr, linear_attr in ( + ("patch_embed", "proj"), + ("patch_embedder", "input_proj"), + ): + head = getattr(self._vision_tower, head_attr, None) + linear = getattr(head, linear_attr, None) if head is not None else None # type: ignore + if linear is not None and hasattr(linear, "weight"): # type: ignore + patch_embed_weight = linear.weight # type: ignore + break + assert patch_embed_weight is not None, ( + "vision tower has no recognised patch-embedding linear" + ) + tower_dtype = cast(mx.Dtype, patch_embed_weight.dtype) if self._needs_nhwc: + assert grid_thw is not None + pixel_values = mx.concatenate(per_image_pixels, axis=0).astype(tower_dtype) grid_hw = grid_thw[:, 1:] if grid_thw.shape[-1] == 3 else grid_thw hidden_states = self._vision_tower( pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True, grid_thw=grid_hw, ) + elif grid_thw is None: + # Fixed-token-budget models (gemma4): run each image separately + # since they can have different spatial shapes *and* different + # soft-token counts, then flatten each to (n_tokens_i, hidden) + # and concatenate along the token axis. + per_image_hidden: list[mx.array] = [] + for pv in per_image_pixels: + result = self._vision_tower(pv.astype(tower_dtype)) + h = result[0] if isinstance(result, tuple) else result + if h.ndim == 3: + h = h.reshape(-1, h.shape[-1]) + per_image_hidden.append(h) + hidden_states = mx.concatenate(per_image_hidden, axis=0) else: + pixel_values = mx.concatenate(per_image_pixels, axis=0).astype(tower_dtype) result = self._vision_tower(pixel_values, grid_thw) hidden_states = result[0] if isinstance(result, tuple) else result @@ -380,6 +564,11 @@ def encode_images(self, images: list[str]) -> tuple[mx.array, list[int]]: else: image_features = hidden_states + # `create_vision_embeddings` expects a 2D (total_tokens, hidden) view, + # but fixed-token-budget models (gemma4) return (n_images, tokens, hidden). + if image_features.ndim == 3: + image_features = image_features.reshape(-1, image_features.shape[-1]) + return image_features, n_tokens_per_image @@ -420,6 +609,16 @@ def create_vision_embeddings( n = min(n_placeholders, image_features.shape[0]) image_features = image_features[:n] + # Gemma-family models apply `h = input_embeddings * embed_scale` inside + # the inner model's forward pass. That scale is appropriate for text + # tokens (which come out of a raw `embed_tokens(id)` lookup) but not + # for our pre-projected image features. Pre-divide by `embed_scale` + # so that after the model multiplies, image features are unchanged + # while text positions remain correctly scaled. + if hasattr(inner, "embed_scale"): # type: ignore + embed_scale = float(inner.embed_scale) # type: ignore + image_features = image_features / embed_scale + image_indices = mx.cumsum(is_image.astype(mx.int32)) - 1 image_indices = mx.clip(image_indices, 0, image_features.shape[0] - 1) @@ -432,7 +631,7 @@ def create_vision_embeddings( def _find_media_regions( prompt_tokens: mx.array, - images: list[str], + images: list[Base64Image], image_token_id: int, ) -> list[MediaRegion]: tokens_np = np.array(prompt_tokens) @@ -483,7 +682,7 @@ def __init__(self, config: VisionCardConfig, model_id: ModelId): def load(self) -> None: self._encoder.ensure_loaded() - def _image_cache_key(self, images: list[str]) -> str: + def _image_cache_key(self, images: list[Base64Image]) -> str: h = hashlib.sha256() for img in images: pil = decode_base64_image(img) @@ -492,7 +691,7 @@ def _image_cache_key(self, images: list[str]) -> str: def process( self, - images: list[str], + images: list[Base64Image], chat_template_messages: list[dict[str, Any]], tokenizer: TokenizerWrapper, model: Model, @@ -563,11 +762,12 @@ def process( prompt_tokens=prompt_tokens, embeddings=embeddings, media_regions=media_regions, + image_token_id=self.vision_config.image_token_id, ) def prepare_vision( - images: list[str] | None, + images: list[Base64Image] | None, chat_template_messages: list[dict[str, Any]] | None, vision_processor: VisionProcessor, tokenizer: TokenizerWrapper, diff --git a/src/exo/worker/main.py b/src/exo/worker/main.py index 08cc1e0ecc..883181a967 100644 --- a/src/exo/worker/main.py +++ b/src/exo/worker/main.py @@ -45,6 +45,7 @@ TaskStatus, TextGeneration, ) +from exo.shared.types.text_generation import Base64Image from exo.shared.types.topology import Connection, SocketConnection from exo.shared.types.worker.downloads import DownloadCompleted from exo.shared.types.worker.instances import InstanceId @@ -335,7 +336,9 @@ async def plan_step(self): f"from {len(chunk_buffer)} chunks" ) - resolved_images = [by_index[i] for i in sorted(by_index)] + resolved_images = [ + Base64Image(by_index[i]) for i in sorted(by_index) + ] modified_task = task.model_copy( update={ "task_params": task.task_params.model_copy( diff --git a/src/exo/worker/runner/llm_inference/model_output_parsers.py b/src/exo/worker/runner/llm_inference/model_output_parsers.py index 1242909a39..cc4e3518b8 100644 --- a/src/exo/worker/runner/llm_inference/model_output_parsers.py +++ b/src/exo/worker/runner/llm_inference/model_output_parsers.py @@ -307,21 +307,35 @@ def parse_thinking_models( Always yields tokens with finish_reason to avoid hanging the chunk stream. """ is_thinking = starts_in_thinking + accumulated = "" + for response in responses: if response is None: yield None continue + + accumulated += response.text + if response.finish_reason is not None: yield response.model_copy(update={"is_thinking": False}) continue - if response.text == think_start: + if accumulated == think_start and not is_thinking: is_thinking = True + accumulated = "" continue - if response.text == think_end: + if accumulated == think_end and is_thinking: is_thinking = False + accumulated = "" + continue + + if (think_start and accumulated == think_start[: len(accumulated)]) or ( + think_end and accumulated == think_end[: len(accumulated)] + ): continue + accumulated = "" + yield response.model_copy(update={"is_thinking": is_thinking}) diff --git a/src/exo/worker/tests/unittests/test_mlx/test_gemma4_sharding.py b/src/exo/worker/tests/unittests/test_mlx/test_gemma4_sharding.py new file mode 100644 index 0000000000..ebd6194c92 --- /dev/null +++ b/src/exo/worker/tests/unittests/test_mlx/test_gemma4_sharding.py @@ -0,0 +1,779 @@ +# type: ignore +import importlib +import json +import multiprocessing as mp +import os +import tempfile +import traceback +from typing import Any + +import mlx.core as mx +import pytest + +RANDOM_SEED = 42 +INPUT_TOKENS = [1, 100, 200, 300, 400, 500] +MAX_GEN_TOKENS = 50 + + +def _dense_no_kv_shared_config() -> dict[str, Any]: + return { + "model_type": "gemma4_text", + "hidden_size": 256, + "num_hidden_layers": 6, + "intermediate_size": 512, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "head_dim": 64, + "global_head_dim": 128, + "rms_norm_eps": 1e-6, + "vocab_size": 1024, + "vocab_size_per_layer_input": 1024, + "num_kv_shared_layers": 0, + "hidden_size_per_layer_input": 0, + "sliding_window": 32, + "sliding_window_pattern": 3, + "max_position_embeddings": 2048, + "attention_k_eq_v": False, + "final_logit_softcapping": 30.0, + "use_double_wide_mlp": False, + "enable_moe_block": False, + "tie_word_embeddings": True, + } + + +def _moe_config() -> dict[str, Any]: + return { + "model_type": "gemma4_text", + "hidden_size": 256, + "num_hidden_layers": 6, + "intermediate_size": 512, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "num_global_key_value_heads": 2, + "head_dim": 64, + "global_head_dim": 128, + "rms_norm_eps": 1e-6, + "vocab_size": 1024, + "vocab_size_per_layer_input": 1024, + "num_kv_shared_layers": 0, + "hidden_size_per_layer_input": 0, + "sliding_window": 32, + "sliding_window_pattern": 3, + "max_position_embeddings": 2048, + "attention_k_eq_v": True, + "final_logit_softcapping": 30.0, + "use_double_wide_mlp": False, + "enable_moe_block": True, + "num_experts": 4, + "top_k_experts": 2, + "moe_intermediate_size": 128, + "tie_word_embeddings": True, + } + + +def _kv_shared_config() -> dict[str, Any]: + return { + "model_type": "gemma4_text", + "hidden_size": 256, + "num_hidden_layers": 6, + "intermediate_size": 512, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "head_dim": 64, + "global_head_dim": 128, + "rms_norm_eps": 1e-6, + "vocab_size": 1024, + "vocab_size_per_layer_input": 1024, + "num_kv_shared_layers": 2, + "hidden_size_per_layer_input": 32, + "sliding_window": 32, + "sliding_window_pattern": 3, + "max_position_embeddings": 2048, + "attention_k_eq_v": False, + "final_logit_softcapping": 30.0, + "use_double_wide_mlp": False, + "enable_moe_block": False, + "tie_word_embeddings": True, + } + + +def _kv_shared_sources_split_config() -> dict[str, Any]: + """Config where sources land on layers 1 and 2 (one on each of the first + two 2-layer ranks). pattern=2 → layer_types = + [sliding, full, sliding, full, sliding, full]; num_kv_shared_layers=3 → + non-shared layers are 0,1,2 and the last occurrence of each type in that + window gives sources = {sliding:2, full:1}. previous_kvs = + [0, 1, 2, 1, 2, 1].""" + return { + "model_type": "gemma4_text", + "hidden_size": 256, + "num_hidden_layers": 6, + "intermediate_size": 512, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "head_dim": 64, + "global_head_dim": 128, + "rms_norm_eps": 1e-6, + "vocab_size": 1024, + "vocab_size_per_layer_input": 1024, + "num_kv_shared_layers": 3, + "hidden_size_per_layer_input": 32, + "sliding_window": 32, + "sliding_window_pattern": 2, + "max_position_embeddings": 2048, + "attention_k_eq_v": False, + "final_logit_softcapping": 30.0, + "use_double_wide_mlp": False, + "enable_moe_block": False, + "tie_word_embeddings": True, + } + + +def _kv_shared_sources_first_config() -> dict[str, Any]: + """Config where both kv-share sources live in layers 0 and 1, so every + shared layer must read from rank 0. pattern=2 → layer_types = + [sliding, full, sliding, full, sliding, full]; num_kv_shared_layers=4 → + non-shared layers are only 0,1 and therefore are the unique sources + (sliding→0, full→1). previous_kvs = [0, 1, 0, 1, 0, 1].""" + return { + "model_type": "gemma4_text", + "hidden_size": 256, + "num_hidden_layers": 6, + "intermediate_size": 512, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "head_dim": 64, + "global_head_dim": 128, + "rms_norm_eps": 1e-6, + "vocab_size": 1024, + "vocab_size_per_layer_input": 1024, + "num_kv_shared_layers": 4, + "hidden_size_per_layer_input": 32, + "sliding_window": 32, + "sliding_window_pattern": 2, + "max_position_embeddings": 2048, + "attention_k_eq_v": False, + "final_logit_softcapping": 30.0, + "use_double_wide_mlp": False, + "enable_moe_block": False, + "tie_word_embeddings": True, + } + + +def _build_gemma4_model(text_config: dict[str, Any]): + mx.random.seed(RANDOM_SEED) + gemma4 = importlib.import_module("mlx_lm.models.gemma4") + args = gemma4.ModelArgs.from_dict( + { + "model_type": "gemma4", + "text_config": text_config, + "vocab_size": text_config["vocab_size"], + } + ) + return gemma4.Model(args) + + +def _create_hostfile(world_size: int, base_port: int) -> str: + hosts = [f"127.0.0.1:{base_port + i}" for i in range(world_size)] + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(hosts, f) + return f.name + + +def _greedy_generate(model, prompt_tokens: list[int], n_steps: int) -> list[int]: + cache = model.make_cache() + inputs = mx.array([prompt_tokens]) + logits = model(inputs, cache=cache) + mx.eval(logits) + next_id = int(mx.argmax(logits[0, -1]).item()) + generated = [next_id] + for _ in range(n_steps - 1): + next_inputs = mx.array([[next_id]]) + logits = model(next_inputs, cache=cache) + mx.eval(logits) + next_id = int(mx.argmax(logits[0, -1]).item()) + generated.append(next_id) + return generated + + +def _run_single(text_config: dict[str, Any], result_queue: Any) -> None: + try: + model = _build_gemma4_model(text_config) + cache = model.make_cache() + inputs = mx.array([INPUT_TOKENS]) + logits = model(inputs, cache=cache) + mx.eval(logits) + last_logits = logits[0, -1] + tokens = _greedy_generate( + _build_gemma4_model(text_config), INPUT_TOKENS, MAX_GEN_TOKENS + ) + result_queue.put(("ok", last_logits.tolist(), tokens)) + except Exception as e: + result_queue.put(("err", f"{e}\n{traceback.format_exc()}", None)) + + +def _run_tensor( + rank: int, + world_size: int, + hostfile: str, + text_config: dict[str, Any], + result_queue: Any, +) -> None: + os.environ["MLX_HOSTFILE"] = hostfile + os.environ["MLX_RANK"] = str(rank) + try: + from exo.worker.engines.mlx.auto_parallel import tensor_auto_parallel + + group = mx.distributed.init(backend="ring", strict=True) + model = _build_gemma4_model(text_config) + model = tensor_auto_parallel( + model, group, timeout_seconds=60.0, on_timeout=None, on_layer_loaded=None + ) + + cache = model.make_cache() + inputs = mx.array([INPUT_TOKENS]) + logits = model(inputs, cache=cache) + mx.eval(logits) + last_logits = logits[0, -1] + + gen_model = _build_gemma4_model(text_config) + gen_model = tensor_auto_parallel( + gen_model, + group, + timeout_seconds=60.0, + on_timeout=None, + on_layer_loaded=None, + ) + tokens = _greedy_generate(gen_model, INPUT_TOKENS, MAX_GEN_TOKENS) + + result_queue.put((rank, "ok", last_logits.tolist(), tokens)) + except Exception as e: + result_queue.put((rank, "err", f"{e}\n{traceback.format_exc()}", None)) + + +def _run_pipeline( + rank: int, + world_size: int, + hostfile: str, + splits: list[tuple[int, int]], + text_config: dict[str, Any], + result_queue: Any, +) -> None: + os.environ["MLX_HOSTFILE"] = hostfile + os.environ["MLX_RANK"] = str(rank) + try: + from exo.shared.models.model_cards import ModelCard, ModelTask + from exo.shared.types.common import ModelId + from exo.shared.types.memory import Memory + from exo.shared.types.worker.shards import PipelineShardMetadata + from exo.worker.engines.mlx.auto_parallel import pipeline_auto_parallel + + group = mx.distributed.init(backend="ring", strict=True) + start, end = splits[rank] + + n_layers = text_config["num_hidden_layers"] + shard_meta = PipelineShardMetadata( + model_card=ModelCard( + model_id=ModelId("test/gemma4-test"), + storage_size=Memory.from_gb(1), + n_layers=n_layers, + hidden_size=text_config["hidden_size"], + supports_tensor=False, + tasks=[ModelTask.TextGeneration], + ), + device_rank=rank, + world_size=world_size, + start_layer=start, + end_layer=end, + n_layers=n_layers, + ) + + def _build(): + model = _build_gemma4_model(text_config) + return pipeline_auto_parallel( + model, group, shard_meta, on_layer_loaded=None + ) + + model = _build() + cache = model.make_cache() + inputs = mx.array([INPUT_TOKENS]) + logits = model(inputs, cache=cache) + mx.eval(logits) + last_logits = logits[0, -1] + + gen_model = _build() + tokens = _greedy_generate(gen_model, INPUT_TOKENS, MAX_GEN_TOKENS) + + result_queue.put((rank, "ok", last_logits.tolist(), tokens)) + except Exception as e: + result_queue.put((rank, "err", f"{e}\n{traceback.format_exc()}", None)) + + +def _spawn_single(text_config: dict[str, Any]) -> tuple[list[float], list[int]]: + ctx = mp.get_context("spawn") + queue: Any = ctx.Queue() + p = ctx.Process(target=_run_single, args=(text_config, queue)) + p.start() + p.join(timeout=120) + status, payload, tokens = queue.get(timeout=5) + if status != "ok": + raise RuntimeError(f"single device failed: {payload}") + return payload, tokens + + +def _spawn_distributed( + target, + args_per_rank: list[tuple], + timeout: float = 240.0, +) -> dict[int, tuple[list[float], list[int]]]: + ctx = mp.get_context("spawn") + queue: Any = ctx.Queue() + procs = [] + for rank_args in args_per_rank: + p = ctx.Process(target=target, args=(*rank_args, queue)) + p.start() + procs.append(p) + for p in procs: + p.join(timeout=timeout) + + results: dict[int, tuple[list[float], list[int]]] = {} + errors: dict[int, str] = {} + while not queue.empty(): + rank, status, payload, tokens = queue.get() + if status == "ok": + results[rank] = (payload, tokens) + else: + errors[rank] = payload + if errors: + raise RuntimeError(f"distributed run errors: {errors}") + return results + + +@pytest.mark.slow +def test_dense_tensor_matches_single() -> None: + text_config = _dense_no_kv_shared_config() + base_port = 31200 + + single_logits, single_tokens = _spawn_single(text_config) + + hostfile = _create_hostfile(2, base_port) + try: + results = _spawn_distributed( + _run_tensor, + [(0, 2, hostfile, text_config), (1, 2, hostfile, text_config)], + ) + finally: + os.unlink(hostfile) + + for rank in range(2): + rank_logits, rank_tokens = results[rank] + diff = max(abs(a - b) for a, b in zip(single_logits, rank_logits, strict=True)) + assert diff < 5e-5, f"rank {rank} logit diff {diff}" + assert rank_tokens == single_tokens, ( + f"rank {rank} tokens {rank_tokens} != single {single_tokens}" + ) + + +@pytest.mark.slow +def test_dense_pipeline_matches_single() -> None: + text_config = _dense_no_kv_shared_config() + base_port = 31220 + n_layers = text_config["num_hidden_layers"] + splits = [(0, n_layers // 2), (n_layers // 2, n_layers)] + + single_logits, single_tokens = _spawn_single(text_config) + + hostfile = _create_hostfile(2, base_port) + try: + results = _spawn_distributed( + _run_pipeline, + [ + (0, 2, hostfile, splits, text_config), + (1, 2, hostfile, splits, text_config), + ], + ) + finally: + os.unlink(hostfile) + + rank_logits, rank_tokens = results[1] + diff = max(abs(a - b) for a, b in zip(single_logits, rank_logits, strict=True)) + assert diff < 5e-5, f"pipeline rank 1 logit diff {diff}" + assert rank_tokens == single_tokens, ( + f"pipeline tokens {rank_tokens} != single {single_tokens}" + ) + + +@pytest.mark.slow +def test_dense_pipeline_asymmetric() -> None: + text_config = _dense_no_kv_shared_config() + base_port = 31240 + splits = [(0, 2), (2, 6)] + + single_logits, single_tokens = _spawn_single(text_config) + + hostfile = _create_hostfile(2, base_port) + try: + results = _spawn_distributed( + _run_pipeline, + [ + (0, 2, hostfile, splits, text_config), + (1, 2, hostfile, splits, text_config), + ], + ) + finally: + os.unlink(hostfile) + + rank_logits, rank_tokens = results[1] + diff = max(abs(a - b) for a, b in zip(single_logits, rank_logits, strict=True)) + assert diff < 5e-5, f"asymmetric pipeline logit diff {diff}" + assert rank_tokens == single_tokens + + +@pytest.mark.slow +def test_moe_tensor_matches_single() -> None: + text_config = _moe_config() + base_port = 31260 + + single_logits, single_tokens = _spawn_single(text_config) + + hostfile = _create_hostfile(2, base_port) + try: + results = _spawn_distributed( + _run_tensor, + [(0, 2, hostfile, text_config), (1, 2, hostfile, text_config)], + ) + finally: + os.unlink(hostfile) + + for rank in range(2): + rank_logits, rank_tokens = results[rank] + diff = max(abs(a - b) for a, b in zip(single_logits, rank_logits, strict=True)) + assert diff < 5e-5, f"moe tensor rank {rank} logit diff {diff}" + assert rank_tokens == single_tokens + + +@pytest.mark.slow +def test_moe_pipeline_matches_single() -> None: + text_config = _moe_config() + base_port = 31280 + n_layers = text_config["num_hidden_layers"] + splits = [(0, n_layers // 2), (n_layers // 2, n_layers)] + + single_logits, single_tokens = _spawn_single(text_config) + + hostfile = _create_hostfile(2, base_port) + try: + results = _spawn_distributed( + _run_pipeline, + [ + (0, 2, hostfile, splits, text_config), + (1, 2, hostfile, splits, text_config), + ], + ) + finally: + os.unlink(hostfile) + + rank_logits, rank_tokens = results[1] + diff = max(abs(a - b) for a, b in zip(single_logits, rank_logits, strict=True)) + assert diff < 5e-5, f"moe pipeline logit diff {diff}" + assert rank_tokens == single_tokens + + +@pytest.mark.slow +def test_kv_shared_pipeline_valid_split() -> None: + """KV-shared layers stay with their source — split (0,2),(2,6) keeps both shared sources on rank 1.""" + text_config = _kv_shared_config() + base_port = 31300 + splits = [(0, 2), (2, 6)] + + single_logits, single_tokens = _spawn_single(text_config) + + hostfile = _create_hostfile(2, base_port) + try: + results = _spawn_distributed( + _run_pipeline, + [ + (0, 2, hostfile, splits, text_config), + (1, 2, hostfile, splits, text_config), + ], + ) + finally: + os.unlink(hostfile) + + rank_logits, rank_tokens = results[1] + diff = max(abs(a - b) for a, b in zip(single_logits, rank_logits, strict=True)) + assert diff < 5e-5, f"kv_shared pipeline logit diff {diff}" + assert rank_tokens == single_tokens + + +@pytest.mark.slow +def test_kv_shared_pipeline_split_separates_source_from_shared() -> None: + """Split (0,3),(3,6) puts source layer 2 on rank 0 but shared layer 5 (which + reads layer 2) on rank 1. Cross-rank kv transfer should make this work.""" + text_config = _kv_shared_config() + base_port = 31320 + splits = [(0, 3), (3, 6)] + + single_logits, single_tokens = _spawn_single(text_config) + + hostfile = _create_hostfile(2, base_port) + try: + results = _spawn_distributed( + _run_pipeline, + [ + (0, 2, hostfile, splits, text_config), + (1, 2, hostfile, splits, text_config), + ], + ) + finally: + os.unlink(hostfile) + + rank_logits, rank_tokens = results[1] + diff = max(abs(a - b) for a, b in zip(single_logits, rank_logits, strict=True)) + assert diff < 5e-5, f"kv_shared cross-rank logit diff {diff}" + assert rank_tokens == single_tokens + + +@pytest.mark.slow +def test_kv_shared_pipeline_3node_both_sources_remote() -> None: + """3-rank split (0,2),(2,4),(4,6). Both sources (2 and 3) live on rank 1; + rank 2's shared layers (4 and 5) need to receive both via cross-rank kv + transfer.""" + text_config = _kv_shared_config() + base_port = 31340 + splits = [(0, 2), (2, 4), (4, 6)] + + single_logits, single_tokens = _spawn_single(text_config) + + hostfile = _create_hostfile(3, base_port) + try: + results = _spawn_distributed( + _run_pipeline, + [ + (0, 3, hostfile, splits, text_config), + (1, 3, hostfile, splits, text_config), + (2, 3, hostfile, splits, text_config), + ], + ) + finally: + os.unlink(hostfile) + + rank_logits, rank_tokens = results[2] + diff = max(abs(a - b) for a, b in zip(single_logits, rank_logits, strict=True)) + assert diff < 5e-5, f"kv_shared 3-rank logit diff {diff}" + assert rank_tokens == single_tokens + + +@pytest.mark.slow +def test_kv_shared_pipeline_3node_multi_hop() -> None: + """3-rank split (0,3),(3,4),(4,6). Source layer 2 lives on rank 0, source + layer 3 lives on rank 1. Rank 2's shared layers need both, so source 2 must + be forwarded through rank 1 to rank 2 (multi-hop forwarding via the + received-bundle pass-through).""" + text_config = _kv_shared_config() + base_port = 31360 + splits = [(0, 3), (3, 4), (4, 6)] + + single_logits, single_tokens = _spawn_single(text_config) + + hostfile = _create_hostfile(3, base_port) + try: + results = _spawn_distributed( + _run_pipeline, + [ + (0, 3, hostfile, splits, text_config), + (1, 3, hostfile, splits, text_config), + (2, 3, hostfile, splits, text_config), + ], + ) + finally: + os.unlink(hostfile) + + rank_logits, rank_tokens = results[2] + diff = max(abs(a - b) for a, b in zip(single_logits, rank_logits, strict=True)) + assert diff < 5e-5, f"kv_shared multi-hop logit diff {diff}" + assert rank_tokens == single_tokens + + +@pytest.mark.slow +def test_kv_shared_pipeline_sources_rank0_consumed_by_all() -> None: + """3-rank split (0,2),(2,4),(4,6). Sources (layers 0, 1) live entirely on + rank 0. Rank 1 (layers 2, 3) and rank 2 (layers 4, 5) are ALL shared + layers, and every shared layer reads from one of rank 0's sources.""" + text_config = _kv_shared_sources_first_config() + base_port = 31380 + splits = [(0, 2), (2, 4), (4, 6)] + + single_logits, single_tokens = _spawn_single(text_config) + + hostfile = _create_hostfile(3, base_port) + try: + results = _spawn_distributed( + _run_pipeline, + [ + (0, 3, hostfile, splits, text_config), + (1, 3, hostfile, splits, text_config), + (2, 3, hostfile, splits, text_config), + ], + ) + finally: + os.unlink(hostfile) + + rank_logits, rank_tokens = results[2] + diff = max(abs(a - b) for a, b in zip(single_logits, rank_logits, strict=True)) + assert diff < 5e-5, f"sources-rank0 logit diff {diff}" + assert rank_tokens == single_tokens + + +@pytest.mark.slow +def test_kv_shared_pipeline_sources_split_across_first_two_ranks() -> None: + """3-rank split (0,2),(2,4),(4,6). previous_kvs = [0, 1, 2, 1, 2, 1] puts + the two sources at layers 1 (rank 0) and 2 (rank 1).""" + text_config = _kv_shared_sources_split_config() + base_port = 31400 + splits = [(0, 2), (2, 4), (4, 6)] + + single_logits, single_tokens = _spawn_single(text_config) + + hostfile = _create_hostfile(3, base_port) + try: + results = _spawn_distributed( + _run_pipeline, + [ + (0, 3, hostfile, splits, text_config), + (1, 3, hostfile, splits, text_config), + (2, 3, hostfile, splits, text_config), + ], + ) + finally: + os.unlink(hostfile) + + rank_logits, rank_tokens = results[2] + diff = max(abs(a - b) for a, b in zip(single_logits, rank_logits, strict=True)) + assert diff < 5e-5, f"sources-split logit diff {diff}" + assert rank_tokens == single_tokens + + +def _e2b_like_config() -> dict[str, Any]: + """Exact gemma-4-e2b text_config (30 layers, 10 kv-shared, + hidden_size_per_layer_input=256, sliding_window=512, explicit layer_types).""" + pattern = ["sliding_attention"] * 4 + ["full_attention"] + layer_types = pattern * 6 + return { + "model_type": "gemma4_text", + "hidden_size": 1536, + "num_hidden_layers": 30, + "intermediate_size": 8192, + "num_attention_heads": 8, + "num_key_value_heads": 2, + "head_dim": 256, + "global_head_dim": 256, + "rms_norm_eps": 1e-6, + "vocab_size": 262144, + "vocab_size_per_layer_input": 262144, + "num_kv_shared_layers": 10, + "hidden_size_per_layer_input": 256, + "sliding_window": 512, + "sliding_window_pattern": 5, + "layer_types": layer_types, + "max_position_embeddings": 32768, + "attention_k_eq_v": False, + "final_logit_softcapping": 30.0, + "use_double_wide_mlp": False, + "enable_moe_block": False, + "tie_word_embeddings": True, + } + + +def _load_gemma4_tokenizer(): # type: ignore[no-untyped-def] + from transformers import AutoTokenizer + + return AutoTokenizer.from_pretrained("mlx-community/gemma-4-e2b-it-4bit") + + +def _run_pipeline_via_prefill( + rank: int, + world_size: int, + hostfile: str, + splits: list[tuple[int, int]], + text_config: dict[str, Any], + result_queue: Any, +) -> None: + """End-to-end reproducer for the warmup hang: load the real + mlx-community/gemma-4-e2b-it-4bit model + tokenizer, shard it across a + 2-rank pipeline, and drive it through exo's `warmup_inference` which + is the *exact same entrypoint* runner.main calls before serving.""" + os.environ["MLX_HOSTFILE"] = hostfile + os.environ["MLX_RANK"] = str(rank) + try: + from mlx_lm.utils import load + + from exo.shared.models.model_cards import ModelCard, ModelTask + from exo.shared.types.common import ModelId + from exo.shared.types.memory import Memory + from exo.shared.types.worker.shards import PipelineShardMetadata + from exo.worker.engines.mlx.auto_parallel import pipeline_auto_parallel + from exo.worker.engines.mlx.generator.generate import warmup_inference + + group = mx.distributed.init(backend="ring", strict=True) + start, end = splits[rank] + + model, tokenizer = load("mlx-community/gemma-4-e2b-it-4bit") + n_layers = len(model.layers) + shard_meta = PipelineShardMetadata( + model_card=ModelCard( + model_id=ModelId("mlx-community/gemma-4-e2b-it-4bit"), + storage_size=Memory.from_gb(3), + n_layers=n_layers, + hidden_size=1536, + supports_tensor=False, + tasks=[ModelTask.TextGeneration], + ), + device_rank=rank, + world_size=world_size, + start_layer=start, + end_layer=end, + n_layers=n_layers, + ) + model = pipeline_auto_parallel(model, group, shard_meta, on_layer_loaded=None) + + tokens_generated = warmup_inference( + model=model, + tokenizer=tokenizer, + group=group, + model_id=ModelId("mlx-community/gemma-4-e2b-it-4bit"), + ) + + result_queue.put((rank, "ok", [], [tokens_generated])) + except Exception as e: + result_queue.put((rank, "err", f"{e}\n{traceback.format_exc()}", None)) + + +@pytest.mark.slow +def test_e2b_like_pipeline_warmup_path() -> None: + """Runs exo's real prefill+decode path on an e2b-shaped config in a + 2-rank pipeline split. Catches bugs that only show up when the cache + goes through trim/snapshot-restore, per_layer_inputs is active, and + the shard boundary crosses mixed cache types.""" + text_config = _e2b_like_config() + base_port = 31420 + n_layers = text_config["num_hidden_layers"] + splits = [(0, n_layers // 2), (n_layers // 2, n_layers)] + + hostfile = _create_hostfile(2, base_port) + try: + results = _spawn_distributed( + _run_pipeline_via_prefill, + [ + (0, 2, hostfile, splits, text_config), + (1, 2, hostfile, splits, text_config), + ], + timeout=120.0, + ) + finally: + os.unlink(hostfile) + + # Reaching this point means both ranks completed `warmup_inference` + # end-to-end (prefill + decode) without hanging or raising — that's the + # only thing this test is here to prove. + assert 0 in results and 1 in results, f"missing rank results: {list(results)}" diff --git a/src/exo/worker/tests/unittests/test_mlx/test_prefix_cache_architectures.py b/src/exo/worker/tests/unittests/test_mlx/test_prefix_cache_architectures.py index 609ea867a1..fec5082a24 100644 --- a/src/exo/worker/tests/unittests/test_mlx/test_prefix_cache_architectures.py +++ b/src/exo/worker/tests/unittests/test_mlx/test_prefix_cache_architectures.py @@ -16,7 +16,11 @@ from exo.shared.types.common import ModelId from exo.shared.types.mlx import Model -from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams +from exo.shared.types.text_generation import ( + InputMessage, + InputMessageContent, + TextGenerationTaskParams, +) from exo.worker.engines.mlx.cache import KVPrefixCache from exo.worker.engines.mlx.generator.generate import mlx_generate from exo.worker.engines.mlx.utils_mlx import ( @@ -203,7 +207,9 @@ def _make_task() -> TextGenerationTaskParams: input=[ InputMessage( role="user", - content="Use the calculator to compute 1847 * 263 + 5921", + content=InputMessageContent( + "Use the calculator to compute 1847 * 263 + 5921" + ), ) ], max_output_tokens=20, diff --git a/src/exo/worker/tests/unittests/test_plan/test_task_forwarding.py b/src/exo/worker/tests/unittests/test_plan/test_task_forwarding.py index 0aa901ce72..bb87268fc4 100644 --- a/src/exo/worker/tests/unittests/test_plan/test_task_forwarding.py +++ b/src/exo/worker/tests/unittests/test_plan/test_task_forwarding.py @@ -2,7 +2,11 @@ import exo.worker.plan as plan_mod from exo.shared.types.tasks import Task, TaskId, TaskStatus, TextGeneration -from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams +from exo.shared.types.text_generation import ( + InputMessage, + InputMessageContent, + TextGenerationTaskParams, +) from exo.shared.types.worker.instances import BoundInstance, InstanceId from exo.shared.types.worker.runners import ( RunnerIdle, @@ -61,7 +65,8 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready(): task_status=TaskStatus.Pending, command_id=COMMAND_1_ID, task_params=TextGenerationTaskParams( - model=MODEL_A_ID, input=[InputMessage(role="user", content="")] + model=MODEL_A_ID, + input=[InputMessage(role="user", content=InputMessageContent(""))], ), ) @@ -113,7 +118,8 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready(): task_status=TaskStatus.Pending, command_id=COMMAND_1_ID, task_params=TextGenerationTaskParams( - model=MODEL_A_ID, input=[InputMessage(role="user", content="")] + model=MODEL_A_ID, + input=[InputMessage(role="user", content=InputMessageContent(""))], ), ) @@ -162,7 +168,8 @@ def test_plan_does_not_forward_tasks_for_other_instances(): task_status=TaskStatus.Pending, command_id=COMMAND_1_ID, task_params=TextGenerationTaskParams( - model=MODEL_A_ID, input=[InputMessage(role="user", content="")] + model=MODEL_A_ID, + input=[InputMessage(role="user", content=InputMessageContent(""))], ), ) @@ -215,7 +222,8 @@ def test_plan_ignores_non_pending_or_non_chat_tasks(): task_status=TaskStatus.Complete, command_id=COMMAND_1_ID, task_params=TextGenerationTaskParams( - model=MODEL_A_ID, input=[InputMessage(role="user", content="")] + model=MODEL_A_ID, + input=[InputMessage(role="user", content=InputMessageContent(""))], ), ) diff --git a/src/exo/worker/tests/unittests/test_runner/test_dsml_e2e.py b/src/exo/worker/tests/unittests/test_runner/test_dsml_e2e.py index 26c11fe542..74efbef1db 100644 --- a/src/exo/worker/tests/unittests/test_runner/test_dsml_e2e.py +++ b/src/exo/worker/tests/unittests/test_runner/test_dsml_e2e.py @@ -988,6 +988,7 @@ class TestApplyChatTemplateWithToolCalls: def test_dsml_encoding_with_tool_calls_in_history(self): from exo.shared.types.text_generation import ( InputMessage, + InputMessageContent, TextGenerationTaskParams, ) from exo.worker.engines.mlx.utils_mlx import apply_chat_template @@ -1022,8 +1023,8 @@ def test_dsml_encoding_with_tool_calls_in_history(self): params = TextGenerationTaskParams( model=ModelId("mlx-community/DeepSeek-V3.2-8bit"), - input=[InputMessage(role="user", content="Thanks!")], - instructions="You are a helpful assistant.", + input=[InputMessage(role="user", content=InputMessageContent("Thanks!"))], + instructions=InputMessageContent("You are a helpful assistant."), enable_thinking=True, chat_template_messages=chat_template_messages, tools=_WEATHER_TOOLS, diff --git a/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py b/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py index cd4c18d701..ffd8fbfdf3 100644 --- a/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py +++ b/src/exo/worker/tests/unittests/test_runner/test_event_ordering.py @@ -27,7 +27,11 @@ TaskStatus, TextGeneration, ) -from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams +from exo.shared.types.text_generation import ( + InputMessage, + InputMessageContent, + TextGenerationTaskParams, +) from exo.shared.types.worker.runner_response import GenerationResponse from exo.shared.types.worker.runners import ( RunnerConnected, @@ -91,7 +95,7 @@ def nothin(*_1: U, **_2: V) -> T: CHAT_PARAMS = TextGenerationTaskParams( model=MODEL_A_ID, - input=[InputMessage(role="user", content="hello")], + input=[InputMessage(role="user", content=InputMessageContent("hello"))], stream=True, max_output_tokens=4, temperature=0.0, diff --git a/src/exo/worker/tests/unittests/test_runner/test_runner_supervisor.py b/src/exo/worker/tests/unittests/test_runner/test_runner_supervisor.py index ecb07f9494..82612d6d67 100644 --- a/src/exo/worker/tests/unittests/test_runner/test_runner_supervisor.py +++ b/src/exo/worker/tests/unittests/test_runner/test_runner_supervisor.py @@ -9,7 +9,11 @@ from exo.shared.types.common import CommandId, NodeId from exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated from exo.shared.types.tasks import Task, TaskId, TextGeneration -from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams +from exo.shared.types.text_generation import ( + InputMessage, + InputMessageContent, + TextGenerationTaskParams, +) from exo.shared.types.worker.instances import BoundInstance, InstanceId from exo.shared.types.worker.runners import RunnerFailed, RunnerId from exo.utils.channels import channel, mp_channel @@ -68,7 +72,7 @@ async def test_check_runner_emits_error_chunk_for_inflight_text_generation() -> command_id=command_id, task_params=TextGenerationTaskParams( model=bound_instance.bound_shard.model_card.model_id, - input=[InputMessage(role="user", content="hi")], + input=[InputMessage(role="user", content=InputMessageContent("hi"))], stream=True, ), ) diff --git a/uv.lock b/uv.lock index f40615ffa9..850dbe0b47 100644 --- a/uv.lock +++ b/uv.lock @@ -1541,7 +1541,7 @@ dependencies = [ [[package]] name = "mlx-vlm" -version = "0.4.1" +version = "0.4.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "datasets", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -1558,9 +1558,9 @@ dependencies = [ { name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "uvicorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/57/8f/31204f1a8c7404e523a5578949ea668e668e10dd67a0f63336f261014c0f/mlx_vlm-0.4.1.tar.gz", hash = "sha256:4e2d8a232715dbca72d346f43cf54d5738452848855792ffb1b285228ae7c7bd", size = 621840, upload-time = "2026-03-21T14:26:04.586Z" } +sdist = { url = "https://files.pythonhosted.org/packages/94/ec/108aec30efb159940ea29d133d5d8ec14840edbec914869b46eaafac5552/mlx_vlm-0.4.4.tar.gz", hash = "sha256:3197e277c1be9ed1712ea04624df029e486f7747ad93e40e7bd1c9c771f8b179", size = 836370, upload-time = "2026-04-04T15:19:01.087Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/78/856f44c6bdd8791427fa59a093a1d00c91cdbe16238506602fc3968017bb/mlx_vlm-0.4.1-py3-none-any.whl", hash = "sha256:89feca2e8be31609770c0e8a6d88fa21d00ee25bd3d56b4aafce59d35dd63b71", size = 768806, upload-time = "2026-03-21T14:26:03.129Z" }, + { url = "https://files.pythonhosted.org/packages/d9/81/235518176c3c8230e5274e91346ecf940591f653e73b0daeb505fb37eea9/mlx_vlm-0.4.4-py3-none-any.whl", hash = "sha256:3ff86ea738ab1914dc1b07e4fa5d4cc34bec5909e540692cfad0af808af13c11", size = 1014936, upload-time = "2026-04-04T15:18:59.328Z" }, ] [[package]]