diff --git a/tunix/models/gemma4/model.py b/tunix/models/gemma4/model.py index 157d7fed6..7a10a02fe 100644 --- a/tunix/models/gemma4/model.py +++ b/tunix/models/gemma4/model.py @@ -596,10 +596,40 @@ def __init__( param_dtype=config.param_dtype, ) + # Retrieve TP size from the active JAX mesh + tp_size = 1 + mesh = None + if hasattr(jax.sharding, 'get_abstract_mesh'): + try: + m = jax.sharding.get_abstract_mesh() + if m is not None and not getattr(m, 'empty', False): + mesh = m + except Exception: + pass + if mesh is None: + mesh = pxla.thread_resources.env.physical_mesh + if mesh is not None and 'tp' in mesh.axis_names: + tp_size = mesh.shape['tp'] + + self.act_btnh_kv = config.shd_config.act_btnh + if self.num_kv_heads % tp_size != 0: + import logging as python_logging + python_logging.info( + f"num_kv_heads={self.num_kv_heads} is not divisible by TP size {tp_size}, " + f"sharding k/v projections on head_dim instead of kv-heads." + ) + fsdp = config.shd_config.act_btnh[0] + self.act_btnh_kv = (fsdp, None, None, 'tp') + k_eq_v = ( config.k_eq_v_global if attn_type == AttentionType.GLOBAL else False ) if k_eq_v: + k_sharding = config.shd_config.q_weight_ndh + if self.num_kv_heads % tp_size != 0: + fsdp = config.shd_config.q_weight_ndh[1] + k_sharding = (None, fsdp, 'tp') + self.k_einsum = Einsum( einsum_str='BSD,KDH->BSKH', shape=( @@ -608,11 +638,19 @@ def __init__( self.head_dim, ), rngs=rngs, - sharding=config.shd_config.q_weight_ndh, + sharding=k_sharding, dtype=config.dtype, param_dtype=config.param_dtype, ) else: + if self.num_kv_heads == 1: + kv_sharding = (None, None, 'fsdp', None) + else: + kv_sharding = config.shd_config.kv_weight_cndh + if self.num_kv_heads % tp_size != 0: + fsdp = config.shd_config.kv_weight_cndh[2] + kv_sharding = (None, None, fsdp, 'tp') + self.kv_einsum = Einsum( einsum_str='BSD,CKDH->CBSKH', shape=( @@ -622,9 +660,7 @@ def __init__( self.head_dim, ), rngs=rngs, - sharding=(None, None, 'fsdp', None) - if self.num_kv_heads == 1 - else config.shd_config.kv_weight_cndh, + sharding=kv_sharding, dtype=config.dtype, param_dtype=config.param_dtype, ) @@ -650,6 +686,7 @@ def block( cache: LayerCache | None, attn_mask: jaxtyping.Array, kv_shared_cache: LayerCache | None = None, + segment_ids: jaxtyping.Array | None = None, ) -> tuple[ LayerCache | None, jaxtyping.Array, @@ -743,12 +780,13 @@ def block( 'k': key_proj, } - _, _, qh, _ = query_proj.shape + b, _, qh, _ = query_proj.shape + _, _, kh, _ = key_proj.shape if self.config.use_flash_attention and seq_len > 1: - query_proj = query_proj.transpose(0, 2, 1, 3) - key_proj = key_proj.transpose(0, 2, 1, 3) - value_proj = value_proj.transpose(0, 2, 1, 3) + query_proj_splash = query_proj.transpose(0, 2, 1, 3) + key_proj_splash = key_proj.transpose(0, 2, 1, 3) + value_proj_splash = value_proj.transpose(0, 2, 1, 3) mesh = pxla.thread_resources.env.physical_mesh if self.attn_type == AttentionType.LOCAL_SLIDING: @@ -773,6 +811,8 @@ def block( ) shd_b, shd_t, shd_n, shd_h = self.config.shd_config.act_btnh + if mesh is not None and shd_b is not None and shd_b in mesh.shape and b % mesh.shape[shd_b] != 0: + shd_b = None head_shards = ( mesh.shape[shd_n] if shd_n is not None and shd_n in mesh.shape else 1 ) @@ -788,24 +828,76 @@ def block( ) shd_spec = P(shd_b, shd_n, shd_t, shd_h) - unsharded_seq = P(shd_b, shd_n, None, shd_h) + shd_n_kv = ( + shd_n + if mesh is not None + and shd_n is not None + and shd_n in mesh.shape + and kh % mesh.shape[shd_n] == 0 + else None + ) + unsharded_seq_kv = P(shd_b, shd_n_kv, None, shd_h) kernel_spec = splash_attn_kernel.manual_sharding_spec( shd.NamedSharding(mesh, P(shd_n, shd_t)) ) - @partial( - shard_map, - mesh=mesh, - in_specs=(kernel_spec, shd_spec, unsharded_seq, unsharded_seq), - out_specs=shd_spec, - check_rep=False, - ) - def sharded_splash_attn(kernel, q_block, k_block, v_block): - return jax.vmap(kernel)(q_block, k_block, v_block) + if segment_ids is not None: + seg_spec = P(shd_b, shd_t) + unsharded_seg_spec = P(shd_b, None) + + @partial( + shard_map, + mesh=mesh, + in_specs=( + kernel_spec, + shd_spec, + unsharded_seq_kv, + unsharded_seq_kv, + seg_spec, + unsharded_seg_spec, + ), + out_specs=shd_spec, + check_rep=False, + ) + def sharded_splash_attn( + kernel, q_block, k_block, v_block, q_seg_block, kv_seg_block + ): + seg_ids = splash.SegmentIds(q=q_seg_block, kv=kv_seg_block) + return jax.vmap(kernel)( + q_block, k_block, v_block, segment_ids=seg_ids + ) - qkv: jaxtyping.Array = sharded_splash_attn( - splash_attn_kernel, query_proj, key_proj, value_proj - ) + qkv: jaxtyping.Array = sharded_splash_attn( + splash_attn_kernel, + query_proj_splash, + key_proj_splash, + value_proj_splash, + segment_ids, + segment_ids, + ) + else: + + @partial( + shard_map, + mesh=mesh, + in_specs=( + kernel_spec, + shd_spec, + unsharded_seq_kv, + unsharded_seq_kv, + ), + out_specs=shd_spec, + check_rep=False, + ) + def sharded_splash_attn(kernel, q_block, k_block, v_block): + return jax.vmap(kernel)(q_block, k_block, v_block) + + qkv: jaxtyping.Array = sharded_splash_attn( + splash_attn_kernel, + query_proj_splash, + key_proj_splash, + value_proj_splash, + ) encoded = qkv.transpose(0, 2, 1, 3) else: if self.use_gqa: @@ -892,7 +984,15 @@ def sharded_splash_attn(kernel, q_block, k_block, v_block): def use_gqa(self): return self.num_kv_heads != self.config.num_heads and self.num_kv_heads > 1 - def __call__(self, x, segment_pos, cache, attn_mask, kv_shared_cache=None): + def __call__( + self, + x, + segment_pos, + cache, + attn_mask, + kv_shared_cache=None, + segment_ids=None, + ): remat_config = getattr(self.config, 'remat_config', RematConfig.NONE) if ( remat_config == RematConfig.BLOCK @@ -902,11 +1002,16 @@ def __call__(self, x, segment_pos, cache, attn_mask, kv_shared_cache=None): # as the first argument. graph_updates=False prevents TraceContextError # when mutating params across jax transformation trace levels. return nnx.remat(self.block.__func__, graph_updates=False)( - self, x, segment_pos, cache, attn_mask, kv_shared_cache + self, x, segment_pos, cache, attn_mask, kv_shared_cache, segment_ids ) else: return self.block( - x, segment_pos, cache, attn_mask, kv_shared_cache=kv_shared_cache + x, + segment_pos, + cache, + attn_mask, + kv_shared_cache=kv_shared_cache, + segment_ids=segment_ids, ) def init_cache(self, batch_size, max_seq_len, dtype): @@ -1114,10 +1219,16 @@ def block( attn_mask, per_layer_input=None, kv_shared_cache=None, + segment_ids=None, ): norm = self.pre_attention_norm(x) cache, attn, kv = self.attn( - norm, segment_pos, cache, attn_mask, kv_shared_cache=kv_shared_cache + norm, + segment_pos, + cache, + attn_mask, + kv_shared_cache=kv_shared_cache, + segment_ids=segment_ids, ) attn = self.post_attention_norm(attn) attn += x @@ -1127,7 +1238,7 @@ def block( if self.config.enable_moe: ffw = self.dense_post_ffw_norm(ffw) moe_norm_ffw = self.moe_pre_ffw_norm(attn) - moe_out = self.moe(moe_norm_ffw) + moe_out = self.moe(moe_norm_ffw, router_input=attn) moe_out = self.moe_post_ffw_norm(moe_out) ffw += moe_out ffw = self.post_ffw_norm(ffw) @@ -1153,6 +1264,7 @@ def __call__( attn_mask, per_layer_input=None, kv_shared_cache=None, + segment_ids=None, ): remat_config = getattr(self.config, 'remat_config', RematConfig.NONE) if ( @@ -1167,10 +1279,17 @@ def __call__( attn_mask, per_layer_input, kv_shared_cache, + segment_ids, ) else: return self.block( - x, segment_pos, cache, attn_mask, per_layer_input, kv_shared_cache + x, + segment_pos, + cache, + attn_mask, + per_layer_input, + kv_shared_cache, + segment_ids=segment_ids, ) def init_cache(self, batch_size, max_seq_len, dtype): @@ -1236,12 +1355,14 @@ def __call__( positions=None, cache=None, attention_mask=None, + segment_ids=None, decode_only_last_token=False, ): if positions is None: B, T = tokens.shape # pylint: disable=invalid-name positions = jnp.tile(jnp.arange(T)[None, :], (B, 1)) + return_cache = cache is not None new_cache = {} x = self.embedder.encode(tokens) @@ -1284,6 +1405,7 @@ def __call__( if per_layer_inputs is not None else None, kv_shared_cache=kv_shared_cache, + segment_ids=segment_ids, ) if is_prefill and i in self.shared_layer_origins: transient_kvs[layer_name] = layers_kvs @@ -1302,7 +1424,7 @@ def __call__( logits /= self.config.final_logit_softcap logits = jnp.tanh(logits) * self.config.final_logit_softcap - return logits, (None if cache is None else new_cache) + return logits, (new_cache if return_cache else None) # pytype: disable=container-type-mismatch def init_cache(self, batch_size, max_seq_len, dtype): cache = {} diff --git a/tunix/models/gemma4/moe.py b/tunix/models/gemma4/moe.py index 082590b36..c377a59f0 100644 --- a/tunix/models/gemma4/moe.py +++ b/tunix/models/gemma4/moe.py @@ -119,7 +119,7 @@ def __init__( def _router(self, router_logits: jax.Array): router_logits = router_logits.astype(jnp.float32) router_probs = jax.nn.softmax(router_logits, axis=-1) - _, choices = jax.lax.approx_max_k( + weights, choices = jax.lax.top_k( router_logits, k=self.num_experts_per_datapoint, ) @@ -184,9 +184,11 @@ def _run_ffw_and_routing( ) return out - def __call__(self, x): - var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) - router_input = x * jax.lax.rsqrt(var + 1e-06).astype(x.dtype) + def __call__(self, x, router_input=None): + if router_input is None: + router_input = x + var = jnp.mean(jnp.square(router_input.astype(jnp.float32)), axis=-1, keepdims=True) + router_input = router_input * jax.lax.rsqrt(var + 1e-06).astype(router_input.dtype) root_size = jax.lax.rsqrt( jnp.array(self.features, dtype=router_input.dtype)