Skip to content

Support CUDA 12.8#309

Merged
gbg141 merged 5 commits intogeometric-intelligence:zaz/fix-deprecationsfrom
zaz:support-cu128
Apr 17, 2026
Merged

Support CUDA 12.8#309
gbg141 merged 5 commits intogeometric-intelligence:zaz/fix-deprecationsfrom
zaz:support-cu128

Conversation

@zaz
Copy link
Copy Markdown
Collaborator

@zaz zaz commented Apr 11, 2026

Checklist

  • My pull request has a clear and explanatory title.
  • My pull request passes the Linting test.
  • I added appropriate unit tests and I made sure the code passes all unit tests. (refer to comment below)
  • My PR follows PEP8 guidelines. (refer to comment below)
  • My code is properly documented, using numpy docs conventions, and I made sure the documentation renders properly.
  • I linked to issues and PRs that are relevant to this PR.

PR series: #303 #307 #308 309

This PR is only for commits da969a2..9741e93. It is number 3 in a series of staged PRs, building upon #308 by adding commits da969a2..9741e93; if #308 is rejected, the commits in this PR need to be manually reviewed. In particular, 2953cfb is required so that the OGB dataset loader does not crash.

Description

Commits:

  1. Loosens the torch pin from ==2.3.0 to >=2.3.0.
  2. Set 3 torch sparse packages to not build from source because I was having issues with the source build taking priority.
  3. Add CUDA 12.8 support. Fixes Support cu128 for Blackwell GPUs #306.
  4. Make it so backbone-specific imports are optional.
  5. Add tests for 4.

The last two shouldn't affect users who do the automated install, but for users who are playing around with later CUDA versions, later torch versions, etc, having those dependencies as optional makes things easier as you have less dependencies to manage if you're not using those backbones.

If you don't want to support torch >2.3.0, we could add a CLI option to use a specific torch version and update the installation instructions to use that, noting that non-2.3.0 versions are experimental. However, it would be good to support >2.3.0 torch versions because GPUs that require them are only becoming more common.

Issue

Fixes #306.

zaz added 5 commits April 11, 2026 18:11
Replace hardcoded TORCH_VER="2.3.0" with auto-detection so the setup
script works with any torch version resolved by uv. Loosen torch pin
from ==2.3.0 to >=2.3.0 to allow newer versions.
Add no-build-package to pyproject.toml to prevent uv from building
these packages from PyPI sdists. This forces resolution from the PyG
find-links wheels, which are pre-built for the correct PyTorch + CUDA
version. Applies to all uv commands, not just the setup script.
Remove extra-build-dependencies section (no longer needed since we
never build from source).
Add pytorch-cu128 index to pyproject.toml and cu128 option to the
setup script. This is required for newer GPUs (e.g. Blackwell
architecture) that need CUDA 12.8+.
This only affects users doing a manual install; the setup script
installs them via --all-extras. Making them optional avoids install
failures for users who don't need NSD, ED-GNN, or point cloud lifting
backbones, as these packages require pre-built wheels matching the
exact PyTorch + CUDA version.

Move top-level imports of torch_sparse, torch_scatter, and
torch_cluster to lazy imports inside the functions that use them,
so that importing topobench doesn't crash without the [sparse] extra.
Add [sparse] to the [all] extra group.
Verify that importing topobench and triggering backbone auto-discovery
works without the [sparse] extra installed.
Copy link
Copy Markdown
Collaborator

@gbg141 gbg141 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only that one question before the merge. Thank you!

Tensor
Output features.
"""
import torch_scatter
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why we need import torch_scatter within methods instead of globally?

torch.Tensor
Output node features of shape [num_nodes, output_dim].
"""
import torch_sparse
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why we need import torch_scatter within methods instead of globally?

saved_tril_maps : torch.Tensor
Saved lower triangular transport maps for analysis.
"""
from torch_scatter import scatter_add
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(same comment!)

@zaz
Copy link
Copy Markdown
Collaborator Author

zaz commented Apr 17, 2026

@gbg141 Thank you for all your reviews 🙏 It's because import BACKBONE_CLASSES will fail if you put them at the top-level and don't have the (now optional) packages installed. I added a test to cover this.

I left that commit last because it is a bit unusual. I reasoned that it adds negligible overhead and shouldn't delay ModuleNotFound errors in most practical use-cases. But it does look weird. The argument for it is that it may help users who are running into frictions with dependencies, and these frictions will only become more prominent as we expand TopoBench.

So it's your call on that commit.

Copy link
Copy Markdown
Collaborator

@gbg141 gbg141 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it! I just opened an issue (#312) to explore more in detail these optional imports in the future, but merging your PR into main right away.

Thanks to you for all of your contributions to TB!

@gbg141 gbg141 merged commit 964bca4 into geometric-intelligence:zaz/fix-deprecations Apr 17, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support cu128 for Blackwell GPUs

2 participants