Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .mlx_typings/mlx/nn/layers/quantized.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
This type stub file was generated by pyright.
"""

from typing import Callable, Optional, Union
from typing import Any, Callable, Optional, Union

import mlx.core as mx
from base import Module
Expand All @@ -13,8 +13,10 @@ def quantize(
bits: int = ...,
*,
mode: str = ...,
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = ...,
): # -> None:
class_predicate: Optional[
Callable[[str, Module], Union[bool, dict[str, Any]]]
] = ...,
) -> None:
"""Quantize the sub-modules of a module according to a predicate.

By default all layers that define a ``to_quantized(group_size, bits)``
Expand Down
31 changes: 31 additions & 0 deletions .mlx_typings/mlx_lm/models/gemma4.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from dataclasses import dataclass
from typing import Any, Optional

import mlx.core as mx
import mlx.nn as nn

from . import gemma4_text
from .base import BaseModelArgs
from .cache import KVCache, RotatingKVCache

@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
text_config: Optional[dict[str, Any]]
vocab_size: int

def __post_init__(self) -> None: ...

class Model(nn.Module):
args: ModelArgs
model_type: str
language_model: gemma4_text.Model

def __init__(self, args: ModelArgs) -> None: ...
def __call__(self, *args: Any, **kwargs: Any) -> mx.array: ...
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
@property
def layers(self) -> list[gemma4_text.DecoderLayer]: ...
@property
def quant_predicate(self) -> Any: ...
def make_cache(self) -> list[KVCache | RotatingKVCache]: ...
179 changes: 179 additions & 0 deletions .mlx_typings/mlx_lm/models/gemma4_text.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import mlx.core as mx
import mlx.nn as nn

from .base import BaseModelArgs
from .cache import KVCache, RotatingKVCache
from .switch_layers import SwitchGLU

@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
head_dim: int
global_head_dim: int
global_partial_rotary_factor: float
rms_norm_eps: float
vocab_size: int
vocab_size_per_layer_input: int
num_key_value_heads: int
num_global_key_value_heads: Optional[int]
num_kv_shared_layers: int
pad_token_id: int
hidden_size_per_layer_input: int
rope_traditional: bool
partial_rotary_factor: float
rope_parameters: Optional[Dict[str, Any]]
sliding_window: int
sliding_window_pattern: int
max_position_embeddings: int
attention_k_eq_v: bool
final_logit_softcapping: float
use_double_wide_mlp: bool
enable_moe_block: bool
num_experts: Optional[int]
top_k_experts: Optional[int]
moe_intermediate_size: Optional[int]
layer_types: Optional[List[str]]
tie_word_embeddings: bool

def __post_init__(self) -> None: ...

class MLP(nn.Module):
gate_proj: nn.Linear
down_proj: nn.Linear
up_proj: nn.Linear

def __init__(self, config: ModelArgs, layer_idx: int = 0) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...

class Router(nn.Module):
proj: nn.Linear
scale: mx.array
per_expert_scale: mx.array

def __init__(self, config: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...

class Experts(nn.Module):
switch_glu: SwitchGLU

def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self, x: mx.array, top_k_indices: mx.array, top_k_weights: mx.array
) -> mx.array: ...

class Attention(nn.Module):
layer_idx: int
layer_type: str
is_sliding: bool
head_dim: int
n_heads: int
n_kv_heads: int
use_k_eq_v: bool
scale: float
q_proj: nn.Linear
k_proj: nn.Linear
v_proj: nn.Linear
o_proj: nn.Linear
q_norm: nn.Module
k_norm: nn.Module
v_norm: nn.Module
rope: nn.Module

def __init__(self, config: ModelArgs, layer_idx: int) -> None: ...
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...

class DecoderLayer(nn.Module):
layer_idx: int
layer_type: str
self_attn: Attention
mlp: MLP
enable_moe: bool
router: Router
experts: Experts
input_layernorm: nn.Module
post_attention_layernorm: nn.Module
pre_feedforward_layernorm: nn.Module
post_feedforward_layernorm: nn.Module
post_feedforward_layernorm_1: nn.Module
post_feedforward_layernorm_2: nn.Module
pre_feedforward_layernorm_2: nn.Module
hidden_size_per_layer_input: int
per_layer_input_gate: Optional[nn.Linear]
per_layer_projection: Optional[nn.Linear]
post_per_layer_input_norm: Optional[nn.Module]
layer_scalar: mx.array

def __init__(self, config: ModelArgs, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = ...,
cache: Optional[Any] = ...,
per_layer_input: Optional[mx.array] = ...,
shared_kv: Optional[tuple[mx.array, mx.array]] = ...,
offset: Optional[mx.array] = ...,
) -> tuple[mx.array, tuple[mx.array, mx.array], mx.array]: ...

class Gemma4TextModel(nn.Module):
config: ModelArgs
vocab_size: int
window_size: int
sliding_window_pattern: int
num_hidden_layers: int
embed_tokens: nn.Embedding
embed_scale: float
layers: list[DecoderLayer]
norm: nn.Module
hidden_size_per_layer_input: int
embed_tokens_per_layer: Optional[nn.Embedding]
per_layer_model_projection: Optional[nn.Linear]
per_layer_projection_norm: Optional[nn.Module]
previous_kvs: list[int]

def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
inputs: Optional[mx.array] = ...,
cache: Optional[list[Any]] = ...,
input_embeddings: Optional[mx.array] = ...,
per_layer_inputs: Optional[mx.array] = ...,
) -> mx.array: ...
def _get_per_layer_inputs(
self,
input_ids: Optional[mx.array],
input_embeddings: Optional[mx.array] = ...,
) -> mx.array: ...
def _project_per_layer_inputs(
self,
input_embeddings: mx.array,
per_layer_inputs: Optional[mx.array] = ...,
) -> mx.array: ...
def _make_masks(self, h: mx.array, cache: list[Any]) -> list[Any]: ...

class Model(nn.Module):
args: ModelArgs
model_type: str
model: Gemma4TextModel
final_logit_softcapping: float
tie_word_embeddings: bool
lm_head: nn.Linear

def __init__(self, args: ModelArgs) -> None: ...
def __call__(self, *args: Any, **kwargs: Any) -> mx.array: ...
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
@property
def layers(self) -> list[DecoderLayer]: ...
@property
def head_dim(self) -> int: ...
@property
def n_kv_heads(self) -> int: ...
@property
def quant_predicate(self) -> Any: ...
def make_cache(self) -> list[KVCache | RotatingKVCache]: ...
2 changes: 2 additions & 0 deletions .mlx_typings/mlx_lm/tokenizer_utils.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ class TokenizerWrapper:
think_end: str | None
think_start_id: int | None
think_end_id: int | None
think_start_tokens: list[int] | None
think_end_tokens: list[int] | None

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
model_id = "mlx-community/gemma-4-26b-a4b-it-4bit"
n_layers = 30
hidden_size = 2816
num_key_value_heads = 8
supports_tensor = true
tasks = ["TextGeneration"]
family = "gemma"
quantization = "4bit"
base_model = "Gemma 4 26B A4B"
capabilities = ["text", "vision"]

context_length = 262144

[storage_size]
in_bytes = 15608614044
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
model_id = "mlx-community/gemma-4-26b-a4b-it-6bit"
n_layers = 30
hidden_size = 2816
num_key_value_heads = 8
supports_tensor = true
tasks = ["TextGeneration"]
family = "gemma"
quantization = "6bit"
base_model = "Gemma 4 26B A4B"
capabilities = ["text", "vision"]

context_length = 262144

[storage_size]
in_bytes = 21781015708
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
model_id = "mlx-community/gemma-4-26b-a4b-it-8bit"
n_layers = 30
hidden_size = 2816
num_key_value_heads = 8
supports_tensor = true
tasks = ["TextGeneration"]
family = "gemma"
quantization = "8bit"
base_model = "Gemma 4 26B A4B"
capabilities = ["text", "vision"]

context_length = 262144

[storage_size]
in_bytes = 27953417372
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
model_id = "mlx-community/gemma-4-26b-a4b-it-bf16"
n_layers = 30
hidden_size = 2816
num_key_value_heads = 8
supports_tensor = true
tasks = ["TextGeneration"]
family = "gemma"
quantization = "bf16"
base_model = "Gemma 4 26B A4B"
capabilities = ["text", "vision"]

context_length = 262144

[storage_size]
in_bytes = 51611872412
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
model_id = "mlx-community/gemma-4-31b-it-4bit"
n_layers = 60
hidden_size = 5376
num_key_value_heads = 16
supports_tensor = true
tasks = ["TextGeneration"]
family = "gemma"
quantization = "4bit"
base_model = "Gemma 4 31B"
capabilities = ["text", "vision"]

context_length = 262144

[storage_size]
in_bytes = 18411755224
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
model_id = "mlx-community/gemma-4-31b-it-6bit"
n_layers = 60
hidden_size = 5376
num_key_value_heads = 16
supports_tensor = true
tasks = ["TextGeneration"]
family = "gemma"
quantization = "6bit"
base_model = "Gemma 4 31B"
capabilities = ["text", "vision"]

context_length = 262144

[storage_size]
in_bytes = 26087306968
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
model_id = "mlx-community/gemma-4-31b-it-8bit"
n_layers = 60
hidden_size = 5376
num_key_value_heads = 16
supports_tensor = true
tasks = ["TextGeneration"]
family = "gemma"
quantization = "8bit"
base_model = "Gemma 4 31B"
capabilities = ["text", "vision"]

context_length = 262144

[storage_size]
in_bytes = 33762858712
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
model_id = "mlx-community/gemma-4-31b-it-bf16"
n_layers = 60
hidden_size = 5376
num_key_value_heads = 16
supports_tensor = true
tasks = ["TextGeneration"]
family = "gemma"
quantization = "bf16"
base_model = "Gemma 4 31B"
capabilities = ["text", "vision"]

context_length = 262144

[storage_size]
in_bytes = 62546177752
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
model_id = "mlx-community/gemma-4-e2b-it-4bit"
n_layers = 35
hidden_size = 1536
num_key_value_heads = 1
supports_tensor = true
tasks = ["TextGeneration"]
family = "gemma"
quantization = "4bit"
base_model = "Gemma 4 E2B"
capabilities = ["text", "vision"]

context_length = 131072

[storage_size]
in_bytes = 3580765126
Loading
Loading