Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions tests/models/gemma4/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

from __future__ import annotations

import dataclasses

from absl.testing import absltest
from flax import nnx
import jax
import jax.numpy as jnp
from tunix.models.gemma3 import vision as vision_lib
from tunix.models.gemma4 import model as model_lib


Expand Down Expand Up @@ -195,6 +198,144 @@ def body_fn(step, _):
_, logits = compiled_decode(state)
self.assertEqual(logits.shape, (2, 32, config.num_embed))

def test_text_only_no_vision_encoder(self):
config = model_lib.ModelConfig.gemma4_e2b()
self.assertIsNone(config.vision_config)
config.num_layers = 1
config.embed_dim = 256
config.hidden_dim = 512
config.num_heads = 4
config.head_dim = 64
config.num_kv_heads = 1
config.frac_shared_layers = 0.0

rngs = nnx.Rngs(0)
model = model_lib.Gemma4(config, rngs=rngs)
self.assertIsNone(model.vision_encoder)
self.assertFalse(hasattr(model.embedder, "mm_input_projection"))

def test_forward_pass_multimodal(self):
# Use a tiny SigLIP config so the test stays fast/light.
small_vision_config = vision_lib.SigLIPConfig(
num_mm_tokens_per_image_prepool=16,
num_mm_tokens_per_image=4,
image_height=32,
image_width=32,
image_channels=3,
soft_token_placeholder=219,
patch_size=(8, 8),
width=32,
depth=1,
mlp_dim=64,
num_heads=4,
)
base_config = model_lib.ModelConfig.gemma4_e2b()
config = dataclasses.replace(
base_config,
num_layers=1,
embed_dim=256,
hidden_dim=512,
num_heads=4,
head_dim=64,
num_kv_heads=1,
frac_shared_layers=0.0,
vision_config=small_vision_config,
)

rngs = nnx.Rngs(0)
model = model_lib.Gemma4(config, rngs=rngs)
self.assertIsNotNone(model.vision_encoder)
self.assertTrue(hasattr(model.embedder, "mm_input_projection"))
self.assertTrue(hasattr(model.embedder, "mm_soft_embedding_norm"))

batch_size = 2
seq_len = 16
num_images = 1
num_mm = small_vision_config.num_mm_tokens_per_image

# Place the soft-image placeholder in positions [1:1+num_mm] of each row.
tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
tokens = tokens.at[:, 1 : 1 + num_mm].set(
small_vision_config.soft_token_placeholder
)

positions = jnp.tile(
jnp.arange(seq_len)[None, :], (batch_size, 1)
)
attn_mask = model.get_attention_mask(tokens)

images = jnp.zeros(
(
batch_size,
num_images,
small_vision_config.image_height,
small_vision_config.image_width,
small_vision_config.image_channels,
),
dtype=jnp.float32,
)

logits, _ = model(
tokens,
positions=positions,
attention_mask=attn_mask,
images=images,
)
self.assertEqual(logits.shape, (batch_size, seq_len, config.num_embed))

def test_multimodal_call_without_vision_encoder_raises(self):
config = model_lib.ModelConfig.gemma4_e2b()
config.num_layers = 1
config.embed_dim = 256
config.hidden_dim = 512
config.num_heads = 4
config.head_dim = 64
config.num_kv_heads = 1
config.frac_shared_layers = 0.0

rngs = nnx.Rngs(0)
model = model_lib.Gemma4(config, rngs=rngs)

tokens = jnp.ones((1, 4), dtype=jnp.int32)
positions = jnp.arange(4)[None, :]
attn_mask = jnp.tril(jnp.ones((4, 4), dtype=jnp.bool_))[None, ...]
images = jnp.zeros((1, 1, 32, 32, 3), dtype=jnp.float32)

with self.assertRaises(ValueError):
model(
tokens,
positions=positions,
attention_mask=attn_mask,
images=images,
)

def test_get_attention_mask_text_only(self):
config = model_lib.ModelConfig.gemma4_e2b()
config.num_layers = 1
config.embed_dim = 256
config.hidden_dim = 512
config.num_heads = 4
config.head_dim = 64
config.num_kv_heads = 1
config.frac_shared_layers = 0.0

rngs = nnx.Rngs(0)
model = model_lib.Gemma4(config, rngs=rngs)
# No vision config => no bidirectional span; mask should be purely causal.
tokens = jnp.array([[1, 2, 3, 0, 0]])
mask = model.get_attention_mask(tokens)
expected = jnp.array(
[[
[1, 0, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 1, 0, 0],
[1, 1, 1, 0, 0],
[1, 1, 1, 0, 0],
]],
dtype=jnp.bool_,
)
self.assertTrue(bool(jnp.all(mask == expected)))


if __name__ == "__main__":
absltest.main()
Loading
Loading