Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flash_mla/flash_mla_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_mla_metadata(
num_heads_k: The number of k heads.
num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled
is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to.
topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache` will be attended to.
Returns:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
Expand Down
31 changes: 30 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@ def get_features_args():
features_args.append("-DFLASH_MLA_DISABLE_FP16")
return features_args

def get_gpu_arch():
"""Detect GPU architecture using nvidia-smi."""
try:
result = subprocess.run(
['nvidia-smi', '--query-gpu=compute_cap', '--format=csv,noheader'],
capture_output=True, text=True, timeout=10
)
if result.returncode == 0:
compute_cap = result.stdout.strip().split('\n')[0]
major, minor = compute_cap.split('.')
return int(major) * 10 + int(minor)
except Exception:
pass
return None

def get_arch_flags():
# Check NVCC Version
# NOTE The "CUDA_HOME" here is not necessarily from the `CUDA_HOME` environment variable. For more details, see `torch/utils/cpp_extension.py`
Expand All @@ -33,12 +48,24 @@ def get_arch_flags():
major, minor = map(int, nvcc_version_number.split('.'))
print(f'Compiling using NVCC {major}.{minor}')

DISABLE_SM120 = is_flag_set("FLASH_MLA_DISABLE_SM120")
DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100")
DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90")

# SM120 requires NVCC 12.9+
if major < 12 or (major == 12 and minor <= 8):
assert DISABLE_SM100, "sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment."
DISABLE_SM120 = True
if not DISABLE_SM100:
assert False, "sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment."

# Auto-detect SM120 (RTX PRO 6000 Blackwell workstation)
gpu_arch = get_gpu_arch()
if gpu_arch == 120 and not DISABLE_SM120:
print(f'Detected SM120 GPU (RTX PRO 6000 Blackwell workstation)')

arch_flags = []
if not DISABLE_SM120:
arch_flags.extend(["-gencode", "arch=compute_120,code=sm_120"])
if not DISABLE_SM100:
arch_flags.extend(["-gencode", "arch=compute_100a,code=sm_100a"])
if not DISABLE_SM90:
Expand Down Expand Up @@ -97,6 +124,8 @@ def get_nvcc_thread_args():
Path(this_dir) / "csrc" / "sm90",
Path(this_dir) / "csrc" / "cutlass" / "include",
Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include",
# CUDA 13+ moved cuda/std headers to CCCL - add CUDA include path explicitly
Path(CUDA_HOME) / "include",
],
)
)
Expand Down