Fix TypeError when passing auto-derived segment_ids to models that don't accept it (Gemma/Llama)#1547
Open
msghik wants to merge 6 commits into
Open
Fix TypeError when passing auto-derived segment_ids to models that don't accept it (Gemma/Llama)#1547msghik wants to merge 6 commits into
msghik wants to merge 6 commits into
Conversation
Enables SFT/LoRA/RL of Gemma 4 with text+image inputs, mirroring the SigLIP-based integration already present in Gemma 3. The training and sampling infrastructures (peft_trainer, sampler, rl/common) already forward an `images` kwarg to the model; this change closes the gap at the model level so they actually work end-to-end on Gemma 4. Specifically: - `ModelConfig` gains an optional `vision_config: SigLIPConfig | None`, exposed via a `text_only=True` parameter on the e2b/e4b/31b/26b-a4b factory classmethods (matching gemma3_*_pt/it). - `ShardingConfig` gains a `siglip` field for the encoder's sharding. - `Embedder` optionally builds the `mm_input_projection` and `mm_soft_embedding_norm` layers and exposes `encode_vision`. - `Gemma4.__init__` constructs a `SigLiP` encoder when a vision config is set; `__call__` accepts an `images` kwarg and merges the soft vision tokens into the text embeddings at the placeholder positions. - Adds `get_attention_mask` (bidirectional over image spans) and `get_model_input` (used for LoRA tracing) on the model. - `params_safetensors._get_key_and_transform_mapping` adds vision-tower and multi-modal projector mappings when vision is enabled. The SigLIP encoder, embedding-merge utility and attention-mask helper are reused from `tunix.models.gemma3` to avoid ~900 lines of duplication; they have no Gemma 3-specific assumptions. Tests: 4 new tests cover text-only construction, multi-modal forward pass, the helpful error when images are passed without a vision encoder, and the text-only attention mask shape. All existing Gemma 3/4 tests continue to pass.
Add multi-modal (vision) support to Gemma 4
Enables SFT/LoRA/RL of Gemma 4 with text+image inputs, mirroring the SigLIP-based integration already present in Gemma 3. The training and sampling infrastructures (peft_trainer, sampler, rl/common) already forward an `images` kwarg to the model; this change closes the gap at the model level so they actually work end-to-end on Gemma 4. Specifically: - `ModelConfig` gains an optional `vision_config: SigLIPConfig | None`, exposed via a `text_only=True` parameter on the e2b/e4b/31b/26b-a4b factory classmethods (matching gemma3_*_pt/it). - `ShardingConfig` gains a `siglip` field for the encoder's sharding. - `Embedder` optionally builds the `mm_input_projection` and `mm_soft_embedding_norm` layers and exposes `encode_vision`. - `Gemma4.__init__` constructs a `SigLiP` encoder when a vision config is set; `__call__` accepts an `images` kwarg and merges the soft vision tokens into the text embeddings at the placeholder positions. - Adds `get_attention_mask` (bidirectional over image spans) and `get_model_input` (used for LoRA tracing) on the model. - `params_safetensors._get_key_and_transform_mapping` adds vision-tower and multi-modal projector mappings when vision is enabled. The SigLIP encoder, embedding-merge utility and attention-mask helper are reused from `tunix.models.gemma3` to avoid ~900 lines of duplication; they have no Gemma 3-specific assumptions. Tests: 4 new tests cover text-only construction, multi-modal forward pass, the helpful error when images are passed without a vision encoder, and the text-only attention mask shape. All existing Gemma 3/4 tests continue to pass.
…pt it. Fixes google#1539. Commit 49b63f7 ("Agentic GRPO improvements") started passing a derived non-pad mask as ``segment_ids`` to the model from rl/common.py for every RL training run. Only Qwen3's ``__call__`` was updated to accept the keyword; Gemma 2/3/4 and Llama 3 still don't, so GRPO training with those reference/policy models crashes with:: TypeError: Gemma3.__call__() got an unexpected keyword argument 'segment_ids' at the model call site in ``compute_per_token_logps``. A later commit (9b4a4c6) added an inline ``inspect.signature`` workaround in ``compute_per_token_logps`` only. ``compute_score`` was missed and still unconditionally forwarded the auto-derived ``input_seg_ids``, so reward-model paths in GRPO/GSPO remain broken even on HEAD. This change: - Extracts ``_model_accepts_segment_ids(model)`` so the signature-introspection logic lives in one place, with a docstring explaining why this gate exists (until every model accepts ``segment_ids`` natively). - Applies the gate consistently in both ``compute_per_token_logps`` and ``compute_score``: caller-supplied ``segment_ids`` is always passed through (matches pre-bug behavior); the auto-derived ``input_seg_ids`` is only forwarded to models whose signature accepts it. - Adds regression tests using mock modules whose ``__call__`` does NOT declare ``segment_ids``, which is the shape of the failure mode for Gemma 2/3/4 / Llama. The existing tests only used ``ToyTransformer``, whose ``__call__`` already accepts ``segment_ids`` — which is how this slipped past CI. Manually verified that ``compute_per_token_logps`` no longer raises when called with a real ``Gemma3`` reference model: graphdef, state = nnx.split(Gemma3(ModelConfig.gemma3_270m(), rngs=nnx.Rngs(0))) common.compute_per_token_logps( graphdef, state, prompt, completion, pad_id=0, eos_id=-1)
…'t accept it (Gemma/Llama) Fix TypeError: don't pass auto-derived segment_ids to models that don't accept it (Gemma/Llama)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Resolves #1539.
Summary
GRPO/GSPO training crashes with
TypeError: Gemma3.__call__() got an unexpected keyword argument 'segment_ids'(and the same on Gemma 2 / Gemma 4 / Llama 3)whenever the policy or reference model is anything other than Qwen3.
Root cause
Commit 49b63f7 ("Agentic GRPO improvements") changed
tunix/rl/common.pytoauto-derive a per-position non-pad mask in
process_idsand forward it assegment_ids=to the model in bothcompute_per_token_logpsandcompute_score. Only Qwen3's__call__was updated to accept that keyword —Gemma 2/3/4 and Llama 3 still don't, so any RL run with those models as the
reference/policy model raises a
TypeErrorat the model call site.Commit 9b4a4c6 added an inline
inspect.signatureworkaround incompute_per_token_logpsonly.compute_scorewas missed and stillunconditionally forwarded the auto-derived
input_seg_ids, so reward-modelpaths in GRPO/GSPO remain broken on HEAD.
This regression slipped past CI because the existing tests in
tests/rl/common_test.pyonly useToyTransformer, whose__call__alreadydeclares
segment_ids.Changes
tunix/rl/common.py_model_accepts_segment_ids(model)that introspects the model's__call__signature once and returnsTrueif it has an explicitsegment_idsparameter or a**kwargscatch-all. Docstring explains whythe gate exists.
compute_per_token_logps: replaces the inlineimport inspect/ try-exceptwith a call to the helper. Same semantics, cleaner code.
compute_score: applies the same gating. Caller-suppliedsegment_idsisstill forwarded as before (so explicit packed-mode callers see the same
TypeErrorthey always would have, now pointing at their code rather thana hidden auto-derived value), but the auto-derived
input_seg_idsis onlypassed if the model accepts the keyword.
tests/rl/common_test.pytest_model_accepts_segment_ids_helper— unit-tests the helper againstmodules with an explicit
segment_idsparam, without it, and with**kwargs.test_compute_per_token_logps_model_without_segment_ids— reproduces theissue with a mock module whose
__call__does NOT declaresegment_ids.Would
TypeErrorbefore this fix.test_compute_score_model_without_segment_ids— same shape, but for thepreviously-missed
compute_scorecallsite.Verification
All 41 tests in
tests/rl/common_test.pypass (38 existing + 3 new):