From 2faf8d48ee0467bf87ba51d5cb0fafe503e48439 Mon Sep 17 00:00:00 2001 From: msghik Date: Wed, 27 May 2026 20:12:02 +0000 Subject: [PATCH] [Tunix] Add multi-modal (vision) support to Gemma 4. Enables SFT/LoRA/RL of Gemma 4 with text+image inputs, mirroring the SigLIP-based integration already present in Gemma 3. The training and sampling infrastructures (peft_trainer, sampler, rl/common) already forward an `images` kwarg to the model; this change closes the gap at the model level so they actually work end-to-end on Gemma 4. Specifically: - `ModelConfig` gains an optional `vision_config: SigLIPConfig | None`, exposed via a `text_only=True` parameter on the e2b/e4b/31b/26b-a4b factory classmethods (matching gemma3_*_pt/it). - `ShardingConfig` gains a `siglip` field for the encoder's sharding. - `Embedder` optionally builds the `mm_input_projection` and `mm_soft_embedding_norm` layers and exposes `encode_vision`. - `Gemma4.__init__` constructs a `SigLiP` encoder when a vision config is set; `__call__` accepts an `images` kwarg and merges the soft vision tokens into the text embeddings at the placeholder positions. - Adds `get_attention_mask` (bidirectional over image spans) and `get_model_input` (used for LoRA tracing) on the model. - `params_safetensors._get_key_and_transform_mapping` adds vision-tower and multi-modal projector mappings when vision is enabled. The SigLIP encoder, embedding-merge utility and attention-mask helper are reused from `tunix.models.gemma3` to avoid ~900 lines of duplication; they have no Gemma 3-specific assumptions. Tests: 4 new tests cover text-only construction, multi-modal forward pass, the helpful error when images are passed without a vision encoder, and the text-only attention mask shape. All existing Gemma 3/4 tests continue to pass. --- tests/models/gemma4/model_test.py | 141 ++++++++++++++++ tunix/models/gemma4/model.py | 188 +++++++++++++++++++++- tunix/models/gemma4/params_safetensors.py | 98 +++++++++++ 3 files changed, 422 insertions(+), 5 deletions(-) diff --git a/tests/models/gemma4/model_test.py b/tests/models/gemma4/model_test.py index fdf8b457e..3449361bc 100644 --- a/tests/models/gemma4/model_test.py +++ b/tests/models/gemma4/model_test.py @@ -16,10 +16,13 @@ from __future__ import annotations +import dataclasses + from absl.testing import absltest from flax import nnx import jax import jax.numpy as jnp +from tunix.models.gemma3 import vision as vision_lib from tunix.models.gemma4 import model as model_lib @@ -195,6 +198,144 @@ def body_fn(step, _): _, logits = compiled_decode(state) self.assertEqual(logits.shape, (2, 32, config.num_embed)) + def test_text_only_no_vision_encoder(self): + config = model_lib.ModelConfig.gemma4_e2b() + self.assertIsNone(config.vision_config) + config.num_layers = 1 + config.embed_dim = 256 + config.hidden_dim = 512 + config.num_heads = 4 + config.head_dim = 64 + config.num_kv_heads = 1 + config.frac_shared_layers = 0.0 + + rngs = nnx.Rngs(0) + model = model_lib.Gemma4(config, rngs=rngs) + self.assertIsNone(model.vision_encoder) + self.assertFalse(hasattr(model.embedder, "mm_input_projection")) + + def test_forward_pass_multimodal(self): + # Use a tiny SigLIP config so the test stays fast/light. + small_vision_config = vision_lib.SigLIPConfig( + num_mm_tokens_per_image_prepool=16, + num_mm_tokens_per_image=4, + image_height=32, + image_width=32, + image_channels=3, + soft_token_placeholder=219, + patch_size=(8, 8), + width=32, + depth=1, + mlp_dim=64, + num_heads=4, + ) + base_config = model_lib.ModelConfig.gemma4_e2b() + config = dataclasses.replace( + base_config, + num_layers=1, + embed_dim=256, + hidden_dim=512, + num_heads=4, + head_dim=64, + num_kv_heads=1, + frac_shared_layers=0.0, + vision_config=small_vision_config, + ) + + rngs = nnx.Rngs(0) + model = model_lib.Gemma4(config, rngs=rngs) + self.assertIsNotNone(model.vision_encoder) + self.assertTrue(hasattr(model.embedder, "mm_input_projection")) + self.assertTrue(hasattr(model.embedder, "mm_soft_embedding_norm")) + + batch_size = 2 + seq_len = 16 + num_images = 1 + num_mm = small_vision_config.num_mm_tokens_per_image + + # Place the soft-image placeholder in positions [1:1+num_mm] of each row. + tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + tokens = tokens.at[:, 1 : 1 + num_mm].set( + small_vision_config.soft_token_placeholder + ) + + positions = jnp.tile( + jnp.arange(seq_len)[None, :], (batch_size, 1) + ) + attn_mask = model.get_attention_mask(tokens) + + images = jnp.zeros( + ( + batch_size, + num_images, + small_vision_config.image_height, + small_vision_config.image_width, + small_vision_config.image_channels, + ), + dtype=jnp.float32, + ) + + logits, _ = model( + tokens, + positions=positions, + attention_mask=attn_mask, + images=images, + ) + self.assertEqual(logits.shape, (batch_size, seq_len, config.num_embed)) + + def test_multimodal_call_without_vision_encoder_raises(self): + config = model_lib.ModelConfig.gemma4_e2b() + config.num_layers = 1 + config.embed_dim = 256 + config.hidden_dim = 512 + config.num_heads = 4 + config.head_dim = 64 + config.num_kv_heads = 1 + config.frac_shared_layers = 0.0 + + rngs = nnx.Rngs(0) + model = model_lib.Gemma4(config, rngs=rngs) + + tokens = jnp.ones((1, 4), dtype=jnp.int32) + positions = jnp.arange(4)[None, :] + attn_mask = jnp.tril(jnp.ones((4, 4), dtype=jnp.bool_))[None, ...] + images = jnp.zeros((1, 1, 32, 32, 3), dtype=jnp.float32) + + with self.assertRaises(ValueError): + model( + tokens, + positions=positions, + attention_mask=attn_mask, + images=images, + ) + + def test_get_attention_mask_text_only(self): + config = model_lib.ModelConfig.gemma4_e2b() + config.num_layers = 1 + config.embed_dim = 256 + config.hidden_dim = 512 + config.num_heads = 4 + config.head_dim = 64 + config.num_kv_heads = 1 + config.frac_shared_layers = 0.0 + + rngs = nnx.Rngs(0) + model = model_lib.Gemma4(config, rngs=rngs) + # No vision config => no bidirectional span; mask should be purely causal. + tokens = jnp.array([[1, 2, 3, 0, 0]]) + mask = model.get_attention_mask(tokens) + expected = jnp.array( + [[ + [1, 0, 0, 0, 0], + [1, 1, 0, 0, 0], + [1, 1, 1, 0, 0], + [1, 1, 1, 0, 0], + [1, 1, 1, 0, 0], + ]], + dtype=jnp.bool_, + ) + self.assertTrue(bool(jnp.all(mask == expected))) + if __name__ == "__main__": absltest.main() diff --git a/tunix/models/gemma4/model.py b/tunix/models/gemma4/model.py index c64a32522..8997fdc82 100644 --- a/tunix/models/gemma4/model.py +++ b/tunix/models/gemma4/model.py @@ -18,7 +18,8 @@ import enum from functools import partial import itertools -from typing import Tuple +from typing import Tuple, Union +import einops import flax from flax import nnx import jax @@ -31,6 +32,9 @@ from jax.sharding import PartitionSpec as P import jaxtyping from tunix.generate.mappings import BackendMappingMixin +from tunix.models.gemma3 import merge_embeddings as merge_embeddings_lib +from tunix.models.gemma3 import utils as mm_utils +from tunix.models.gemma3 import vision from tunix.models.gemma4 import moe from tunix.utils import compat from tunix.utils import env_utils @@ -75,6 +79,8 @@ class ShardingConfig: per_layer_input_gate: Tuple[str | None, ...] per_layer_projection: Tuple[str | None, ...] per_layer_input_embedding: Tuple[str | None, ...] + # SigLIP vision encoder sharding. + siglip: vision.SigLIPShardingConfig | None = None @staticmethod def get_default_sharding(is_sampling: bool = False): @@ -100,6 +106,7 @@ def get_default_sharding(is_sampling: bool = False): per_layer_input_gate=(fsdp, 'tp'), per_layer_projection=('tp', fsdp), per_layer_input_embedding=('tp', None, fsdp), + siglip=vision.SigLIPShardingConfig.get_default_sharding(is_sampling), ) @@ -131,6 +138,10 @@ class ModelConfig: local_scale_factor: float = 1.0 global_scale_factor: float = 1.0 + # Vision config. If set, the model includes a SigLIP vision encoder and + # accepts `images` in the forward pass for multi-modal inputs. + vision_config: vision.SigLIPConfig | None = None + shd_config: ShardingConfig = ShardingConfig.get_default_sharding() remat_config: RematConfig = RematConfig.NONE param_dtype: jnp.dtype = jnp.float32 @@ -157,6 +168,7 @@ def __post_init__(self): def gemma4_e2b( cls, sharding_config: ShardingConfig = ShardingConfig.get_default_sharding(), + text_only: bool = True, ) -> 'ModelConfig': return cls( num_layers=35, @@ -171,6 +183,7 @@ def gemma4_e2b( per_layer_input_dim=256, frac_shared_layers=20.0 / 35, override_kv_shared_ffw_hidden=int(1536 * 4 * 2), + vision_config=None if text_only else vision.SigLIPConfig(), attention_pattern=( AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, @@ -184,6 +197,7 @@ def gemma4_e2b( def gemma4_e4b( cls, sharding_config: ShardingConfig = ShardingConfig.get_default_sharding(), + text_only: bool = True, ) -> 'ModelConfig': return cls( num_layers=42, @@ -197,6 +211,7 @@ def gemma4_e4b( shd_config=sharding_config, per_layer_input_dim=256, frac_shared_layers=18.0 / 42, + vision_config=None if text_only else vision.SigLIPConfig(), attention_pattern=( AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, @@ -211,6 +226,7 @@ def gemma4_e4b( def gemma4_31b( cls, sharding_config: ShardingConfig = ShardingConfig.get_default_sharding(), + text_only: bool = True, ) -> 'ModelConfig': return cls( num_layers=60, @@ -224,6 +240,7 @@ def gemma4_31b( sliding_window_size=1024, shd_config=sharding_config, k_eq_v_global=True, + vision_config=None if text_only else vision.SigLIPConfig(), attention_pattern=( AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, @@ -238,6 +255,7 @@ def gemma4_31b( def gemma4_26b_a4b( cls, sharding_config: ShardingConfig = ShardingConfig.get_default_sharding(), + text_only: bool = True, ) -> 'ModelConfig': return cls( num_layers=30, @@ -257,6 +275,7 @@ def gemma4_26b_a4b( moe_dense_hidden_dim=2112, k_eq_v_global=True, global_rope_proportion=0.25, + vision_config=None if text_only else vision.SigLIPConfig(), attention_pattern=( AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, @@ -275,6 +294,7 @@ def __init__( self, config: ModelConfig, rngs: nnx.Rngs, + vision_proj_dim: int | None = None, ): self.config = config self.vocab_size = config.num_embed @@ -314,6 +334,23 @@ def __init__( sharding=config.shd_config.per_layer_input_embedding, ) + if vision_proj_dim: + self.mm_soft_embedding_norm = RMSNorm( + vision_proj_dim, + rngs=rngs, + sharding=config.shd_config.vision_soft_emb_norm_weight, + dtype=self.config.dtype, + param_dtype=self.param_dtype, + ) + self.mm_input_projection = Einsum( + einsum_str='...TM,MD->...TD', + shape=(vision_proj_dim, self.embed_dim), + rngs=rngs, + sharding=config.shd_config.vision_proj, + dtype=self.config.dtype, + param_dtype=self.param_dtype, + ) + def encode(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: x = self.input_embedding[(x,)] x *= jnp.sqrt(x.shape[-1]).astype(x.dtype) @@ -333,6 +370,12 @@ def encode_per_layer_input( y *= jnp.sqrt(self.config.per_layer_input_dim).astype(y.dtype) return (x + y) * jax.lax.rsqrt(2.0).astype(x.dtype) + def encode_vision(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: + """Projects vision encoder outputs into the language model embedding space.""" + x = self.mm_soft_embedding_norm(x) + x = self.mm_input_projection(x) + return x + def decode(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: x = jnp.astype(x, self.config.dtype) w = jnp.astype(self.input_embedding.value, self.config.dtype) @@ -428,13 +471,20 @@ def __init__( dim: int, *, rngs: nnx.Rngs, - sharding: ShardingConfig = ShardingConfig.get_default_sharding(), + sharding: Union[ShardingConfig, Tuple[str | None, ...]] = ( + ShardingConfig.get_default_sharding() + ), dtype: jnp.dtype, param_dtype: jnp.dtype, ): + scale_sharding = ( + sharding.rms_norm_weight + if isinstance(sharding, ShardingConfig) + else sharding + ) self.scale = nnx.Param( nnx.initializers.ones_init()(rngs.params(), dim).astype(param_dtype), - sharding=sharding.rms_norm_weight, + sharding=scale_sharding, ) self.dtype = dtype @@ -1186,7 +1236,25 @@ class Gemma4(BackendMappingMixin, nnx.Module): def __init__(self, config: ModelConfig, *, rngs: nnx.Rngs): self.config = config - self.embedder = Embedder(config, rngs=rngs) + + if config.vision_config is not None: + self.vision_encoder = vision.SigLiP( + config=config.vision_config, + shd_config=config.shd_config.siglip, + rngs=rngs, + ) + else: + self.vision_encoder = None + + self.embedder = Embedder( + config, + rngs=rngs, + vision_proj_dim=( + self.vision_encoder.siglip_encoder.width + if self.vision_encoder is not None + else None + ), + ) pattern = ( config.attention_pattern @@ -1241,13 +1309,15 @@ def __call__( cache=None, attention_mask=None, decode_only_last_token=False, + *, + images: jaxtyping.Array | None = None, ): if positions is None: B, T = tokens.shape # pylint: disable=invalid-name positions = jnp.tile(jnp.arange(T)[None, :], (B, 1)) new_cache = {} - x = self.embedder.encode(tokens) + x = self._encode_and_get_inputs(tokens=tokens, images=images) per_layer_inputs = None if self.config.per_layer_input_dim > 0: @@ -1315,3 +1385,111 @@ def init_cache(self, batch_size, max_seq_len, dtype): continue # Skip shared layers. cache[f'layer_{i}'] = layer.init_cache(batch_size, max_seq_len, dtype) return cache + + def _encode_and_get_inputs( + self, + *, + tokens: jaxtyping.Array, # (B, L) + images: jaxtyping.Array | None = None, # (B, H, W, C) or (B, N, H, W, C) + ) -> jaxtyping.Array: + """Encode the text tokens, eventually including the vision embeddings.""" + if self.config.vision_config is not None and images is not None: + self._assert_support_mm() + if len(images.shape) == 4: # If num_images is 1, add an axis. + images = einops.rearrange(images, 'b h w c -> b 1 h w c') + + x = self.embedder.encode(tokens) + + if images is not None: + x = self._merge_mm_embeddings(tokens=tokens, embeddings=x, images=images) + return x + + def _assert_support_mm(self) -> None: + if self.vision_encoder is None: + raise ValueError( + f'The model {type(self).__name__!r} does not have a vision encoder,' + ' yet images are provided. Construct the model with a `vision_config`' + ' (e.g. `ModelConfig.gemma4_e4b(text_only=False)`) to enable' + ' multi-modal inputs.' + ) + + def _merge_mm_embeddings( + self, + *, + tokens: jaxtyping.ArrayLike, # (B, L) + embeddings: jaxtyping.ArrayLike, # (B, L, D) + images: jaxtyping.ArrayLike, # (B, N, H, W, C) + ) -> jaxtyping.ArrayLike: + """Update the text embeddings to include the vision soft tokens.""" + soft_embeddings = self._encode_vision(images) + if self.config.vision_config is None: + raise ValueError( + '`vision_config` is required for `_merge_mm_embeddings`.' + ) + return merge_embeddings_lib.merge_embeddings( + text_embeddings=embeddings, + vision_embeddings=soft_embeddings, + mask=tokens == self.config.vision_config.soft_token_placeholder, + ) + + def _encode_vision( + self, images: jaxtyping.ArrayLike # (B, N, H, W, C) + ) -> jaxtyping.ArrayLike: # (B, N, P, D) + """Encode the images into the same space as the text embeddings.""" + if self.vision_encoder is None: + raise ValueError('`vision_encoder` is None, cannot encode images.') + soft_embeddings = self.vision_encoder(images=images) + soft_embeddings = self.embedder.encode_vision(soft_embeddings) + return soft_embeddings + + def get_attention_mask( + self, + tokens: jaxtyping.ArrayLike, # (B, L) + *, + inputs_mask: jaxtyping.ArrayLike | None = None, # (B, L) + ): + """Returns the attention mask for the transformer. + + For multi-modal inputs, the mask is bidirectional over image soft-token + spans (so all soft tokens of an image attend to each other) while remaining + causal over text tokens, matching the Gemma 3 behavior. + """ + token_placeholder_id = ( + None + if self.config.vision_config is None + else self.config.vision_config.soft_token_placeholder + ) + return mm_utils.get_attention_mask( + tokens, + inputs_mask=inputs_mask, + token_placeholder_id=token_placeholder_id, + ) + + def get_model_input(self): + """Returns a dummy model input for the transformer. + + Used to trace the graph for LoRA application and similar transforms. + """ + dummy_batch_size = 2 + dummy_seq_len = 1 + inputs = { + 'tokens': jnp.ones( + (dummy_batch_size, dummy_seq_len), dtype=jnp.int32 + ), + 'positions': jnp.ones( + (dummy_batch_size, dummy_seq_len), dtype=jnp.int32 + ), + 'cache': None, + 'attention_mask': jnp.ones( + (dummy_batch_size, 1, dummy_seq_len), dtype=jnp.bool + ), + } + + if self.vision_encoder is not None: + vc = self.config.vision_config + inputs['images'] = jnp.ones( + (dummy_batch_size, 1, vc.image_height, vc.image_width, + vc.image_channels), + dtype=jnp.float32, + ) + return inputs diff --git a/tunix/models/gemma4/params_safetensors.py b/tunix/models/gemma4/params_safetensors.py index 9eae3f9a8..e1b5bc254 100644 --- a/tunix/models/gemma4/params_safetensors.py +++ b/tunix/models/gemma4/params_safetensors.py @@ -246,6 +246,104 @@ def _get_key_and_transform_mapping(cfg: model_lib.ModelConfig): ), } + # Vision Tower (SigLIP) and multi-modal projector. + if cfg.vision_config is not None: + mapping.update({ + r"vision_tower\.vision_model\.embeddings\.patch_embedding\.weight": ( + "vision_encoder.siglip_encoder.embedding.kernel", + ((2, 3, 1, 0), None), + ), + r"vision_tower\.vision_model\.embeddings\.patch_embedding\.bias": ( + "vision_encoder.siglip_encoder.embedding.bias", + None, + ), + r"vision_tower\.vision_model\.embeddings\.position_embedding\.weight": ( + "vision_encoder.siglip_encoder.pos_embedding", + None, + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.self_attn\.q_proj\.weight": ( + r"vision_encoder.siglip_encoder.transformer.blocks.\1.attn.query_proj.kernel", + ((1, 0), None), + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.self_attn\.k_proj\.weight": ( + r"vision_encoder.siglip_encoder.transformer.blocks.\1.attn.key_proj.kernel", + ((1, 0), None), + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.self_attn\.v_proj\.weight": ( + r"vision_encoder.siglip_encoder.transformer.blocks.\1.attn.value_proj.kernel", + ((1, 0), None), + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.weight": ( + r"vision_encoder.siglip_encoder.transformer.blocks.\1.attn.out_proj.kernel", + ((1, 0), None), + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.self_attn\.q_proj\.bias": ( + r"vision_encoder.siglip_encoder.transformer.blocks.\1.attn.query_proj.bias", + None, + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.self_attn\.k_proj\.bias": ( + r"vision_encoder.siglip_encoder.transformer.blocks.\1.attn.key_proj.bias", + None, + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.self_attn\.v_proj\.bias": ( + r"vision_encoder.siglip_encoder.transformer.blocks.\1.attn.value_proj.bias", + None, + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.bias": ( + r"vision_encoder.siglip_encoder.transformer.blocks.\1.attn.out_proj.bias", + None, + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.layer_norm1\.weight": ( + r"vision_encoder.siglip_encoder.transformer.blocks.\1.ln1.scale", + None, + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.layer_norm1\.bias": ( + r"vision_encoder.siglip_encoder.transformer.blocks.\1.ln1.bias", + None, + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.layer_norm2\.weight": ( + r"vision_encoder.siglip_encoder.transformer.blocks.\1.ln2.scale", + None, + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.layer_norm2\.bias": ( + r"vision_encoder.siglip_encoder.transformer.blocks.\1.ln2.bias", + None, + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.mlp\.fc1\.weight": ( + r"vision_encoder.siglip_encoder.transformer.blocks.\1.mlp.fc1.kernel", + ((1, 0), None), + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.mlp\.fc2\.weight": ( + r"vision_encoder.siglip_encoder.transformer.blocks.\1.mlp.fc2.kernel", + ((1, 0), None), + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.mlp\.fc1\.bias": ( + r"vision_encoder.siglip_encoder.transformer.blocks.\1.mlp.fc1.bias", + None, + ), + r"vision_tower\.vision_model\.encoder\.layers\.([0-9]+)\.mlp\.fc2\.bias": ( + r"vision_encoder.siglip_encoder.transformer.blocks.\1.mlp.fc2.bias", + None, + ), + r"vision_tower\.vision_model\.post_layernorm\.weight": ( + "vision_encoder.siglip_encoder.transformer.encoder_norm.scale", + None, + ), + r"vision_tower\.vision_model\.post_layernorm\.bias": ( + "vision_encoder.siglip_encoder.transformer.encoder_norm.bias", + None, + ), + # Multi-modal Projector + r"multi_modal_projector\.mm_input_projection_weight": ( + "embedder.mm_input_projection.w", + None, + ), + r"multi_modal_projector\.mm_soft_emb_norm\.weight": ( + "embedder.mm_soft_embedding_norm.scale", + None, + ), + }) + return mapping