Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ TopoBench now uses [**uv**](https://docs.astral.sh/uv/), an extremely fast Pytho
3. **Initialize Environment**:
Use our centralized setup script to handle Python 3.11 virtualization and specialized hardware (CUDA) mapping.
```bash
# Usage: source uv_env_setup.sh [cpu|cu118|cu121]
# Usage: source uv_env_setup.sh [cpu|cu118|cu121|cu128]
source uv_env_setup.sh cpu
```
*This script performs the following:*
Expand Down
25 changes: 15 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
requires-python = ">=3.11, <3.12"

dependencies=[
"torch==2.3.0",
"torch>=2.3.0",
"tqdm",
"charset-normalizer",
"numpy",
Expand Down Expand Up @@ -56,12 +56,16 @@ dependencies=[
"topomodelx @ git+https://github.com/pyt-team/TopoModelX.git",
"toponetx @ git+https://github.com/pyt-team/TopoNetX.git@c378925",
"lightning==2.4.0",
]

[project.optional-dependencies]
# Required for NSD, ED-GNN, and point cloud lifting backbones.
# Wheels must match your PyTorch + CUDA version; see find-links below.
sparse = [
"torch-scatter",
"torch-sparse",
"torch-cluster",
]

[project.optional-dependencies]
doc = [
"jupyter",
"nbsphinx",
Expand All @@ -87,7 +91,7 @@ test = [
]

dev = ["TopoBench[test, lint]"]
all = ["TopoBench[dev, doc]"]
all = ["TopoBench[dev, doc, sparse]"]

[project.urls]
homepage="https://geometric-intelligence.github.io/topobench/index.html"
Expand All @@ -112,21 +116,22 @@ name = "pytorch-cu121"
url = "https://download.pytorch.org/whl/cu121"
explicit = true

[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

# Default find-links (will be overwritten by bash script)
[tool.uv]
find-links = ["https://data.pyg.org/whl/torch-2.3.0+cu121.html"]
no-build-package = ["torch-scatter", "torch-sparse", "torch-cluster"]

[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin' or sys_platform == 'win32'" },
{ index = "pytorch-cu121", marker = "sys_platform == 'linux'" },
]

[tool.uv.extra-build-dependencies]
torch-cluster = ["torch==2.3.0"]
torch-scatter = ["torch==2.3.0"]
torch-sparse = ["torch==2.3.0"]

# ==============================================================================
# TOOL CONFIGS
# ==============================================================================
Expand Down Expand Up @@ -172,7 +177,7 @@ disable_error_code = ["import-untyped"]
plugins = "numpy.typing.mypy_plugin"

[[tool.mypy.overrides]]
module = ["torch_cluster.*","networkx.*","scipy.spatial","scipy.sparse","toponetx.classes.simplicial_complex"]
module = ["torch_cluster.*","torch_sparse.*","torch_scatter.*","networkx.*","scipy.spatial","scipy.sparse","toponetx.classes.simplicial_complex"]
ignore_missing_imports = true

[tool.pytest.ini_options]
Expand Down
32 changes: 32 additions & 0 deletions test/test_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Tests for package import behavior."""

import subprocess
import sys


def test_import_without_sparse():
"""Verify topobench imports without torch_sparse/scatter/cluster.

Runs in a subprocess that blocks the sparse imports via sys.modules
to avoid modifying the test environment.
"""
result = subprocess.run(
[
sys.executable,
"-c",
"import sys;"
"sys.modules['torch_sparse'] = None;"
"sys.modules['torch_scatter'] = None;"
"sys.modules['torch_cluster'] = None;"
"import topobench;"
"from topobench.nn.backbones.graph import BACKBONE_CLASSES;"
"from topobench.nn.backbones.hypergraph import BACKBONE_CLASSES;"
"print('ok')",
],
capture_output=True,
text=True,
)
assert result.returncode == 0, (
f"import topobench failed without sparse packages:\n{result.stderr}"
)
assert "ok" in result.stdout
5 changes: 4 additions & 1 deletion topobench/data/utils/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import torch_geometric
from toponetx.classes import SimplicialComplex
from torch_geometric.data import Data
from torch_sparse import coalesce

from topobench.data.utils import get_complex_connectivity

Expand Down Expand Up @@ -370,6 +369,8 @@ def load_hypergraph_pickle_dataset(data_dir, data_name):
torch_geometric.data.Data
Hypergraph dataset.
"""
from torch_sparse import coalesce

data_dir = osp.join(data_dir, data_name)

# Load node features:
Expand Down Expand Up @@ -478,6 +479,8 @@ def load_hypergraph_content_dataset(data_dir, data_name):
torch_geometric.data.Data
Hypergraph dataset.
"""
from torch_sparse import coalesce

# data_dir = osp.join(data_dir, data_name)

p2idx_features_labels = osp.join(data_dir, f"{data_name}.content")
Expand Down
3 changes: 2 additions & 1 deletion topobench/dataloader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import torch
import torch_geometric
from torch_sparse import SparseTensor


class DomainData(torch_geometric.data.Data):
Expand Down Expand Up @@ -70,6 +69,8 @@ def to_data_list(batch):
list
List of data objects.
"""
from torch_sparse import SparseTensor

for key, _ in batch:
if batch[key].is_sparse:
sparse_data = batch[key].coalesce()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import torch
import torch.nn.functional as F
import torch_sparse
from torch import nn

from .laplacian_builders import (
Expand Down Expand Up @@ -102,6 +101,8 @@ def forward(self, x, edge_index):
torch.Tensor
Output node features of shape [num_nodes, output_dim].
"""
import torch_sparse

# Get actual number of nodes dynamically
actual_num_nodes = x.size(0)

Expand Down Expand Up @@ -281,6 +282,8 @@ def forward(self, x, edge_index):
torch.Tensor
Output node features of shape [num_nodes, output_dim].
"""
import torch_sparse
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why we need import torch_scatter within methods instead of globally?


# Get actual number of nodes dynamically
actual_num_nodes = x.size(0)

Expand Down Expand Up @@ -453,6 +456,8 @@ def forward(self, x, edge_index):
torch.Tensor
Output node features of shape [num_nodes, output_dim].
"""
import torch_sparse

# Get actual number of nodes dynamically
actual_num_nodes = x.size(0)

Expand Down
5 changes: 4 additions & 1 deletion topobench/nn/backbones/graph/nsd_utils/laplacian_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import torch
from torch import nn
from torch_geometric.utils import degree
from torch_scatter import scatter_add

sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from .laplace import (
Expand Down Expand Up @@ -170,6 +169,8 @@ def forward(self, maps):
saved_tril_maps : torch.Tensor
Saved lower triangular restriction maps for analysis.
"""
from torch_scatter import scatter_add

assert len(maps.size()) == 2
assert maps.size(1) == self.d
left_idx, right_idx = self.left_right_idx
Expand Down Expand Up @@ -352,6 +353,8 @@ def forward(self, maps):
saved_tril_maps : torch.Tensor
Saved lower triangular transport maps for analysis.
"""
from torch_scatter import scatter_add
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(same comment!)


left_idx, right_idx = self.left_right_idx
tril_row, tril_col = self.vertex_tril_idx
tril_indices, diag_indices = self.tril_indices, self.diag_indices
Expand Down
7 changes: 6 additions & 1 deletion topobench/nn/backbones/hypergraph/edgnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
import torch_scatter


class EDGNN(nn.Module):
Expand Down Expand Up @@ -482,6 +481,8 @@ def forward(self, X, vertex, edges, X0):
Tensor
Output features.
"""
import torch_scatter

N = X.shape[-2]

Xve = self.W1(X)[..., vertex, :] # [nnz, C]
Expand Down Expand Up @@ -562,6 +563,8 @@ def forward(self, X, vertex, edges, X0, beta=1.0):
Tensor
Output features.
"""
import torch_scatter
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why we need import torch_scatter within methods instead of globally?


N = X.shape[-2]

Xve = X[..., vertex, :] # [nnz, C]
Expand Down Expand Up @@ -666,6 +669,8 @@ def forward(self, X, vertex, edges, X0):
Tensor
Output features.
"""
import torch_scatter

N = X.shape[-2]

Xve = self.W1(X[..., vertex, :]) # [nnz, C]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
import torch_geometric
from torch_cluster import fps, radius

from topobench.transforms.liftings.pointcloud2hypergraph.base import (
PointCloud2HypergraphLifting,
Expand Down Expand Up @@ -45,6 +44,7 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict:
dict
The lifted topology.
"""
from torch_cluster import fps, radius

batch = (
torch.zeros(data.num_nodes, dtype=torch.long)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
import torch_geometric
from torch_cluster import fps, knn

from topobench.transforms.liftings.pointcloud2hypergraph.base import (
PointCloud2HypergraphLifting,
Expand Down Expand Up @@ -37,6 +36,7 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict:
dict
The lifted topology.
"""
from torch_cluster import fps, knn

# Sample FPS induced Voronoi graph
support_idcs = fps(data.x, ratio=self.support_ratio)
Expand Down
38 changes: 23 additions & 15 deletions uv_env_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# ==============================================================================
# 🛠️ TopoBench Environment Setup Script (Py3.11 + Dynamic CUDA)
# ==============================================================================
# usage: bash uv_env_setup.sh [cpu|cu118|cu121]
# usage: bash uv_env_setup.sh [cpu|cu118|cu121|cu128]
# ==============================================================================

PLATFORM="${1:-cpu}"
Expand All @@ -17,35 +17,25 @@ echo "======================================================="
# ------------------------------------------------------------------------------
# Configuration
# ------------------------------------------------------------------------------
TORCH_VER="2.3.0"

case "$PLATFORM" in
cpu|cu118|cu121)
cpu|cu118|cu121|cu128)
TARGET_INDEX="pytorch-${PLATFORM}"
PYG_URL="https://data.pyg.org/whl/torch-${TORCH_VER}+${PLATFORM}.html"
;;
*)
echo "❌ Error: Invalid platform '$PLATFORM'. Use: cpu, cu118, or cu121."
echo "❌ Error: Invalid platform '$PLATFORM'. Use: cpu, cu118, cu121, or cu128."
return 1 2>/dev/null || exit 1
;;
esac

echo "⚙️ Updating pyproject.toml..."

# 1. Update the 'find-links' URL for PyG extensions
# Update the torch source index for Linux
if [[ "$OSTYPE" == "darwin"* ]]; then
# MacOS sed
sed -i '' "s|find-links = \[\".*\"\]|find-links = [\"${PYG_URL}\"]|g" pyproject.toml
# Update Linux Source Marker
sed -i '' "s/index = \"pytorch-[a-z0-9]*\", marker = \"sys_platform == 'linux'/index = \"${TARGET_INDEX}\", marker = \"sys_platform == 'linux'/g" pyproject.toml
else
# Linux sed
sed -i "s|find-links = \[\".*\"\]|find-links = [\"${PYG_URL}\"]|g" pyproject.toml
# Update Linux Source Marker
sed -i "s/index = \"pytorch-[a-z0-9]*\", marker = \"sys_platform == 'linux'/index = \"${TARGET_INDEX}\", marker = \"sys_platform == 'linux'/g" pyproject.toml
fi

echo "✅ Set PyG Links to : ${PYG_URL}"
echo "✅ Set Torch Index to: ${TARGET_INDEX}"

# ------------------------------------------------------------------------------
Expand All @@ -55,8 +45,26 @@ echo ""
echo "🧹 Cleaning old lockfile..."
rm -f uv.lock

# Dry-run to detect which torch version will be installed
TORCH_VER=$(uv sync --dry-run --python 3.11 2>&1 \
| grep '+ torch==' | sed 's/.*+ torch==//')
if [ -z "$TORCH_VER" ]; then
# Fallback: read from existing venv (dry-run reports nothing if already installed)
TORCH_VER=$(.venv/bin/python -c "import torch; print(torch.__version__)" 2>/dev/null)
fi
if [ -z "$TORCH_VER" ]; then
echo "❌ Error: Could not detect torch version."
return 1 2>/dev/null || exit 1
fi
PYG_URL="https://data.pyg.org/whl/torch-${TORCH_VER}.html"
if [[ "$OSTYPE" == "darwin"* ]]; then
sed -i '' "s|find-links = \[\".*\"\]|find-links = [\"${PYG_URL}\"]|g" pyproject.toml
else
sed -i "s|find-links = \[\".*\"\]|find-links = [\"${PYG_URL}\"]|g" pyproject.toml
fi
echo "✅ Set PyG Links to : ${PYG_URL} (torch ${TORCH_VER})"

echo "📦 Syncing Environment (Python 3.11)..."
# Force Python 3.11 creation
if ! uv sync --python 3.11 --all-extras; then
echo "❌ uv sync failed."
return 1 2>/dev/null || exit 1
Expand Down
Loading