Skip to content

[Tunix] Add multi-modal (vision) support to Gemma 4.#1545

Open
msghik wants to merge 1 commit into
google:mainfrom
msghik:add-gemma4-vision-support
Open

[Tunix] Add multi-modal (vision) support to Gemma 4.#1545
msghik wants to merge 1 commit into
google:mainfrom
msghik:add-gemma4-vision-support

Conversation

@msghik

@msghik msghik commented May 27, 2026

Copy link
Copy Markdown

Resolves #1543

Enables SFT / LoRA / RL of Gemma 4 with text + image inputs, mirroring the SigLIP-based integration already present in Gemma 3. The training and sampling code paths (peft_trainer, sampler, rl/common) already forward an images kwarg to the model — they just had no Gemma 4 model that accepted it. This change closes that gap at the model level so they work end-to-end on Gemma 4.

What changed

tunix/models/gemma4/model.py

  • ModelConfig gains an optional vision_config: SigLIPConfig | None, exposed via a text_only: bool = True parameter on the gemma4_e2b / gemma4_e4b / gemma4_31b / gemma4_26b_a4b factory classmethods (matches the existing gemma3_*_pt/it(text_only=...) pattern).
  • ShardingConfig gains a siglip field for the vision encoder's sharding.
  • RMSNorm now accepts either a ShardingConfig (existing call sites) or a sharding tuple (new vision call site) — fully backward compatible.
  • Embedder optionally builds the mm_input_projection and mm_soft_embedding_norm layers and exposes encode_vision.
  • Gemma4.__init__ constructs a SigLiP encoder when vision_config is set; Gemma4.__call__ accepts an images kwarg and merges the soft vision tokens into the text embeddings at the placeholder positions.
  • Adds get_attention_mask (bidirectional over image spans, causal over text) and get_model_input (used for LoRA tracing) on the model.

tunix/models/gemma4/params_safetensors.py

  • _get_key_and_transform_mapping adds vision-tower and multi-modal projector mappings when vision_config is set (no change for text-only configs).

tests/models/gemma4/model_test.py

  • 4 new tests: multi-modal forward pass, text-only construction has no vision encoder, helpful error when images are passed to a text-only model, and the text-only attention-mask shape.

Design notes

  • The SigLIP encoder (vision.py), the embedding-merge helper (merge_embeddings.py), and the bidirectional-causal attention mask (utils.py) are reused from tunix.models.gemma3 to avoid ~900 lines of duplication; they have no Gemma 3-specific assumptions. Happy to refactor these into a shared tunix.models.common namespace if reviewers prefer.
  • The change is strictly additive: vision_config defaults to None, images defaults to None. Existing Gemma 4 call sites and tests are unaffected.

Usage

from tunix.models.gemma4 import model as model_lib
from tunix.models.gemma4 import params_safetensors

config = model_lib.ModelConfig.gemma4_e4b(text_only=False)  # enables SigLIP
model = params_safetensors.create_model_from_safe_tensors(
    file_dir=ckpt_path, config=config, mesh=mesh, dtype=jnp.bfloat16,
)

logits, _ = model(
    tokens,
    positions=positions,
    attention_mask=model.get_attention_mask(tokens),
    images=images,
)

Enables SFT/LoRA/RL of Gemma 4 with text+image inputs, mirroring the
SigLIP-based integration already present in Gemma 3. The training and
sampling infrastructures (peft_trainer, sampler, rl/common) already
forward an `images` kwarg to the model; this change closes the gap at
the model level so they actually work end-to-end on Gemma 4.

Specifically:
- `ModelConfig` gains an optional `vision_config: SigLIPConfig | None`,
  exposed via a `text_only=True` parameter on the e2b/e4b/31b/26b-a4b
  factory classmethods (matching gemma3_*_pt/it).
- `ShardingConfig` gains a `siglip` field for the encoder's sharding.
- `Embedder` optionally builds the `mm_input_projection` and
  `mm_soft_embedding_norm` layers and exposes `encode_vision`.
- `Gemma4.__init__` constructs a `SigLiP` encoder when a vision config
  is set; `__call__` accepts an `images` kwarg and merges the soft
  vision tokens into the text embeddings at the placeholder positions.
- Adds `get_attention_mask` (bidirectional over image spans) and
  `get_model_input` (used for LoRA tracing) on the model.
- `params_safetensors._get_key_and_transform_mapping` adds vision-tower
  and multi-modal projector mappings when vision is enabled.

The SigLIP encoder, embedding-merge utility and attention-mask helper
are reused from `tunix.models.gemma3` to avoid ~900 lines of
duplication; they have no Gemma 3-specific assumptions.

Tests: 4 new tests cover text-only construction, multi-modal forward
pass, the helpful error when images are passed without a vision encoder,
and the text-only attention mask shape. All existing Gemma 3/4 tests
continue to pass.
@tianshub

tianshub commented Jun 2, 2026

Copy link
Copy Markdown
Collaborator

Hi @msghik thanks for adding the vision support. Can you paste a sample output for the multi-modal change?

@tianshub tianshub left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @msghik thanks for adding the vision support. Can you paste a sample output for the multi-modal change?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MM SFT of Gemma4

3 participants