diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 4d276214..747dbb0f 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -19,7 +19,7 @@ def get_mla_metadata( num_heads_k: The number of k heads. num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. - topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to. + topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache` will be attended to. Returns: tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. diff --git a/setup.py b/setup.py index 15fa6717..a1e14a12 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,21 @@ def get_features_args(): features_args.append("-DFLASH_MLA_DISABLE_FP16") return features_args +def get_gpu_arch(): + """Detect GPU architecture using nvidia-smi.""" + try: + result = subprocess.run( + ['nvidia-smi', '--query-gpu=compute_cap', '--format=csv,noheader'], + capture_output=True, text=True, timeout=10 + ) + if result.returncode == 0: + compute_cap = result.stdout.strip().split('\n')[0] + major, minor = compute_cap.split('.') + return int(major) * 10 + int(minor) + except Exception: + pass + return None + def get_arch_flags(): # Check NVCC Version # NOTE The "CUDA_HOME" here is not necessarily from the `CUDA_HOME` environment variable. For more details, see `torch/utils/cpp_extension.py` @@ -33,12 +48,24 @@ def get_arch_flags(): major, minor = map(int, nvcc_version_number.split('.')) print(f'Compiling using NVCC {major}.{minor}') + DISABLE_SM120 = is_flag_set("FLASH_MLA_DISABLE_SM120") DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100") DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90") + + # SM120 requires NVCC 12.9+ if major < 12 or (major == 12 and minor <= 8): - assert DISABLE_SM100, "sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment." + DISABLE_SM120 = True + if not DISABLE_SM100: + assert False, "sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment." + + # Auto-detect SM120 (RTX PRO 6000 Blackwell workstation) + gpu_arch = get_gpu_arch() + if gpu_arch == 120 and not DISABLE_SM120: + print(f'Detected SM120 GPU (RTX PRO 6000 Blackwell workstation)') arch_flags = [] + if not DISABLE_SM120: + arch_flags.extend(["-gencode", "arch=compute_120,code=sm_120"]) if not DISABLE_SM100: arch_flags.extend(["-gencode", "arch=compute_100a,code=sm_100a"]) if not DISABLE_SM90: @@ -97,6 +124,8 @@ def get_nvcc_thread_args(): Path(this_dir) / "csrc" / "sm90", Path(this_dir) / "csrc" / "cutlass" / "include", Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include", + # CUDA 13+ moved cuda/std headers to CCCL - add CUDA include path explicitly + Path(CUDA_HOME) / "include", ], ) )