[Tunix] Add multi-modal (vision) support to Gemma 4.#1545
Open
msghik wants to merge 1 commit into
Open
Conversation
Enables SFT/LoRA/RL of Gemma 4 with text+image inputs, mirroring the SigLIP-based integration already present in Gemma 3. The training and sampling infrastructures (peft_trainer, sampler, rl/common) already forward an `images` kwarg to the model; this change closes the gap at the model level so they actually work end-to-end on Gemma 4. Specifically: - `ModelConfig` gains an optional `vision_config: SigLIPConfig | None`, exposed via a `text_only=True` parameter on the e2b/e4b/31b/26b-a4b factory classmethods (matching gemma3_*_pt/it). - `ShardingConfig` gains a `siglip` field for the encoder's sharding. - `Embedder` optionally builds the `mm_input_projection` and `mm_soft_embedding_norm` layers and exposes `encode_vision`. - `Gemma4.__init__` constructs a `SigLiP` encoder when a vision config is set; `__call__` accepts an `images` kwarg and merges the soft vision tokens into the text embeddings at the placeholder positions. - Adds `get_attention_mask` (bidirectional over image spans) and `get_model_input` (used for LoRA tracing) on the model. - `params_safetensors._get_key_and_transform_mapping` adds vision-tower and multi-modal projector mappings when vision is enabled. The SigLIP encoder, embedding-merge utility and attention-mask helper are reused from `tunix.models.gemma3` to avoid ~900 lines of duplication; they have no Gemma 3-specific assumptions. Tests: 4 new tests cover text-only construction, multi-modal forward pass, the helpful error when images are passed without a vision encoder, and the text-only attention mask shape. All existing Gemma 3/4 tests continue to pass.
Collaborator
|
Hi @msghik thanks for adding the vision support. Can you paste a sample output for the multi-modal change? |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Resolves #1543
Enables SFT / LoRA / RL of Gemma 4 with text + image inputs, mirroring the SigLIP-based integration already present in Gemma 3. The training and sampling code paths (
peft_trainer,sampler,rl/common) already forward animageskwarg to the model — they just had no Gemma 4 model that accepted it. This change closes that gap at the model level so they work end-to-end on Gemma 4.What changed
tunix/models/gemma4/model.pyModelConfiggains an optionalvision_config: SigLIPConfig | None, exposed via atext_only: bool = Trueparameter on thegemma4_e2b/gemma4_e4b/gemma4_31b/gemma4_26b_a4bfactory classmethods (matches the existinggemma3_*_pt/it(text_only=...)pattern).ShardingConfiggains asiglipfield for the vision encoder's sharding.RMSNormnow accepts either aShardingConfig(existing call sites) or a sharding tuple (new vision call site) — fully backward compatible.Embedderoptionally builds themm_input_projectionandmm_soft_embedding_normlayers and exposesencode_vision.Gemma4.__init__constructs aSigLiPencoder whenvision_configis set;Gemma4.__call__accepts animageskwarg and merges the soft vision tokens into the text embeddings at the placeholder positions.get_attention_mask(bidirectional over image spans, causal over text) andget_model_input(used for LoRA tracing) on the model.tunix/models/gemma4/params_safetensors.py_get_key_and_transform_mappingadds vision-tower and multi-modal projector mappings whenvision_configis set (no change for text-only configs).tests/models/gemma4/model_test.pyDesign notes
vision.py), the embedding-merge helper (merge_embeddings.py), and the bidirectional-causal attention mask (utils.py) are reused fromtunix.models.gemma3to avoid ~900 lines of duplication; they have no Gemma 3-specific assumptions. Happy to refactor these into a sharedtunix.models.commonnamespace if reviewers prefer.vision_configdefaults toNone,imagesdefaults toNone. Existing Gemma 4 call sites and tests are unaffected.Usage