diff --git a/CHANGELOG.md b/CHANGELOG.md index 709cced4b680..9b6f5353550b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ### Added +- Added AEROGNN model and AEROConv layer for deep attention in graph neural networks ([#10285](https://github.com/pyg-team/pytorch_geometric/pull/10285)) + ### Changed ### Deprecated diff --git a/examples/ogbn_train.py b/examples/ogbn_train.py index 141780eef78c..a2fb2fcb9110 100644 --- a/examples/ogbn_train.py +++ b/examples/ogbn_train.py @@ -11,7 +11,13 @@ from torch_geometric import seed_everything from torch_geometric.loader import NeighborLoader -from torch_geometric.nn.models import GAT, GraphSAGE, Polynormer, SGFormer +from torch_geometric.nn.models import ( + AEROGNN, + GAT, + GraphSAGE, + Polynormer, + SGFormer, +) from torch_geometric.utils import ( add_self_loops, remove_self_loops, @@ -37,7 +43,7 @@ "--model", type=str.lower, default='SGFormer', - choices=['sage', 'gat', 'sgformer', 'polynormer'], + choices=['sage', 'gat', 'sgformer', 'polynormer', 'aero'], help="Model used for training", ) @@ -55,6 +61,12 @@ parser.add_argument('--lr', type=float, default=0.003) parser.add_argument('--wd', type=float, default=0.0) parser.add_argument('--dropout', type=float, default=0.5) +parser.add_argument('--iterations', type=int, default=10, + help='number of propagation iterations for AERO model') +parser.add_argument('--lambd', type=float, default=1.0, + help='decay weight parameter for AERO model') +parser.add_argument('--add_dropout', action='store_true', + help='apply dropout before final layer for AERO model') parser.add_argument( '--use_directed_graph', action='store_true', @@ -214,6 +226,18 @@ def get_model(model_name: str) -> torch.nn.Module: out_channels=dataset.num_classes, local_layers=num_layers, ) + elif model_name == 'aero': + model = AEROGNN( + in_channels=dataset.num_features, + hidden_channels=num_hidden_channels, + num_layers=num_layers, + out_channels=dataset.num_classes, + iterations=args.iterations, + heads=args.num_heads, + lambd=args.lambd, + dropout=args.dropout, + add_dropout=args.add_dropout, + ) else: raise ValueError(f'Unsupported model type: {model_name}') diff --git a/examples/ogbn_train_cugraph.py b/examples/ogbn_train_cugraph.py index a3045b736ab3..0b7f91108b87 100644 --- a/examples/ogbn_train_cugraph.py +++ b/examples/ogbn_train_cugraph.py @@ -136,6 +136,7 @@ def arg_parse(): 'SAGE', 'GAT', 'GCN', + 'AERO', # TODO: Uncomment when we add support for disjoint sampling # 'SGFormer', ], @@ -145,8 +146,14 @@ def arg_parse(): "--num_heads", type=int, default=1, - help="If using GATConv or GT, number of attention heads to use", + help="If using GATConv, GT, or AERO, number of attention heads to use", ) + parser.add_argument('--iterations', type=int, default=10, + help='number of propagation iterations for AERO model') + parser.add_argument('--lambd', type=float, default=1.0, + help='decay weight parameter for AERO model') + parser.add_argument('--add_dropout', action='store_true', + help='apply dropout before final layer for AERO model') parser.add_argument('--tempdir_root', type=str, default=None) args = parser.parse_args() return args @@ -277,6 +284,18 @@ def test(model, loader): model = torch_geometric.nn.models.GraphSAGE( dataset.num_features, args.hidden_channels, args.num_layers, dataset.num_classes).cuda() + elif args.model == "AERO": + model = torch_geometric.nn.models.AEROGNN( + in_channels=dataset.num_features, + hidden_channels=args.hidden_channels, + num_layers=args.num_layers, + out_channels=dataset.num_classes, + iterations=args.iterations, + heads=args.num_heads, + lambd=args.lambd, + dropout=args.dropout, + add_dropout=args.add_dropout, + ).cuda() elif args.model == 'SGFormer': # TODO add support for this with disjoint sampling model = torch_geometric.nn.models.SGFormer( diff --git a/test/nn/models/test_aero_gnn.py b/test/nn/models/test_aero_gnn.py new file mode 100644 index 000000000000..f660478305cf --- /dev/null +++ b/test/nn/models/test_aero_gnn.py @@ -0,0 +1,224 @@ +import pytest +import torch + +from torch_geometric.data import Data +from torch_geometric.nn.models import AEROGNN +from torch_geometric.testing import withDevice + +out_dims = [None, 8] +dropouts = [0.0, 0.5] +iterations_list = [1, 5, 10] +heads_list = [1, 2] +lambd_list = [0.5, 1.0, 2.0] +num_layers_list = [1, 2] + + +@pytest.mark.parametrize('out_dim', out_dims) +@pytest.mark.parametrize('dropout', dropouts) +@pytest.mark.parametrize('iterations', iterations_list) +@pytest.mark.parametrize('heads', heads_list) +@pytest.mark.parametrize('lambd', lambd_list) +@pytest.mark.parametrize('num_layers', num_layers_list) +def test_aero_gnn(out_dim, dropout, iterations, heads, lambd, num_layers): + x = torch.randn(3, 8) + edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + out_channels = 16 if out_dim is None else out_dim + + model = AEROGNN( + in_channels=8, + hidden_channels=16, + num_layers=num_layers, + out_channels=out_dim, + iterations=iterations, + heads=heads, + lambd=lambd, + dropout=dropout, + ) + expected_str = (f'AEROGNN(8, 16, num_layers={num_layers}, ' + f'out_channels={out_channels}, iterations={iterations}, ' + f'heads={heads})') + assert str(model) == expected_str + assert model(x, edge_index).size() == (3, out_channels) + + +@pytest.mark.parametrize('add_dropout', [False, True]) +def test_aero_gnn_add_dropout(add_dropout): + x = torch.randn(3, 8) + edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + + model = AEROGNN( + in_channels=8, + hidden_channels=16, + num_layers=1, + out_channels=16, + iterations=5, + heads=1, + add_dropout=add_dropout, + ) + assert model(x, edge_index).size() == (3, 16) + + +def test_aero_gnn_reset_parameters(): + x = torch.randn(3, 8) + edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + + model = AEROGNN( + in_channels=8, + hidden_channels=16, + num_layers=2, + out_channels=16, + iterations=10, + heads=2, + ) + + # Get output before reset + out1 = model(x, edge_index) + + # Reset parameters + model.reset_parameters() + + # Get output after reset (should be different due to random initialization) + out2 = model(x, edge_index) + + # Outputs should be different (very unlikely to be the same after reset) + assert not torch.allclose(out1, out2, atol=1e-6) + + +def test_aero_gnn_single_layer(): + x = torch.randn(3, 8) + edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + + model = AEROGNN( + in_channels=8, + hidden_channels=16, + num_layers=1, + out_channels=16, + iterations=5, + ) + assert model(x, edge_index).size() == (3, 16) + + +def test_aero_gnn_multiple_iterations(): + x = torch.randn(3, 8) + edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + + model = AEROGNN( + in_channels=8, + hidden_channels=16, + num_layers=1, + out_channels=16, + iterations=20, + ) + assert model(x, edge_index).size() == (3, 16) + + +def test_aero_gnn_multi_head(): + x = torch.randn(3, 8) + edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + + model = AEROGNN( + in_channels=8, + hidden_channels=16, + num_layers=1, + out_channels=16, + iterations=5, + heads=4, + ) + assert model(x, edge_index).size() == (3, 16) + + +def test_aero_gnn_with_sparse_tensor(): + from torch_geometric.typing import SparseTensor + + x = torch.randn(3, 8) + edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(3, 3)) + + model = AEROGNN( + in_channels=8, + hidden_channels=16, + num_layers=1, + out_channels=16, + iterations=5, + ) + assert model(x, adj).size() == (3, 16) + + +@withDevice +def test_aero_gnn_device(device): + x = torch.randn(3, 8, device=device) + edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device) + + model = AEROGNN( + in_channels=8, + hidden_channels=16, + num_layers=1, + out_channels=16, + iterations=5, + ).to(device) + + assert model(x, edge_index).size() == (3, 16) + assert model(x, edge_index).device == device + + +def test_aero_gnn_gradient(): + x = torch.randn(3, 8, requires_grad=True) + edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + + model = AEROGNN( + in_channels=8, + hidden_channels=16, + num_layers=1, + out_channels=16, + iterations=5, + ) + + out = model(x, edge_index) + loss = out.sum() + loss.backward() + + assert x.grad is not None + # Check that model parameters have gradients + for param in model.parameters(): + if param.requires_grad: + assert param.grad is not None + + +def test_aero_gnn_with_data(): + data = Data( + x=torch.randn(3, 8), + edge_index=torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]), + ) + + model = AEROGNN( + in_channels=8, + hidden_channels=16, + num_layers=1, + out_channels=16, + iterations=5, + ) + + out = model(data.x, data.edge_index) + assert out.size() == (3, 16) + + +def test_aero_gnn_num_nodes(): + x = torch.randn(5, 8) + edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4]]) + + model = AEROGNN( + in_channels=8, + hidden_channels=16, + num_layers=1, + out_channels=16, + iterations=5, + ) + + # Test with explicit num_nodes + out1 = model(x, edge_index, num_nodes=5) + # Test without num_nodes (should infer) + out2 = model(x, edge_index) + + assert out1.size() == (5, 16) + assert out2.size() == (5, 16) + assert torch.allclose(out1, out2, atol=1e-6) diff --git a/torch_geometric/nn/conv/__init__.py b/torch_geometric/nn/conv/__init__.py index 871374060055..467545618b54 100644 --- a/torch_geometric/nn/conv/__init__.py +++ b/torch_geometric/nn/conv/__init__.py @@ -14,6 +14,7 @@ from .gatv2_conv import GATv2Conv from .transformer_conv import TransformerConv from .agnn_conv import AGNNConv +from .aero_conv import AEROConv from .tag_conv import TAGConv from .gin_conv import GINConv, GINEConv from .arma_conv import ARMAConv @@ -82,6 +83,7 @@ 'GATv2Conv', 'TransformerConv', 'AGNNConv', + 'AEROConv', 'TAGConv', 'GINConv', 'GINEConv', diff --git a/torch_geometric/nn/conv/aero_conv.py b/torch_geometric/nn/conv/aero_conv.py new file mode 100644 index 000000000000..00e50a8bf072 --- /dev/null +++ b/torch_geometric/nn/conv/aero_conv.py @@ -0,0 +1,335 @@ +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Parameter + +from torch_geometric.nn.conv import MessagePassing +from torch_geometric.nn.inits import glorot, ones +from torch_geometric.typing import Adj, SparseTensor +from torch_geometric.utils import scatter + + +class AEROConv(MessagePassing): + r"""The AERO (Attentive dEep pROpagation) graph convolution operator from + the `"Towards Deep Attention in Graph Neural Networks: Problems and + Remedies" `_ paper. + + AERO-GNN addresses problems in deep graph attention, including + vulnerability to over-smoothed features and smooth cumulative attention. + The AERO operator performs iterative message passing with attention + mechanisms that include: + + 1. **Edge attention** :math:`\alpha_{ij}^{(k)}`: Computes attention weights + for edges at iteration :math:`k` using: + .. math:: + \alpha_{ij}^{(k)} = \text{softplus}(\mathbf{a}^{(k)} \cdot + (\mathbf{z}_i^{(k)} + \mathbf{z}_j^{(k)})) + \epsilon + + followed by symmetric normalization: + .. math:: + \hat{\alpha}_{ij}^{(k)} = \frac{\alpha_{ij}^{(k)}}{\sqrt{ + \deg(i) \deg(j)}} + + 2. **Hop attention** :math:`\gamma_i^{(k)}`: Computes attention weights for + each propagation hop: + .. math:: + \gamma_i^{(k)} = \text{ELU}(\mathbf{W}^{(k)} [\mathbf{h}_i^{(k)}, + \mathbf{z}_i^{(k)}]) + \mathbf{b}^{(k)} + + 3. **Decay weights**: Applies exponential decay across propagation + iterations: + .. math:: + w_k = \log\left(\frac{\lambda}{k+1} + 1 + \epsilon\right) + + Args: + in_channels (int): Size of each input sample. + out_channels (int): Size of each output sample. + heads (int, optional): Number of multi-head attentions. + (default: :obj:`1`) + iterations (int, optional): Number of propagation iterations :math:`K`. + (default: :obj:`10`) + lambd (float, optional): Decay weight parameter :math:`\lambda` for + exponential decay. (default: :obj:`1.0`) + dropout (float, optional): Dropout probability of the normalized + attention coefficients. (default: :obj:`0.0`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`torch_geometric.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})`, + edge indices :math:`(2, |\mathcal{E}|)` + - **output:** node features :math:`(|\mathcal{V}|, H * F_{out})` where + :math:`H` is the number of heads + """ + def __init__( + self, + in_channels: int, + out_channels: int, + heads: int = 1, + iterations: int = 10, + lambd: float = 1.0, + dropout: float = 0.0, + bias: bool = True, + **kwargs, + ): + kwargs.setdefault('aggr', 'add') + super().__init__(node_dim=0, **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.iterations = iterations + self.lambd = lambd + self.dropout = dropout + + # Edge attention parameters: one for each iteration k in [1, K] + self.edge_atts = torch.nn.ParameterList([ + Parameter(torch.empty(1, heads, out_channels)) + for _ in range(iterations) + ]) + + # Hop attention parameters: one for each iteration k in [0, K] + # For k=0, we only use h (no z_scale), so dimension is out_channels + # For k>0, we concatenate h and z_scale, so dimension is 2*out_channels + self.hop_atts = torch.nn.ParameterList( + [Parameter(torch.empty(1, heads, out_channels))]) + for _ in range(iterations): + self.hop_atts.append( + Parameter(torch.empty(1, heads, 2 * out_channels))) + + # Hop attention biases + self.hop_biases = torch.nn.ParameterList( + [Parameter(torch.empty(1, heads)) for _ in range(iterations + 1)]) + + # Decay weights: pre-computed log values for efficiency + self.register_buffer( + 'decay_weights', + torch.tensor([ + math.log((lambd / (k + 1)) + (1 + 1e-6)) + for k in range(iterations + 1) + ], dtype=torch.float32)) + + if bias: + self.bias = Parameter(torch.empty(heads * out_channels)) + else: + self.register_parameter('bias', None) + + # Current iteration index (set during forward pass) + self._current_iteration = 0 + # Store edge_index during forward pass for message function + self._edge_index: Optional[Tensor] = None + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + for att in self.edge_atts: + glorot(att) + for att in self.hop_atts: + glorot(att) + for bias in self.hop_biases: + ones(bias) + if self.bias is not None: + ones(self.bias) + + def forward( + self, + x: Tensor, + edge_index: Adj, + num_nodes: Optional[int] = None, + ) -> Tensor: + r"""Runs the forward pass of the module. + + Args: + x (torch.Tensor): The input node features of shape + :math:`(|\mathcal{V}|, H * F_{out})` where :math:`H` is the + number of heads and :math:`F_{out}` is the output feature size. + edge_index (torch.Tensor or SparseTensor): The edge indices. + num_nodes (int, optional): The number of nodes. If not provided, + will be inferred from :obj:`edge_index`. (default: :obj:`None`) + + Returns: + torch.Tensor: The output node features of shape + :math:`(|\mathcal{V}|, H * F_{out})`. + """ + if num_nodes is None: + if isinstance(edge_index, Tensor): + num_nodes = int(edge_index.max()) + 1 + elif isinstance(edge_index, SparseTensor): + num_nodes = edge_index.size(0) + else: + raise ValueError("Cannot infer num_nodes from edge_index type") + + # Reshape input to (num_nodes, heads, out_channels) + h = x.view(-1, self.heads, self.out_channels) + + # Initialize: k=0 + self._current_iteration = 0 + g = self._hop_attention(h, z_scale=None) + z = h * g + z_scale = z * self.decay_weights[0].item() + + # Store edge_index for use in message function + if isinstance(edge_index, Tensor): + self._edge_index = edge_index + else: + # For SparseTensor, convert to edge_index + row, col, _ = edge_index.coo() + self._edge_index = torch.stack([row, col], dim=0) + + # Iterative propagation: k in [1, K] + for k in range(1, self.iterations + 1): + self._current_iteration = k + # Propagate messages + # MessagePassing will split z_scale into z_scale_i and z_scale_j + h = self.propagate( + edge_index, + x=h, + z_scale=z_scale, + num_nodes=num_nodes, + size=None, + ) + # Compute hop attention and accumulate + g = self._hop_attention(h, z_scale) + z = z + h * g + # Update z_scale for next iteration + z_scale = z * self.decay_weights[k].item() + + # Reshape back to (num_nodes, heads * out_channels) + out = z.view(-1, self.heads * self.out_channels) + + if self.bias is not None: + out = out + self.bias + + # Clear stored edge_index + self._edge_index = None + + return out + + def _hop_attention( + self, + h: Tensor, + z_scale: Optional[Tensor], + ) -> Tensor: + r"""Computes hop attention weights. + + Args: + h (torch.Tensor): Current node features of shape + :math:`(|\mathcal{V}|, H, F_{out})`. + z_scale (torch.Tensor, optional): Scaled accumulated features of + shape :math:`(|\mathcal{V}|, H, F_{out})`. If :obj:`None`, + only :obj:`h` is used (for k=0). + + Returns: + torch.Tensor: Hop attention weights of shape + :math:`(|\mathcal{V}|, H, 1)`. + """ + k = self._current_iteration + + if z_scale is None: + # k=0: only use h + x = h + else: + # k>0: concatenate h and z_scale + x = torch.cat([h, z_scale], dim=-1) + + # Apply ELU activation + x = F.elu(x) + + # Compute attention: (hop_att * x).sum(dim=-1) + bias + att = self.hop_atts[k] + g = (att * x).sum(dim=-1) + self.hop_biases[k] + + return g.unsqueeze(-1) + + def _edge_attention( + self, + z_scale_i: Tensor, + z_scale_j: Tensor, + edge_index: Tensor, + num_nodes: int, + ) -> Tensor: + r"""Computes edge attention weights with symmetric normalization. + + Args: + z_scale_i (torch.Tensor): Scaled features for target nodes of + shape :math:`(|\mathcal{E}|, H, F_{out})`. + z_scale_j (torch.Tensor): Scaled features for source nodes of + shape :math:`(|\mathcal{E}|, H, F_{out})`. + edge_index (torch.Tensor): Edge indices of shape + :math:`(2, |\mathcal{E}|)`. + num_nodes (int): Number of nodes. + + Returns: + torch.Tensor: Normalized edge attention weights of shape + :math:`(|\mathcal{E}|,)`. + """ + k = self._current_iteration + + # Compute unnormalized attention: a_ij = softplus(att^T (z_i + z_j)) + eps + a_ij = z_scale_i + z_scale_j + a_ij = F.elu(a_ij) + a_ij = (self.edge_atts[k - 1] * a_ij).sum(dim=-1) + a_ij = F.softplus(a_ij) + 1e-6 + + # Symmetric normalization: a_ij / sqrt(deg(i) * deg(j)) + row, col = edge_index[0], edge_index[1] + # Compute degrees for both source and target nodes + deg_col = scatter(a_ij, col, dim=0, dim_size=num_nodes, reduce='sum') + deg_row = scatter(a_ij, row, dim=0, dim_size=num_nodes, reduce='sum') + + deg_col_inv_sqrt = deg_col.pow(-0.5) + deg_col_inv_sqrt = deg_col_inv_sqrt.masked_fill( + deg_col_inv_sqrt == float('inf'), 0.0) + deg_row_inv_sqrt = deg_row.pow(-0.5) + deg_row_inv_sqrt = deg_row_inv_sqrt.masked_fill( + deg_row_inv_sqrt == float('inf'), 0.0) + + a_ij = deg_row_inv_sqrt[row] * a_ij * deg_col_inv_sqrt[col] + + # Apply dropout + if self.training and self.dropout > 0.0: + a_ij = F.dropout(a_ij, p=self.dropout, training=True) + + return a_ij + + def message( + self, + x_j: Tensor, + z_scale_i: Tensor, + z_scale_j: Tensor, + index: Tensor, + num_nodes: int, + ) -> Tensor: + r"""Constructs messages from source nodes :math:`j` to target nodes :math:`i`. + + Args: + x_j (torch.Tensor): Source node features of shape + :math:`(|\mathcal{E}|, H, F_{out})`. + z_scale_i (torch.Tensor): Scaled features for target nodes of + shape :math:`(|\mathcal{E}|, H, F_{out})`. + z_scale_j (torch.Tensor): Scaled features for source nodes of + shape :math:`(|\mathcal{E}|, H, F_{out})`. + index (torch.Tensor): Target node indices for aggregation of shape + :math:`(|\mathcal{E}|,)`. + num_nodes (int): Number of nodes. + + Returns: + torch.Tensor: Messages of shape :math:`(|\mathcal{E}|, H, F_{out})`. + """ + if self._edge_index is None: + raise RuntimeError("edge_index not set. This should not happen.") + a = self._edge_attention(z_scale_i, z_scale_j, self._edge_index, + num_nodes) + return a.unsqueeze(-1) * x_j + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, heads={self.heads}, ' + f'iterations={self.iterations})') diff --git a/torch_geometric/nn/models/__init__.py b/torch_geometric/nn/models/__init__.py index 9267a302f6a9..37fcede192f2 100644 --- a/torch_geometric/nn/models/__init__.py +++ b/torch_geometric/nn/models/__init__.py @@ -37,6 +37,7 @@ from torch_geometric.explain.algorithm.captum import (to_captum_input, captum_output_to_dicts) from .attract_repel import ARLinkPredictor +from .aero_gnn import AEROGNN __all__ = classes = [ 'MLP', @@ -86,4 +87,5 @@ 'SGFormer', 'Polynormer', 'ARLinkPredictor', + 'AEROGNN', ] diff --git a/torch_geometric/nn/models/aero_gnn.py b/torch_geometric/nn/models/aero_gnn.py new file mode 100644 index 000000000000..174c6b8b22a0 --- /dev/null +++ b/torch_geometric/nn/models/aero_gnn.py @@ -0,0 +1,182 @@ +from typing import Optional + +import torch +from torch import Tensor + +from torch_geometric.nn.conv import AEROConv +from torch_geometric.nn.dense.linear import Linear +from torch_geometric.typing import Adj + + +class AEROGNN(torch.nn.Module): + r"""The AERO-GNN (Attentive dEep pROpagation-GNN) model from the + `"Towards Deep Attention in Graph Neural Networks: Problems and Remedies" + `_ paper. + + AERO-GNN addresses problems in deep graph attention, including + vulnerability to over-smoothed features and smooth cumulative attention. + The model mitigates these issues through: + + 1. **Edge attention** :math:`\alpha_{ij}^{(k)}`: Learns adaptive attention + weights for edges at each propagation iteration, remaining + edge-adaptive and graph-adaptive even at deep layers. + 2. **Hop attention** :math:`\gamma_i^{(k)}`: Learns adaptive attention weights + for each propagation hop, remaining hop-adaptive, node-adaptive, and + graph-adaptive at deep layers. + 3. **Exponential decay weights**: Applied across propagation iterations to + prevent attention collapse and maintain expressiveness at deep layers. + + The model first applies multi-layer MLP transformations to input features, + then performs iterative message passing with AERO attention mechanisms, + and finally applies a classification layer. + + Args: + in_channels (int): Size of each input sample. + hidden_channels (int): Size of each hidden sample. + num_layers (int): Number of MLP layers for feature transformation. + (default: :obj:`1`) + out_channels (int, optional): Size of each output sample. If set to + :obj:`None`, will be set to :obj:`hidden_channels`. + (default: :obj:`None`) + iterations (int, optional): Number of propagation iterations + :math:`K`. (default: :obj:`10`) + heads (int, optional): Number of multi-head attentions. (default: :obj:`1`) + lambd (float, optional): Decay weight parameter :math:`\lambda` for + exponential decay. (default: :obj:`1.0`) + dropout (float, optional): Dropout probability. (default: :obj:`0.0`) + add_dropout (bool, optional): If set to :obj:`True`, applies dropout + before the final classification layer. (default: :obj:`False`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})`, + edge indices :math:`(2, |\mathcal{E}|)` + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` + """ + def __init__( + self, + in_channels: int, + hidden_channels: int, + num_layers: int = 1, + out_channels: Optional[int] = None, + iterations: int = 10, + heads: int = 1, + lambd: float = 1.0, + dropout: float = 0.0, + add_dropout: bool = False, + bias: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.num_layers = num_layers + self.out_channels = out_channels or hidden_channels + self.iterations = iterations + self.heads = heads + self.lambd = lambd + self.dropout = dropout + self.add_dropout = add_dropout + + # Compute hidden channels with heads + self.hidden_channels_with_heads = heads * hidden_channels + + # MLP layers for feature transformation + self.lins = torch.nn.ModuleList() + # First layer: in_channels -> hidden_channels * heads + self.lins.append( + Linear( + in_channels, + self.hidden_channels_with_heads, + bias=bias, + weight_initializer='glorot', + )) + # Middle layers: hidden_channels * heads -> hidden_channels * heads + for _ in range(num_layers - 1): + self.lins.append( + Linear( + self.hidden_channels_with_heads, + self.hidden_channels_with_heads, + bias=bias, + weight_initializer='glorot', + )) + # Final layer: hidden_channels * heads -> out_channels + self.lins.append( + Linear( + self.hidden_channels_with_heads, + self.out_channels, + bias=bias, + weight_initializer='glorot', + )) + + # AERO convolution layer for iterative propagation + self.conv = AEROConv( + in_channels=self.hidden_channels_with_heads, + out_channels=hidden_channels, + heads=heads, + iterations=iterations, + lambd=lambd, + dropout=dropout, + bias=bias, + ) + + self.dropout_layer = torch.nn.Dropout(p=dropout) + self.elu = torch.nn.ELU() + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + for lin in self.lins: + lin.reset_parameters() + self.conv.reset_parameters() + + def forward( + self, + x: Tensor, + edge_index: Adj, + num_nodes: Optional[int] = None, + ) -> Tensor: + r"""Runs the forward pass of the module. + + Args: + x (torch.Tensor): The input node features of shape + :math:`(|\mathcal{V}|, F_{in})`. + edge_index (torch.Tensor or SparseTensor): The edge indices. + num_nodes (int, optional): The number of nodes. If not provided, + will be inferred from :obj:`edge_index`. (default: :obj:`None`) + + Returns: + torch.Tensor: The output node features of shape + :math:`(|\mathcal{V}|, F_{out})`. + """ + # Feature transformation: MLP layers + x = self.dropout_layer(x) + x = self.lins[0](x) + + # Apply ELU and dropout for middle layers + for i in range(1, self.num_layers): + x = self.elu(x) + x = self.dropout_layer(x) + x = self.lins[i](x) + + # AERO propagation: iterative message passing with attention + x = self.conv(x, edge_index, num_nodes=num_nodes) + + # Final classification layer + x = x.view(-1, self.hidden_channels_with_heads) + x = self.elu(x) + if self.add_dropout: + x = self.dropout_layer(x) + x = self.lins[-1](x) + + return x + + def __repr__(self) -> str: + return ( + f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.hidden_channels}, num_layers={self.num_layers}, ' + f'out_channels={self.out_channels}, iterations={self.iterations}, ' + f'heads={self.heads})')