From 2cfe5cfdb7e814d1c38615070d7d5d1330e25166 Mon Sep 17 00:00:00 2001 From: yurekami Date: Fri, 26 Dec 2025 01:52:16 +0900 Subject: [PATCH 1/3] fix: correct outdated function reference in docstring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changed `flash_mla_with_kvcache_sm90` to `flash_mla_with_kvcache` in get_mla_metadata docstring to match the actual function name. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- flash_mla/flash_mla_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 4d276214..747dbb0f 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -19,7 +19,7 @@ def get_mla_metadata( num_heads_k: The number of k heads. num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. - topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to. + topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache` will be attended to. Returns: tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. From b79616df0dca6bf217689a48c478ab228a6688d6 Mon Sep 17 00:00:00 2001 From: yurekami Date: Mon, 29 Dec 2025 03:42:04 +0900 Subject: [PATCH 2/3] fix: add CUDA include path for CCCL headers (CUDA 13+ compatibility) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CUDA 13 moved cuda/std/utility and other standard library headers to CCCL (CUDA C++ Core Library). This adds the CUDA include path explicitly to resolve build errors on CUDA 13+ ARM64 systems. Fixes #121 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index 15fa6717..d74e096f 100644 --- a/setup.py +++ b/setup.py @@ -97,6 +97,8 @@ def get_nvcc_thread_args(): Path(this_dir) / "csrc" / "sm90", Path(this_dir) / "csrc" / "cutlass" / "include", Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include", + # CUDA 13+ moved cuda/std headers to CCCL - add CUDA include path explicitly + Path(CUDA_HOME) / "include", ], ) ) From 7ded643b6df546891dcb3369f282cf567eddabc0 Mon Sep 17 00:00:00 2001 From: yurekami Date: Tue, 30 Dec 2025 01:58:06 +0900 Subject: [PATCH 3/3] fix: add SM120 (RTX PRO 6000 Blackwell) support (#124) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add get_gpu_arch() function to detect GPU compute capability - Add FLASH_MLA_DISABLE_SM120 environment variable - Generate SM120 arch flags when NVCC 12.9+ is available - Auto-detect SM120 GPUs and log detection message 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- setup.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d74e096f..a1e14a12 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,21 @@ def get_features_args(): features_args.append("-DFLASH_MLA_DISABLE_FP16") return features_args +def get_gpu_arch(): + """Detect GPU architecture using nvidia-smi.""" + try: + result = subprocess.run( + ['nvidia-smi', '--query-gpu=compute_cap', '--format=csv,noheader'], + capture_output=True, text=True, timeout=10 + ) + if result.returncode == 0: + compute_cap = result.stdout.strip().split('\n')[0] + major, minor = compute_cap.split('.') + return int(major) * 10 + int(minor) + except Exception: + pass + return None + def get_arch_flags(): # Check NVCC Version # NOTE The "CUDA_HOME" here is not necessarily from the `CUDA_HOME` environment variable. For more details, see `torch/utils/cpp_extension.py` @@ -33,12 +48,24 @@ def get_arch_flags(): major, minor = map(int, nvcc_version_number.split('.')) print(f'Compiling using NVCC {major}.{minor}') + DISABLE_SM120 = is_flag_set("FLASH_MLA_DISABLE_SM120") DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100") DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90") + + # SM120 requires NVCC 12.9+ if major < 12 or (major == 12 and minor <= 8): - assert DISABLE_SM100, "sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment." + DISABLE_SM120 = True + if not DISABLE_SM100: + assert False, "sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment." + + # Auto-detect SM120 (RTX PRO 6000 Blackwell workstation) + gpu_arch = get_gpu_arch() + if gpu_arch == 120 and not DISABLE_SM120: + print(f'Detected SM120 GPU (RTX PRO 6000 Blackwell workstation)') arch_flags = [] + if not DISABLE_SM120: + arch_flags.extend(["-gencode", "arch=compute_120,code=sm_120"]) if not DISABLE_SM100: arch_flags.extend(["-gencode", "arch=compute_100a,code=sm_100a"]) if not DISABLE_SM90: