Skip to content

Fix TypeError when passing auto-derived segment_ids to models that don't accept it (Gemma/Llama)#1547

Open
msghik wants to merge 6 commits into
google:mainfrom
msghik:main
Open

Fix TypeError when passing auto-derived segment_ids to models that don't accept it (Gemma/Llama)#1547
msghik wants to merge 6 commits into
google:mainfrom
msghik:main

Conversation

@msghik

@msghik msghik commented May 28, 2026

Copy link
Copy Markdown

Resolves #1539.

Summary

GRPO/GSPO training crashes with TypeError: Gemma3.__call__() got an unexpected keyword argument 'segment_ids' (and the same on Gemma 2 / Gemma 4 / Llama 3)
whenever the policy or reference model is anything other than Qwen3.

Root cause

Commit 49b63f7 ("Agentic GRPO improvements") changed tunix/rl/common.py to
auto-derive a per-position non-pad mask in process_ids and forward it as
segment_ids= to the model in both compute_per_token_logps and
compute_score. Only Qwen3's __call__ was updated to accept that keyword —
Gemma 2/3/4 and Llama 3 still don't, so any RL run with those models as the
reference/policy model raises a TypeError at the model call site.

Commit 9b4a4c6 added an inline inspect.signature workaround in
compute_per_token_logps only. compute_score was missed and still
unconditionally forwarded the auto-derived input_seg_ids, so reward-model
paths in GRPO/GSPO remain broken on HEAD.

This regression slipped past CI because the existing tests in
tests/rl/common_test.py only use ToyTransformer, whose __call__ already
declares segment_ids.

Changes

  • tunix/rl/common.py

    • New helper _model_accepts_segment_ids(model) that introspects the model's
      __call__ signature once and returns True if it has an explicit
      segment_ids parameter or a **kwargs catch-all. Docstring explains why
      the gate exists.
    • compute_per_token_logps: replaces the inline import inspect / try-except
      with a call to the helper. Same semantics, cleaner code.
    • compute_score: applies the same gating. Caller-supplied segment_ids is
      still forwarded as before (so explicit packed-mode callers see the same
      TypeError they always would have, now pointing at their code rather than
      a hidden auto-derived value), but the auto-derived input_seg_ids is only
      passed if the model accepts the keyword.
  • tests/rl/common_test.py

    • test_model_accepts_segment_ids_helper — unit-tests the helper against
      modules with an explicit segment_ids param, without it, and with
      **kwargs.
    • test_compute_per_token_logps_model_without_segment_ids — reproduces the
      issue with a mock module whose __call__ does NOT declare segment_ids.
      Would TypeError before this fix.
    • test_compute_score_model_without_segment_ids — same shape, but for the
      previously-missed compute_score callsite.

Verification

All 41 tests in tests/rl/common_test.py pass (38 existing + 3 new):

msghik added 5 commits May 27, 2026 20:14
Enables SFT/LoRA/RL of Gemma 4 with text+image inputs, mirroring the
SigLIP-based integration already present in Gemma 3. The training and
sampling infrastructures (peft_trainer, sampler, rl/common) already
forward an `images` kwarg to the model; this change closes the gap at
the model level so they actually work end-to-end on Gemma 4.

Specifically:
- `ModelConfig` gains an optional `vision_config: SigLIPConfig | None`,
  exposed via a `text_only=True` parameter on the e2b/e4b/31b/26b-a4b
  factory classmethods (matching gemma3_*_pt/it).
- `ShardingConfig` gains a `siglip` field for the encoder's sharding.
- `Embedder` optionally builds the `mm_input_projection` and
  `mm_soft_embedding_norm` layers and exposes `encode_vision`.
- `Gemma4.__init__` constructs a `SigLiP` encoder when a vision config
  is set; `__call__` accepts an `images` kwarg and merges the soft
  vision tokens into the text embeddings at the placeholder positions.
- Adds `get_attention_mask` (bidirectional over image spans) and
  `get_model_input` (used for LoRA tracing) on the model.
- `params_safetensors._get_key_and_transform_mapping` adds vision-tower
  and multi-modal projector mappings when vision is enabled.

The SigLIP encoder, embedding-merge utility and attention-mask helper
are reused from `tunix.models.gemma3` to avoid ~900 lines of
duplication; they have no Gemma 3-specific assumptions.

Tests: 4 new tests cover text-only construction, multi-modal forward
pass, the helpful error when images are passed without a vision encoder,
and the text-only attention mask shape. All existing Gemma 3/4 tests
continue to pass.
Add multi-modal (vision) support to Gemma 4
Enables SFT/LoRA/RL of Gemma 4 with text+image inputs, mirroring the
SigLIP-based integration already present in Gemma 3. The training and
sampling infrastructures (peft_trainer, sampler, rl/common) already
forward an `images` kwarg to the model; this change closes the gap at
the model level so they actually work end-to-end on Gemma 4.

Specifically:
- `ModelConfig` gains an optional `vision_config: SigLIPConfig | None`,
  exposed via a `text_only=True` parameter on the e2b/e4b/31b/26b-a4b
  factory classmethods (matching gemma3_*_pt/it).
- `ShardingConfig` gains a `siglip` field for the encoder's sharding.
- `Embedder` optionally builds the `mm_input_projection` and
  `mm_soft_embedding_norm` layers and exposes `encode_vision`.
- `Gemma4.__init__` constructs a `SigLiP` encoder when a vision config
  is set; `__call__` accepts an `images` kwarg and merges the soft
  vision tokens into the text embeddings at the placeholder positions.
- Adds `get_attention_mask` (bidirectional over image spans) and
  `get_model_input` (used for LoRA tracing) on the model.
- `params_safetensors._get_key_and_transform_mapping` adds vision-tower
  and multi-modal projector mappings when vision is enabled.

The SigLIP encoder, embedding-merge utility and attention-mask helper
are reused from `tunix.models.gemma3` to avoid ~900 lines of
duplication; they have no Gemma 3-specific assumptions.

Tests: 4 new tests cover text-only construction, multi-modal forward
pass, the helpful error when images are passed without a vision encoder,
and the text-only attention mask shape. All existing Gemma 3/4 tests
continue to pass.
…pt it.

Fixes google#1539.

Commit 49b63f7 ("Agentic GRPO improvements") started passing a derived
non-pad mask as ``segment_ids`` to the model from rl/common.py for every
RL training run. Only Qwen3's ``__call__`` was updated to accept the
keyword; Gemma 2/3/4 and Llama 3 still don't, so GRPO training with
those reference/policy models crashes with::

  TypeError: Gemma3.__call__() got an unexpected keyword argument
    'segment_ids'

at the model call site in ``compute_per_token_logps``.

A later commit (9b4a4c6) added an inline ``inspect.signature``
workaround in ``compute_per_token_logps`` only. ``compute_score`` was
missed and still unconditionally forwarded the auto-derived
``input_seg_ids``, so reward-model paths in GRPO/GSPO remain broken
even on HEAD.

This change:

- Extracts ``_model_accepts_segment_ids(model)`` so the
  signature-introspection logic lives in one place, with a docstring
  explaining why this gate exists (until every model accepts
  ``segment_ids`` natively).
- Applies the gate consistently in both ``compute_per_token_logps`` and
  ``compute_score``: caller-supplied ``segment_ids`` is always passed
  through (matches pre-bug behavior); the auto-derived
  ``input_seg_ids`` is only forwarded to models whose signature accepts
  it.
- Adds regression tests using mock modules whose ``__call__`` does NOT
  declare ``segment_ids``, which is the shape of the failure mode for
  Gemma 2/3/4 / Llama. The existing tests only used
  ``ToyTransformer``, whose ``__call__`` already accepts
  ``segment_ids`` — which is how this slipped past CI.

Manually verified that ``compute_per_token_logps`` no longer raises
when called with a real ``Gemma3`` reference model:

  graphdef, state = nnx.split(Gemma3(ModelConfig.gemma3_270m(),
                                     rngs=nnx.Rngs(0)))
  common.compute_per_token_logps(
      graphdef, state, prompt, completion, pad_id=0, eos_id=-1)
…'t accept it (Gemma/Llama)

Fix TypeError: don't pass auto-derived segment_ids to models that don't accept it (Gemma/Llama)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[bug] TypeError: Gemma3.__call__() got an unexpected keyword argument 'segment_ids'

2 participants