Support CUDA 12.8#309
Conversation
Replace hardcoded TORCH_VER="2.3.0" with auto-detection so the setup script works with any torch version resolved by uv. Loosen torch pin from ==2.3.0 to >=2.3.0 to allow newer versions.
Add no-build-package to pyproject.toml to prevent uv from building these packages from PyPI sdists. This forces resolution from the PyG find-links wheels, which are pre-built for the correct PyTorch + CUDA version. Applies to all uv commands, not just the setup script. Remove extra-build-dependencies section (no longer needed since we never build from source).
Add pytorch-cu128 index to pyproject.toml and cu128 option to the setup script. This is required for newer GPUs (e.g. Blackwell architecture) that need CUDA 12.8+.
This only affects users doing a manual install; the setup script installs them via --all-extras. Making them optional avoids install failures for users who don't need NSD, ED-GNN, or point cloud lifting backbones, as these packages require pre-built wheels matching the exact PyTorch + CUDA version. Move top-level imports of torch_sparse, torch_scatter, and torch_cluster to lazy imports inside the functions that use them, so that importing topobench doesn't crash without the [sparse] extra. Add [sparse] to the [all] extra group.
Verify that importing topobench and triggering backbone auto-discovery works without the [sparse] extra installed.
gbg141
left a comment
There was a problem hiding this comment.
Only that one question before the merge. Thank you!
| Tensor | ||
| Output features. | ||
| """ | ||
| import torch_scatter |
There was a problem hiding this comment.
Any reason why we need import torch_scatter within methods instead of globally?
| torch.Tensor | ||
| Output node features of shape [num_nodes, output_dim]. | ||
| """ | ||
| import torch_sparse |
There was a problem hiding this comment.
Any reason why we need import torch_scatter within methods instead of globally?
| saved_tril_maps : torch.Tensor | ||
| Saved lower triangular transport maps for analysis. | ||
| """ | ||
| from torch_scatter import scatter_add |
|
@gbg141 Thank you for all your reviews 🙏 It's because I left that commit last because it is a bit unusual. I reasoned that it adds negligible overhead and shouldn't delay So it's your call on that commit. |
964bca4
into
geometric-intelligence:zaz/fix-deprecations
Checklist
PR series: #303 #307 #308 309
This PR is only for commits da969a2..9741e93. It is number 3 in a series of staged PRs, building upon #308 by adding commits da969a2..9741e93; if #308 is rejected, the commits in this PR need to be manually reviewed. In particular, 2953cfb is required so that the OGB dataset loader does not crash.
Description
Commits:
==2.3.0to>=2.3.0.The last two shouldn't affect users who do the automated install, but for users who are playing around with later CUDA versions, later torch versions, etc, having those dependencies as optional makes things easier as you have less dependencies to manage if you're not using those backbones.
If you don't want to support torch >2.3.0, we could add a CLI option to use a specific torch version and update the installation instructions to use that, noting that non-2.3.0 versions are experimental. However, it would be good to support >2.3.0 torch versions because GPUs that require them are only becoming more common.
Issue
Fixes #306.