Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions tests/models/gemma4/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

from __future__ import annotations

import dataclasses

from absl.testing import absltest
from flax import nnx
import jax
import jax.numpy as jnp
from tunix.models.gemma3 import vision as vision_lib
from tunix.models.gemma4 import model as model_lib


Expand Down Expand Up @@ -195,6 +198,144 @@ def body_fn(step, _):
_, logits = compiled_decode(state)
self.assertEqual(logits.shape, (2, 32, config.num_embed))

def test_text_only_no_vision_encoder(self):
config = model_lib.ModelConfig.gemma4_e2b()
self.assertIsNone(config.vision_config)
config.num_layers = 1
config.embed_dim = 256
config.hidden_dim = 512
config.num_heads = 4
config.head_dim = 64
config.num_kv_heads = 1
config.frac_shared_layers = 0.0

rngs = nnx.Rngs(0)
model = model_lib.Gemma4(config, rngs=rngs)
self.assertIsNone(model.vision_encoder)
self.assertFalse(hasattr(model.embedder, "mm_input_projection"))

def test_forward_pass_multimodal(self):
# Use a tiny SigLIP config so the test stays fast/light.
small_vision_config = vision_lib.SigLIPConfig(
num_mm_tokens_per_image_prepool=16,
num_mm_tokens_per_image=4,
image_height=32,
image_width=32,
image_channels=3,
soft_token_placeholder=219,
patch_size=(8, 8),
width=32,
depth=1,
mlp_dim=64,
num_heads=4,
)
base_config = model_lib.ModelConfig.gemma4_e2b()
config = dataclasses.replace(
base_config,
num_layers=1,
embed_dim=256,
hidden_dim=512,
num_heads=4,
head_dim=64,
num_kv_heads=1,
frac_shared_layers=0.0,
vision_config=small_vision_config,
)

rngs = nnx.Rngs(0)
model = model_lib.Gemma4(config, rngs=rngs)
self.assertIsNotNone(model.vision_encoder)
self.assertTrue(hasattr(model.embedder, "mm_input_projection"))
self.assertTrue(hasattr(model.embedder, "mm_soft_embedding_norm"))

batch_size = 2
seq_len = 16
num_images = 1
num_mm = small_vision_config.num_mm_tokens_per_image

# Place the soft-image placeholder in positions [1:1+num_mm] of each row.
tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
tokens = tokens.at[:, 1 : 1 + num_mm].set(
small_vision_config.soft_token_placeholder
)

positions = jnp.tile(
jnp.arange(seq_len)[None, :], (batch_size, 1)
)
attn_mask = model.get_attention_mask(tokens)

images = jnp.zeros(
(
batch_size,
num_images,
small_vision_config.image_height,
small_vision_config.image_width,
small_vision_config.image_channels,
),
dtype=jnp.float32,
)

logits, _ = model(
tokens,
positions=positions,
attention_mask=attn_mask,
images=images,
)
self.assertEqual(logits.shape, (batch_size, seq_len, config.num_embed))

def test_multimodal_call_without_vision_encoder_raises(self):
config = model_lib.ModelConfig.gemma4_e2b()
config.num_layers = 1
config.embed_dim = 256
config.hidden_dim = 512
config.num_heads = 4
config.head_dim = 64
config.num_kv_heads = 1
config.frac_shared_layers = 0.0

rngs = nnx.Rngs(0)
model = model_lib.Gemma4(config, rngs=rngs)

tokens = jnp.ones((1, 4), dtype=jnp.int32)
positions = jnp.arange(4)[None, :]
attn_mask = jnp.tril(jnp.ones((4, 4), dtype=jnp.bool_))[None, ...]
images = jnp.zeros((1, 1, 32, 32, 3), dtype=jnp.float32)

with self.assertRaises(ValueError):
model(
tokens,
positions=positions,
attention_mask=attn_mask,
images=images,
)

def test_get_attention_mask_text_only(self):
config = model_lib.ModelConfig.gemma4_e2b()
config.num_layers = 1
config.embed_dim = 256
config.hidden_dim = 512
config.num_heads = 4
config.head_dim = 64
config.num_kv_heads = 1
config.frac_shared_layers = 0.0

rngs = nnx.Rngs(0)
model = model_lib.Gemma4(config, rngs=rngs)
# No vision config => no bidirectional span; mask should be purely causal.
tokens = jnp.array([[1, 2, 3, 0, 0]])
mask = model.get_attention_mask(tokens)
expected = jnp.array(
[[
[1, 0, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 1, 0, 0],
[1, 1, 1, 0, 0],
[1, 1, 1, 0, 0],
]],
dtype=jnp.bool_,
)
self.assertTrue(bool(jnp.all(mask == expected)))


if __name__ == "__main__":
absltest.main()
95 changes: 95 additions & 0 deletions tests/rl/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,101 @@ def test_compute_per_token_logps(
logits.shape, (expected_logps.shape[0], expected_logps.shape[1], 256)
)

def test_model_accepts_segment_ids_helper(self):
# ToyTransformer.__call__ accepts segment_ids explicitly.
model_with = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0))
self.assertTrue(common._model_accepts_segment_ids(model_with))

class _NoSegIdsModel(nnx.Module):

def __call__(self, tokens, positions=None, cache=None,
attention_mask=None):
del tokens, positions, cache, attention_mask
return None, None

self.assertFalse(common._model_accepts_segment_ids(_NoSegIdsModel()))

class _VarKwargsModel(nnx.Module):

def __call__(self, tokens, **kwargs):
del tokens, kwargs
return None, None

# **kwargs catch-all should be treated as accepting segment_ids.
self.assertTrue(common._model_accepts_segment_ids(_VarKwargsModel()))

def test_compute_per_token_logps_model_without_segment_ids(self):
# Reproduces google/tunix#1539: when the model's __call__ does NOT accept
# segment_ids, compute_per_token_logps must not pass it through (it would
# raise TypeError on Gemma3 etc.).
class _NoSegIdsTransformer(nnx.Module):

def __init__(self, vocab_size: int, rngs: nnx.Rngs):
self.vocab_size = vocab_size
self.emb = nnx.Embed(vocab_size, 8, rngs=rngs)
self.head = nnx.Linear(8, vocab_size, rngs=rngs)

def __call__(
self,
tokens,
positions=None,
cache=None,
attention_mask=None,
):
del positions, cache, attention_mask
x = self.emb(tokens)
return self.head(x), None

model = _NoSegIdsTransformer(vocab_size=8, rngs=nnx.Rngs(0))
graphdef, state = nnx.split(model)

prompt_tokens = jnp.array([[1, 2, 3, 4], [0, 0, 1, 2]], dtype=jnp.int32)
completion_tokens = jnp.array(
[[5, 6, 7, 0], [3, 4, 5, 6]], dtype=jnp.int32
)

# Should not raise even though input_seg_ids would be derived from
# process_ids — the gating helper must suppress passing segment_ids.
per_token_logps = common.compute_per_token_logps(
graphdef,
state,
prompt_tokens,
completion_tokens,
pad_id=0,
eos_id=-1,
return_logits=False,
)
self.assertEqual(per_token_logps.shape, (2, 4))

def test_compute_score_model_without_segment_ids(self):
# Same regression but for compute_score, which also unconditionally passed
# the derived segment_ids before this fix.
class _NoSegIdsScorer(nnx.Module):

def __init__(self, rngs: nnx.Rngs):
self.emb = nnx.Embed(8, 8, rngs=rngs)
self.head = nnx.Linear(8, 1, rngs=rngs)

def __call__(self, tokens, positions=None, cache=None,
attention_mask=None):
del positions, cache, attention_mask
return self.head(self.emb(tokens))

model = _NoSegIdsScorer(rngs=nnx.Rngs(0))

prompt_tokens = jnp.array([[1, 2, 3, 4]], dtype=jnp.int32)
completion_tokens = jnp.array([[5, 6, 7, 0]], dtype=jnp.int32)

scores = common.compute_score(
model,
prompt_tokens,
completion_tokens,
pad_id=0,
eos_id=-1,
)
# [B, T] after the squeeze inside compute_score.
self.assertEqual(scores.shape, (1, 8))

def test_np_make_completion_mask(self):
completion_ids = np.array(
[
Expand Down
Loading