Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 150 additions & 28 deletions tunix/models/gemma4/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,10 +596,40 @@ def __init__(
param_dtype=config.param_dtype,
)

# Retrieve TP size from the active JAX mesh
tp_size = 1
mesh = None
if hasattr(jax.sharding, 'get_abstract_mesh'):
try:
m = jax.sharding.get_abstract_mesh()
if m is not None and not getattr(m, 'empty', False):
mesh = m
except Exception:
pass
if mesh is None:
mesh = pxla.thread_resources.env.physical_mesh
if mesh is not None and 'tp' in mesh.axis_names:
tp_size = mesh.shape['tp']

self.act_btnh_kv = config.shd_config.act_btnh
if self.num_kv_heads % tp_size != 0:
import logging as python_logging
python_logging.info(
f"num_kv_heads={self.num_kv_heads} is not divisible by TP size {tp_size}, "
f"sharding k/v projections on head_dim instead of kv-heads."
)
fsdp = config.shd_config.act_btnh[0]
self.act_btnh_kv = (fsdp, None, None, 'tp')

k_eq_v = (
config.k_eq_v_global if attn_type == AttentionType.GLOBAL else False
)
if k_eq_v:
k_sharding = config.shd_config.q_weight_ndh
if self.num_kv_heads % tp_size != 0:
fsdp = config.shd_config.q_weight_ndh[1]
k_sharding = (None, fsdp, 'tp')

self.k_einsum = Einsum(
einsum_str='BSD,KDH->BSKH',
shape=(
Expand All @@ -608,11 +638,19 @@ def __init__(
self.head_dim,
),
rngs=rngs,
sharding=config.shd_config.q_weight_ndh,
sharding=k_sharding,
dtype=config.dtype,
param_dtype=config.param_dtype,
)
else:
if self.num_kv_heads == 1:
kv_sharding = (None, None, 'fsdp', None)
else:
kv_sharding = config.shd_config.kv_weight_cndh
if self.num_kv_heads % tp_size != 0:
fsdp = config.shd_config.kv_weight_cndh[2]
kv_sharding = (None, None, fsdp, 'tp')

self.kv_einsum = Einsum(
einsum_str='BSD,CKDH->CBSKH',
shape=(
Expand All @@ -622,9 +660,7 @@ def __init__(
self.head_dim,
),
rngs=rngs,
sharding=(None, None, 'fsdp', None)
if self.num_kv_heads == 1
else config.shd_config.kv_weight_cndh,
sharding=kv_sharding,
dtype=config.dtype,
param_dtype=config.param_dtype,
)
Expand All @@ -650,6 +686,7 @@ def block(
cache: LayerCache | None,
attn_mask: jaxtyping.Array,
kv_shared_cache: LayerCache | None = None,
segment_ids: jaxtyping.Array | None = None,
) -> tuple[
LayerCache | None,
jaxtyping.Array,
Expand Down Expand Up @@ -743,12 +780,13 @@ def block(
'k': key_proj,
}

_, _, qh, _ = query_proj.shape
b, _, qh, _ = query_proj.shape
_, _, kh, _ = key_proj.shape

if self.config.use_flash_attention and seq_len > 1:
query_proj = query_proj.transpose(0, 2, 1, 3)
key_proj = key_proj.transpose(0, 2, 1, 3)
value_proj = value_proj.transpose(0, 2, 1, 3)
query_proj_splash = query_proj.transpose(0, 2, 1, 3)
key_proj_splash = key_proj.transpose(0, 2, 1, 3)
value_proj_splash = value_proj.transpose(0, 2, 1, 3)

mesh = pxla.thread_resources.env.physical_mesh
if self.attn_type == AttentionType.LOCAL_SLIDING:
Expand All @@ -773,6 +811,8 @@ def block(
)

shd_b, shd_t, shd_n, shd_h = self.config.shd_config.act_btnh
if mesh is not None and shd_b is not None and shd_b in mesh.shape and b % mesh.shape[shd_b] != 0:
shd_b = None
head_shards = (
mesh.shape[shd_n] if shd_n is not None and shd_n in mesh.shape else 1
)
Expand All @@ -788,24 +828,76 @@ def block(
)

shd_spec = P(shd_b, shd_n, shd_t, shd_h)
unsharded_seq = P(shd_b, shd_n, None, shd_h)
shd_n_kv = (
shd_n
if mesh is not None
and shd_n is not None
and shd_n in mesh.shape
and kh % mesh.shape[shd_n] == 0
else None
)
unsharded_seq_kv = P(shd_b, shd_n_kv, None, shd_h)
kernel_spec = splash_attn_kernel.manual_sharding_spec(
shd.NamedSharding(mesh, P(shd_n, shd_t))
)

@partial(
shard_map,
mesh=mesh,
in_specs=(kernel_spec, shd_spec, unsharded_seq, unsharded_seq),
out_specs=shd_spec,
check_rep=False,
)
def sharded_splash_attn(kernel, q_block, k_block, v_block):
return jax.vmap(kernel)(q_block, k_block, v_block)
if segment_ids is not None:
seg_spec = P(shd_b, shd_t)
unsharded_seg_spec = P(shd_b, None)

@partial(
shard_map,
mesh=mesh,
in_specs=(
kernel_spec,
shd_spec,
unsharded_seq_kv,
unsharded_seq_kv,
seg_spec,
unsharded_seg_spec,
),
out_specs=shd_spec,
check_rep=False,
)
def sharded_splash_attn(
kernel, q_block, k_block, v_block, q_seg_block, kv_seg_block
):
seg_ids = splash.SegmentIds(q=q_seg_block, kv=kv_seg_block)
return jax.vmap(kernel)(
q_block, k_block, v_block, segment_ids=seg_ids
)

qkv: jaxtyping.Array = sharded_splash_attn(
splash_attn_kernel, query_proj, key_proj, value_proj
)
qkv: jaxtyping.Array = sharded_splash_attn(
splash_attn_kernel,
query_proj_splash,
key_proj_splash,
value_proj_splash,
segment_ids,
segment_ids,
)
else:

@partial(
shard_map,
mesh=mesh,
in_specs=(
kernel_spec,
shd_spec,
unsharded_seq_kv,
unsharded_seq_kv,
),
out_specs=shd_spec,
check_rep=False,
)
def sharded_splash_attn(kernel, q_block, k_block, v_block):
return jax.vmap(kernel)(q_block, k_block, v_block)

qkv: jaxtyping.Array = sharded_splash_attn(
splash_attn_kernel,
query_proj_splash,
key_proj_splash,
value_proj_splash,
)
encoded = qkv.transpose(0, 2, 1, 3)
else:
if self.use_gqa:
Expand Down Expand Up @@ -892,7 +984,15 @@ def sharded_splash_attn(kernel, q_block, k_block, v_block):
def use_gqa(self):
return self.num_kv_heads != self.config.num_heads and self.num_kv_heads > 1

def __call__(self, x, segment_pos, cache, attn_mask, kv_shared_cache=None):
def __call__(
self,
x,
segment_pos,
cache,
attn_mask,
kv_shared_cache=None,
segment_ids=None,
):
remat_config = getattr(self.config, 'remat_config', RematConfig.NONE)
if (
remat_config == RematConfig.BLOCK
Expand All @@ -902,11 +1002,16 @@ def __call__(self, x, segment_pos, cache, attn_mask, kv_shared_cache=None):
# as the first argument. graph_updates=False prevents TraceContextError
# when mutating params across jax transformation trace levels.
return nnx.remat(self.block.__func__, graph_updates=False)(
self, x, segment_pos, cache, attn_mask, kv_shared_cache
self, x, segment_pos, cache, attn_mask, kv_shared_cache, segment_ids
)
else:
return self.block(
x, segment_pos, cache, attn_mask, kv_shared_cache=kv_shared_cache
x,
segment_pos,
cache,
attn_mask,
kv_shared_cache=kv_shared_cache,
segment_ids=segment_ids,
)

def init_cache(self, batch_size, max_seq_len, dtype):
Expand Down Expand Up @@ -1114,10 +1219,16 @@ def block(
attn_mask,
per_layer_input=None,
kv_shared_cache=None,
segment_ids=None,
):
norm = self.pre_attention_norm(x)
cache, attn, kv = self.attn(
norm, segment_pos, cache, attn_mask, kv_shared_cache=kv_shared_cache
norm,
segment_pos,
cache,
attn_mask,
kv_shared_cache=kv_shared_cache,
segment_ids=segment_ids,
)
attn = self.post_attention_norm(attn)
attn += x
Expand All @@ -1127,7 +1238,7 @@ def block(
if self.config.enable_moe:
ffw = self.dense_post_ffw_norm(ffw)
moe_norm_ffw = self.moe_pre_ffw_norm(attn)
moe_out = self.moe(moe_norm_ffw)
moe_out = self.moe(moe_norm_ffw, router_input=attn)
moe_out = self.moe_post_ffw_norm(moe_out)
ffw += moe_out
ffw = self.post_ffw_norm(ffw)
Expand All @@ -1153,6 +1264,7 @@ def __call__(
attn_mask,
per_layer_input=None,
kv_shared_cache=None,
segment_ids=None,
):
remat_config = getattr(self.config, 'remat_config', RematConfig.NONE)
if (
Expand All @@ -1167,10 +1279,17 @@ def __call__(
attn_mask,
per_layer_input,
kv_shared_cache,
segment_ids,
)
else:
return self.block(
x, segment_pos, cache, attn_mask, per_layer_input, kv_shared_cache
x,
segment_pos,
cache,
attn_mask,
per_layer_input,
kv_shared_cache,
segment_ids=segment_ids,
)

def init_cache(self, batch_size, max_seq_len, dtype):
Expand Down Expand Up @@ -1236,12 +1355,14 @@ def __call__(
positions=None,
cache=None,
attention_mask=None,
segment_ids=None,
decode_only_last_token=False,
):
if positions is None:
B, T = tokens.shape # pylint: disable=invalid-name
positions = jnp.tile(jnp.arange(T)[None, :], (B, 1))

return_cache = cache is not None
new_cache = {}
x = self.embedder.encode(tokens)

Expand Down Expand Up @@ -1284,6 +1405,7 @@ def __call__(
if per_layer_inputs is not None
else None,
kv_shared_cache=kv_shared_cache,
segment_ids=segment_ids,
)
if is_prefill and i in self.shared_layer_origins:
transient_kvs[layer_name] = layers_kvs
Expand All @@ -1302,7 +1424,7 @@ def __call__(
logits /= self.config.final_logit_softcap
logits = jnp.tanh(logits) * self.config.final_logit_softcap

return logits, (None if cache is None else new_cache)
return logits, (new_cache if return_cache else None) # pytype: disable=container-type-mismatch

def init_cache(self, batch_size, max_seq_len, dtype):
cache = {}
Expand Down
10 changes: 6 additions & 4 deletions tunix/models/gemma4/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(
def _router(self, router_logits: jax.Array):
router_logits = router_logits.astype(jnp.float32)
router_probs = jax.nn.softmax(router_logits, axis=-1)
_, choices = jax.lax.approx_max_k(
weights, choices = jax.lax.top_k(
router_logits,
k=self.num_experts_per_datapoint,
)
Expand Down Expand Up @@ -184,9 +184,11 @@ def _run_ffw_and_routing(
)
return out

def __call__(self, x):
var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True)
router_input = x * jax.lax.rsqrt(var + 1e-06).astype(x.dtype)
def __call__(self, x, router_input=None):
if router_input is None:
router_input = x
var = jnp.mean(jnp.square(router_input.astype(jnp.float32)), axis=-1, keepdims=True)
router_input = router_input * jax.lax.rsqrt(var + 1e-06).astype(router_input.dtype)

root_size = jax.lax.rsqrt(
jnp.array(self.features, dtype=router_input.dtype)
Expand Down
Loading