diff --git a/pyproject.toml b/pyproject.toml index 32b7dc5d9d2f..6ec04c243f54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,6 +104,7 @@ full = [ "torch_geometric[graphgym, modelhub]", "torchmetrics", "trimesh", + "gfn-layer" ] [project.urls] diff --git a/test/nn/unpool/test_gfn_unpooling.py b/test/nn/unpool/test_gfn_unpooling.py new file mode 100644 index 000000000000..b3a8a3693962 --- /dev/null +++ b/test/nn/unpool/test_gfn_unpooling.py @@ -0,0 +1,41 @@ +import torch + +from torch_geometric.nn import GFNUnpooling +from torch_geometric.testing import withPackage + + +def _setup_gfn_unpooling(): + weight = torch.tensor( + [[1, 10, 100], [2, 20, 200], [3, 30, 300], [4, 40, 400]], dtype=float) + bias = torch.tensor([1, 10, 100, 1000], dtype=float) + + out_graph = torch.tensor([(-3, 2), (1, 0), (2, 1), (3, 2)]) + + unpool = GFNUnpooling(3, out_graph) + + with torch.no_grad(): + unpool._gfn.weight.copy_(weight) + unpool._gfn.bias.copy_(bias) + + return unpool + + +@withPackage('gfn') +def test_gfn_unpooling(): + unpool = _setup_gfn_unpooling() + + x = torch.tensor([[1.0], [10.0], [100.0], [-1.0], [-10.0], [-100.0]]) + pos_y = torch.tensor([ + [-1.0, -1.0], + [1.0, 1.0], + [-2.0, -2.0], + [2.0, 2.0], + ]) + batch_x = torch.tensor([0, 0, 0, 1, 1, 1]) + batch_y = torch.tensor([0, 0, 1, 1]) + + y = unpool(x, pos_y, batch_x, batch_y) + + expected = torch.tensor([[15157.], [30673.], [-15146.], [-29933.]]) + + torch.testing.assert_close(y, expected) diff --git a/torch_geometric/nn/unpool/__init__.py b/torch_geometric/nn/unpool/__init__.py index ce01900cd7bc..bba878797b68 100644 --- a/torch_geometric/nn/unpool/__init__.py +++ b/torch_geometric/nn/unpool/__init__.py @@ -1,9 +1,8 @@ r"""Unpooling package.""" from .knn_interpolate import knn_interpolate +from .gfn import GFNUnpooling -__all__ = [ - 'knn_interpolate', -] +__all__ = ['knn_interpolate', 'GFNUnpooling'] classes = __all__ diff --git a/torch_geometric/nn/unpool/gfn.py b/torch_geometric/nn/unpool/gfn.py new file mode 100644 index 000000000000..7c71b47af4b4 --- /dev/null +++ b/torch_geometric/nn/unpool/gfn.py @@ -0,0 +1,88 @@ +import torch + + +class GFNUnpooling(torch.nn.Module): + r"""The Graph Feedforward Network unpooling layer from + `"GFN: A graph feedforward network for resolution-invariant + reduced operator learning in multifidelity applications" + `_. + + The GFN unpooling equation is given by: + + .. math:: + :nowrap: + + \begin{equation*} + \begin{aligned} + \tilde{W}_{i_{\mathcal{M}_{n}}j} &= \underset{\forall + k_{\mathcal{M}_{o}} \text{ s.t } k_{\mathcal{M}_{o}} + {\leftarrow}\!{\backslash}\!{\rightarrow} + i_{\mathcal{M}_{n}}}{\operatorname{mean}} + {W}_{k_{\mathcal{M}_{o}}j}, \\ + \tilde{b}_{i_{\mathcal{M}_{n}}} &= \underset{\forall + k_{\mathcal{M}_{o}} \text{ s.t } k_{\mathcal{M}_{o}} + {\leftarrow}\!{\backslash}\!{\rightarrow} + i_{\mathcal{M}_{n}}}{\operatorname{mean}} + {b}_{k_{\mathcal{M}_{o}}}. + \end{aligned} + \end{equation*} + + where: + + - :math:`\mathcal{M}_{o}` is the original output graph, + - :math:`\mathcal{M}_{n}` is the new output graph, + - :math:`W` and :math:`b` are the weights and biases + associated to the original graph, + - :math:`\tilde{W}` and :math:`\tilde{b}` are the new + weights and biases associated to the new graph, + - :math:`i_{\mathcal{M}_o} {\leftarrow}\!{\backslash}\! + {\rightarrow} j_{\mathcal{M}_n}` indicates that either + node :math:`i` in graph :math:`\mathcal{M}_o` is the + nearest neighbor of node :math:`j` in graph + :math:`\mathcal{M}_n` or vice versa. + + Args: + in_size (int): Size of the input vector. + pos_y (torch.tensor): Original output graph (node position matrix) + :math:`\in \mathbb{R}^{M^{\prime} \times d}`. + **kwargs (optional): Additional arguments of :class:`gfn.GFN`. + """ + def __init__(self, in_size: int, pos_y: torch.Tensor, **kwargs): + try: + import gfn # noqa + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "GFNUnpooling requires `gfn` to be installed. " + "Please install it via `pip install gfn-layer`.") from e + super().__init__() + self._gfn = gfn.GFN(in_features=in_size, out_features=pos_y, **kwargs) + + def forward(self, x: torch.Tensor, pos_y: torch.Tensor, + batch_x: torch.Tensor = None, batch_y: torch.Tensor = None): + r"""Runs the forward pass of the module. + + Args: + x (torch.Tensor): Node feature matrix + :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. + pos_y (torch.Tensor): New output graph (new node position matrix) + :math:`\in \mathbb{R}^{M \times d}`. + batch_x (torch.Tensor, optional): Batch vector + :math:`\mathbf{b_x} \in {\{ 0, \ldots, B-1\}}^{N}`, which + assigns each node from :math:`\mathbf{X}` to a specific + example. (default: :obj:`None`) + batch_y (torch.Tensor, optional): Batch vector + :math:`\mathbf{b_y} \in {\{ 0, \ldots, B-1\}}^{M}`, which + assigns each node from :math:`\mathbf{Y}` to a specific + example. (default: :obj:`None`) + + :rtype: :class:`torch.Tensor` + """ + out = torch.empty((batch_y.shape[0]), *x.shape[1:], + dtype=self._gfn.weight.dtype, + device=self._gfn.weight.device) + for batch_label in batch_x.unique(): + mask = batch_y == batch_label + pos = pos_y[mask, ...] + x_ = x[batch_x == batch_label, ...] + out[mask] = self._gfn(x_.T, out_graph=pos).T + return out diff --git a/torch_geometric/nn/unpool/knn_interpolate.py b/torch_geometric/nn/unpool/knn_interpolate.py index 2989e48998d9..bc05939e97f7 100644 --- a/torch_geometric/nn/unpool/knn_interpolate.py +++ b/torch_geometric/nn/unpool/knn_interpolate.py @@ -35,13 +35,15 @@ def knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor, each node from :math:`\mathbf{X}` to a specific example. (default: :obj:`None`) batch_y (torch.Tensor, optional): Batch vector - :math:`\mathbf{b_y} \in {\{ 0, \ldots, B-1\}}^N`, which assigns + :math:`\mathbf{b_y} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each node from :math:`\mathbf{Y}` to a specific example. (default: :obj:`None`) k (int, optional): Number of neighbors. (default: :obj:`3`) num_workers (int, optional): Number of workers to use for computation. Has no effect in case :obj:`batch_x` or :obj:`batch_y` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) + + :rtype: :class:`torch.Tensor` """ with torch.no_grad(): assign_index = knn(pos_x, pos_y, k, batch_x=batch_x, batch_y=batch_y,