Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).

### Added

- Added AEROGNN model and AEROConv layer for deep attention in graph neural networks ([#10285](https://github.com/pyg-team/pytorch_geometric/pull/10285))

### Changed

### Deprecated
Expand Down
28 changes: 26 additions & 2 deletions examples/ogbn_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@

from torch_geometric import seed_everything
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn.models import GAT, GraphSAGE, Polynormer, SGFormer
from torch_geometric.nn.models import (
AEROGNN,
GAT,
GraphSAGE,
Polynormer,
SGFormer,
)
from torch_geometric.utils import (
add_self_loops,
remove_self_loops,
Expand All @@ -37,7 +43,7 @@
"--model",
type=str.lower,
default='SGFormer',
choices=['sage', 'gat', 'sgformer', 'polynormer'],
choices=['sage', 'gat', 'sgformer', 'polynormer', 'aero'],
help="Model used for training",
)

Expand All @@ -55,6 +61,12 @@
parser.add_argument('--lr', type=float, default=0.003)
parser.add_argument('--wd', type=float, default=0.0)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--iterations', type=int, default=10,
help='number of propagation iterations for AERO model')
parser.add_argument('--lambd', type=float, default=1.0,
help='decay weight parameter for AERO model')
parser.add_argument('--add_dropout', action='store_true',
help='apply dropout before final layer for AERO model')
parser.add_argument(
'--use_directed_graph',
action='store_true',
Expand Down Expand Up @@ -214,6 +226,18 @@ def get_model(model_name: str) -> torch.nn.Module:
out_channels=dataset.num_classes,
local_layers=num_layers,
)
elif model_name == 'aero':
model = AEROGNN(
in_channels=dataset.num_features,
hidden_channels=num_hidden_channels,
num_layers=num_layers,
out_channels=dataset.num_classes,
iterations=args.iterations,
heads=args.num_heads,
lambd=args.lambd,
dropout=args.dropout,
add_dropout=args.add_dropout,
)
else:
raise ValueError(f'Unsupported model type: {model_name}')

Expand Down
21 changes: 20 additions & 1 deletion examples/ogbn_train_cugraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def arg_parse():
'SAGE',
'GAT',
'GCN',
'AERO',
# TODO: Uncomment when we add support for disjoint sampling
# 'SGFormer',
],
Expand All @@ -145,8 +146,14 @@ def arg_parse():
"--num_heads",
type=int,
default=1,
help="If using GATConv or GT, number of attention heads to use",
help="If using GATConv, GT, or AERO, number of attention heads to use",
)
parser.add_argument('--iterations', type=int, default=10,
help='number of propagation iterations for AERO model')
parser.add_argument('--lambd', type=float, default=1.0,
help='decay weight parameter for AERO model')
parser.add_argument('--add_dropout', action='store_true',
help='apply dropout before final layer for AERO model')
parser.add_argument('--tempdir_root', type=str, default=None)
args = parser.parse_args()
return args
Expand Down Expand Up @@ -277,6 +284,18 @@ def test(model, loader):
model = torch_geometric.nn.models.GraphSAGE(
dataset.num_features, args.hidden_channels, args.num_layers,
dataset.num_classes).cuda()
elif args.model == "AERO":
model = torch_geometric.nn.models.AEROGNN(
in_channels=dataset.num_features,
hidden_channels=args.hidden_channels,
num_layers=args.num_layers,
out_channels=dataset.num_classes,
iterations=args.iterations,
heads=args.num_heads,
lambd=args.lambd,
dropout=args.dropout,
add_dropout=args.add_dropout,
).cuda()
elif args.model == 'SGFormer':
# TODO add support for this with disjoint sampling
model = torch_geometric.nn.models.SGFormer(
Expand Down
224 changes: 224 additions & 0 deletions test/nn/models/test_aero_gnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import pytest
import torch

from torch_geometric.data import Data
from torch_geometric.nn.models import AEROGNN
from torch_geometric.testing import withDevice

out_dims = [None, 8]
dropouts = [0.0, 0.5]
iterations_list = [1, 5, 10]
heads_list = [1, 2]
lambd_list = [0.5, 1.0, 2.0]
num_layers_list = [1, 2]


@pytest.mark.parametrize('out_dim', out_dims)
@pytest.mark.parametrize('dropout', dropouts)
@pytest.mark.parametrize('iterations', iterations_list)
@pytest.mark.parametrize('heads', heads_list)
@pytest.mark.parametrize('lambd', lambd_list)
@pytest.mark.parametrize('num_layers', num_layers_list)
def test_aero_gnn(out_dim, dropout, iterations, heads, lambd, num_layers):
x = torch.randn(3, 8)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
out_channels = 16 if out_dim is None else out_dim

model = AEROGNN(
in_channels=8,
hidden_channels=16,
num_layers=num_layers,
out_channels=out_dim,
iterations=iterations,
heads=heads,
lambd=lambd,
dropout=dropout,
)
expected_str = (f'AEROGNN(8, 16, num_layers={num_layers}, '
f'out_channels={out_channels}, iterations={iterations}, '
f'heads={heads})')
assert str(model) == expected_str
assert model(x, edge_index).size() == (3, out_channels)


@pytest.mark.parametrize('add_dropout', [False, True])
def test_aero_gnn_add_dropout(add_dropout):
x = torch.randn(3, 8)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])

model = AEROGNN(
in_channels=8,
hidden_channels=16,
num_layers=1,
out_channels=16,
iterations=5,
heads=1,
add_dropout=add_dropout,
)
assert model(x, edge_index).size() == (3, 16)


def test_aero_gnn_reset_parameters():
x = torch.randn(3, 8)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])

model = AEROGNN(
in_channels=8,
hidden_channels=16,
num_layers=2,
out_channels=16,
iterations=10,
heads=2,
)

# Get output before reset
out1 = model(x, edge_index)

# Reset parameters
model.reset_parameters()

# Get output after reset (should be different due to random initialization)
out2 = model(x, edge_index)

# Outputs should be different (very unlikely to be the same after reset)
assert not torch.allclose(out1, out2, atol=1e-6)


def test_aero_gnn_single_layer():
x = torch.randn(3, 8)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])

model = AEROGNN(
in_channels=8,
hidden_channels=16,
num_layers=1,
out_channels=16,
iterations=5,
)
assert model(x, edge_index).size() == (3, 16)


def test_aero_gnn_multiple_iterations():
x = torch.randn(3, 8)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])

model = AEROGNN(
in_channels=8,
hidden_channels=16,
num_layers=1,
out_channels=16,
iterations=20,
)
assert model(x, edge_index).size() == (3, 16)


def test_aero_gnn_multi_head():
x = torch.randn(3, 8)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])

model = AEROGNN(
in_channels=8,
hidden_channels=16,
num_layers=1,
out_channels=16,
iterations=5,
heads=4,
)
assert model(x, edge_index).size() == (3, 16)


def test_aero_gnn_with_sparse_tensor():
from torch_geometric.typing import SparseTensor

x = torch.randn(3, 8)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(3, 3))

model = AEROGNN(
in_channels=8,
hidden_channels=16,
num_layers=1,
out_channels=16,
iterations=5,
)
assert model(x, adj).size() == (3, 16)


@withDevice
def test_aero_gnn_device(device):
x = torch.randn(3, 8, device=device)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device)

model = AEROGNN(
in_channels=8,
hidden_channels=16,
num_layers=1,
out_channels=16,
iterations=5,
).to(device)

assert model(x, edge_index).size() == (3, 16)
assert model(x, edge_index).device == device


def test_aero_gnn_gradient():
x = torch.randn(3, 8, requires_grad=True)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])

model = AEROGNN(
in_channels=8,
hidden_channels=16,
num_layers=1,
out_channels=16,
iterations=5,
)

out = model(x, edge_index)
loss = out.sum()
loss.backward()

assert x.grad is not None
# Check that model parameters have gradients
for param in model.parameters():
if param.requires_grad:
assert param.grad is not None


def test_aero_gnn_with_data():
data = Data(
x=torch.randn(3, 8),
edge_index=torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]),
)

model = AEROGNN(
in_channels=8,
hidden_channels=16,
num_layers=1,
out_channels=16,
iterations=5,
)

out = model(data.x, data.edge_index)
assert out.size() == (3, 16)


def test_aero_gnn_num_nodes():
x = torch.randn(5, 8)
edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4]])

model = AEROGNN(
in_channels=8,
hidden_channels=16,
num_layers=1,
out_channels=16,
iterations=5,
)

# Test with explicit num_nodes
out1 = model(x, edge_index, num_nodes=5)
# Test without num_nodes (should infer)
out2 = model(x, edge_index)

assert out1.size() == (5, 16)
assert out2.size() == (5, 16)
assert torch.allclose(out1, out2, atol=1e-6)
2 changes: 2 additions & 0 deletions torch_geometric/nn/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .gatv2_conv import GATv2Conv
from .transformer_conv import TransformerConv
from .agnn_conv import AGNNConv
from .aero_conv import AEROConv
from .tag_conv import TAGConv
from .gin_conv import GINConv, GINEConv
from .arma_conv import ARMAConv
Expand Down Expand Up @@ -82,6 +83,7 @@
'GATv2Conv',
'TransformerConv',
'AGNNConv',
'AEROConv',
'TAGConv',
'GINConv',
'GINEConv',
Expand Down
Loading