diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 4d276214..747dbb0f 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -19,7 +19,7 @@ def get_mla_metadata( num_heads_k: The number of k heads. num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. - topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to. + topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache` will be attended to. Returns: tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.