diff --git a/CODING_STANDARDS/FUNCTIONAL_APIS.md b/CODING_STANDARDS/FUNCTIONAL_APIS.md index 024f5c4765..58dd1ed45d 100644 --- a/CODING_STANDARDS/FUNCTIONAL_APIS.md +++ b/CODING_STANDARDS/FUNCTIONAL_APIS.md @@ -53,7 +53,7 @@ This document is structured in two main sections: | Rule ID | Summary | Apply When | |---------|---------|------------| -| [`FNC-000`](#fnc-000-functionals-must-use-functionspec) | Functionals must use FunctionSpec | Creating new functional APIs | +| [`FNC-000`](#fnc-000-functionals-must-use-functionspec) | Functionals must use FunctionSpec unless they are lightweight tensor helpers | Creating new functional APIs | | [`FNC-001`](#fnc-001-functional-location-and-public-api) | Functional location and public API | Organizing or exporting functionals | | [`FNC-002`](#fnc-002-file-layout-for-functionals) | File layout for functionals | Adding or refactoring functional files | | [`FNC-003`](#fnc-003-registration-and-dispatch-rules) | Registration and dispatch rules | Registering implementations | @@ -71,15 +71,30 @@ This document is structured in two main sections: **Description:** -All functionals must be implemented with `FunctionSpec`, even if only a single -implementation exists. This ensures the operation participates in validation -and benchmarking through input generators and `compare_forward` (and +All functionals with backend dispatch, optional accelerated implementations, or +meaningful benchmark coverage must be implemented with `FunctionSpec`, even if +only a single implementation exists. This ensures the operation participates in +validation and benchmarking through input generators and `compare_forward` (and `compare_backward` where needed). +Small pure-PyTorch tensor helpers can remain plain functions when all of the +following are true: + +- The implementation is a thin composition of PyTorch tensor operations. +- There is no optional backend, custom kernel, or dispatch-selection behavior. +- Benchmarking the helper independently would not provide actionable + performance data. +- The function has focused tests or coverage through its owning feature area. + +When a helper later grows an alternate backend, optional dependency, or +performance-sensitive implementation, convert it to `FunctionSpec`. + **Rationale:** `FunctionSpec` provides a consistent structure for backend registration, -selection, benchmarking and verification across the codebase. +selection, benchmarking and verification across the codebase. The lightweight +helper exception avoids adding ceremony to simple tensor algebra that has no +backend-selection or benchmark surface. **Example:** diff --git a/benchmarks/physicsnemo/nn/functional/registry.py b/benchmarks/physicsnemo/nn/functional/registry.py index ee2f9af68d..58bd794cd1 100644 --- a/benchmarks/physicsnemo/nn/functional/registry.py +++ b/benchmarks/physicsnemo/nn/functional/registry.py @@ -17,6 +17,11 @@ """Registry of FunctionSpec classes to benchmark with ASV.""" from physicsnemo.core.function_spec import FunctionSpec +from physicsnemo.nn.functional.attention.neighborhood_attention import ( + NeighborhoodAttention1D, + NeighborhoodAttention2D, + NeighborhoodAttention3D, +) from physicsnemo.nn.functional.derivatives import ( MeshGreenGaussGradient, MeshlessFDDerivatives, @@ -34,7 +39,11 @@ Real, ViewAsComplex, ) -from physicsnemo.nn.functional.geometry import SignedDistanceField +from physicsnemo.nn.functional.geometry import ( + MeshPoissonDiskSample, + MeshToVoxelFraction, + SignedDistanceField, +) from physicsnemo.nn.functional.interpolation import ( GridToPointInterpolation, PointToGridInterpolation, @@ -62,6 +71,8 @@ SpectralGridGradient, MeshlessFDDerivatives, # Geometry. + MeshPoissonDiskSample, + MeshToVoxelFraction, SignedDistanceField, # Interpolation. GridToPointInterpolation, @@ -74,6 +85,10 @@ ViewAsComplex, Real, Imag, + # Neighborhood attention. + NeighborhoodAttention1D, + NeighborhoodAttention2D, + NeighborhoodAttention3D, ) __all__ = ["FUNCTIONAL_SPECS"] diff --git a/docs/api/nn/functionals/neighborhood_attention.rst b/docs/api/nn/functionals/neighborhood_attention.rst new file mode 100644 index 0000000000..9c5398c555 --- /dev/null +++ b/docs/api/nn/functionals/neighborhood_attention.rst @@ -0,0 +1,17 @@ +Neighborhood Attention Functionals +================================== + +NATTEN 1D +--------- + +.. autofunction:: physicsnemo.nn.functional.na1d + +NATTEN 2D +--------- + +.. autofunction:: physicsnemo.nn.functional.na2d + +NATTEN 3D +--------- + +.. autofunction:: physicsnemo.nn.functional.na3d diff --git a/docs/api/physicsnemo.nn.functionals.rst b/docs/api/physicsnemo.nn.functionals.rst index 3edfc950a1..97d08a6e5e 100644 --- a/docs/api/physicsnemo.nn.functionals.rst +++ b/docs/api/physicsnemo.nn.functionals.rst @@ -24,3 +24,4 @@ in the documentation for performance comparisons. nn/functionals/fourier_spectral nn/functionals/regularization_parameterization nn/functionals/interpolation + nn/functionals/neighborhood_attention diff --git a/physicsnemo/domain_parallel/shard_utils/natten_patches.py b/physicsnemo/domain_parallel/shard_utils/natten_patches.py index 5c3edcf669..e1d06918c5 100644 --- a/physicsnemo/domain_parallel/shard_utils/natten_patches.py +++ b/physicsnemo/domain_parallel/shard_utils/natten_patches.py @@ -32,7 +32,7 @@ MissingShardPatch, UndeterminedShardingError, ) -from physicsnemo.nn.functional.natten import na1d, na2d, na3d +from physicsnemo.nn.functional.attention.neighborhood_attention import na1d, na2d, na3d _natten = OptionalImport("natten") _raw_func_map = { @@ -221,9 +221,9 @@ def _natten_wrapper( r"""Shared wrapper for natten functions to support sharded tensors. Registered with :meth:`ShardTensor.register_function_handler` so that calls - to :func:`~physicsnemo.nn.functional.natten.na1d`, - :func:`~physicsnemo.nn.functional.natten.na2d`, or - :func:`~physicsnemo.nn.functional.natten.na3d` automatically route through + to :func:`~physicsnemo.nn.functional.attention.neighborhood_attention.na1d`, + :func:`~physicsnemo.nn.functional.attention.neighborhood_attention.na2d`, or + :func:`~physicsnemo.nn.functional.attention.neighborhood_attention.na3d` automatically route through this handler when any argument is a :class:`ShardTensor`. Parameters @@ -250,7 +250,14 @@ def _natten_wrapper( q, k, v, kernel_size = args[0], args[1], args[2], args[3] dilation = kwargs.get("dilation", 1) - natten_kwargs = {_k: _v for _k, _v in kwargs.items() if _k != "dilation"} + implementation = kwargs.get("implementation") + if implementation not in (None, "natten"): + raise KeyError( + f"No implementation named '{implementation}' for neighborhood attention" + ) + natten_kwargs = { + _k: _v for _k, _v in kwargs.items() if _k not in ("dilation", "implementation") + } if all(type(_t) is torch.Tensor for _t in (q, k, v)): return func( diff --git a/physicsnemo/nn/functional/__init__.py b/physicsnemo/nn/functional/__init__.py index 2eb680a06a..80b986bf90 100644 --- a/physicsnemo/nn/functional/__init__.py +++ b/physicsnemo/nn/functional/__init__.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .attention import na1d, na2d, na3d from .derivatives import ( mesh_green_gauss_gradient, mesh_lsq_gradient, @@ -40,7 +41,6 @@ interpolation, point_to_grid_interpolation, ) -from .natten import na1d, na2d, na3d from .neighbors import knn, radius_search from .regularization_parameterization import drop_path, weight_fact diff --git a/physicsnemo/nn/functional/attention/__init__.py b/physicsnemo/nn/functional/attention/__init__.py new file mode 100644 index 0000000000..8fe84edcdd --- /dev/null +++ b/physicsnemo/nn/functional/attention/__init__.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .neighborhood_attention import ( + NeighborhoodAttention1D, + NeighborhoodAttention2D, + NeighborhoodAttention3D, + na1d, + na2d, + na3d, +) + +__all__ = [ + "NeighborhoodAttention1D", + "NeighborhoodAttention2D", + "NeighborhoodAttention3D", + "na1d", + "na2d", + "na3d", +] diff --git a/physicsnemo/nn/functional/attention/neighborhood_attention.py b/physicsnemo/nn/functional/attention/neighborhood_attention.py new file mode 100644 index 0000000000..afcaaed1c2 --- /dev/null +++ b/physicsnemo/nn/functional/attention/neighborhood_attention.py @@ -0,0 +1,401 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any + +import torch +from torch.overrides import handle_torch_function, has_torch_function + +from physicsnemo.core.function_spec import FunctionSpec +from physicsnemo.core.version_check import OptionalImport, get_installed_version + +_natten = OptionalImport("natten") + + +def _make_qkv( + shape: tuple[int, ...], + device: torch.device | str, + *, + requires_grad: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Create query/key/value tensors for NAT benchmark inputs.""" + + q = torch.randn(shape, device=device, requires_grad=requires_grad) + k = torch.randn(shape, device=device, requires_grad=requires_grad) + v = torch.randn(shape, device=device, requires_grad=requires_grad) + return q, k, v + + +class _NeighborhoodAttention(FunctionSpec): + """Shared behavior for NATTEN-backed function specs.""" + + _NATTEN_REQUIREMENT = "natten>=0.21.5" + + @classmethod + def dispatch(cls, *args, **kwargs): + try: + return super().dispatch(*args, **kwargs) + except ImportError as exc: + installed_version = get_installed_version("natten") + if installed_version is not None: + raise ImportError( + f"{cls._NATTEN_REQUIREMENT} is required for {cls.__name__}, " + f"but found natten {installed_version}" + ) from exc + try: + _natten.functional + except ImportError as missing_exc: + raise missing_exc from exc + raise ImportError( + f"No available NATTEN implementation found for {cls.__name__}. " + f"Expected {cls._NATTEN_REQUIREMENT}; verify that the installed " + "natten package exposes the required functional backend." + ) from exc + + +class NeighborhoodAttention1D(_NeighborhoodAttention): + """Compute 1D neighborhood attention through the NATTEN backend.""" + + _BENCHMARK_CASES = ( + ("small-l64-h2-d32-k3", (1, 64, 2, 32), 3, 1), + ("medium-l256-h4-d32-k5", (1, 256, 4, 32), 5, 1), + ("large-l1024-h8-d64-k7-d2", (1, 1024, 8, 64), 7, 2), + ) + + @FunctionSpec.register( + name="natten", + required_imports=(_NeighborhoodAttention._NATTEN_REQUIREMENT,), + rank=0, + baseline=True, + ) + def natten_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + dilation: int = 1, + **kwargs: Any, + ) -> torch.Tensor: + """Run the 1D NATTEN backend implementation.""" + return _natten.functional.na1d( + q, k, v, kernel_size, dilation=dilation, **kwargs + ) + + @classmethod + def make_inputs_forward(cls, device: torch.device | str = "cpu"): + """Yield labeled forward benchmark cases for 1D neighborhood attention.""" + for label, shape, kernel_size, dilation in cls._BENCHMARK_CASES: + yield ( + label, + (*_make_qkv(shape, device), kernel_size), + {"dilation": dilation}, + ) + + @classmethod + def make_inputs_backward(cls, device: torch.device | str = "cpu"): + """Yield differentiable benchmark cases for 1D neighborhood attention.""" + for label, shape, kernel_size, dilation in cls._BENCHMARK_CASES: + yield ( + label, + (*_make_qkv(shape, device, requires_grad=True), kernel_size), + {"dilation": dilation}, + ) + + @classmethod + def compare_forward(cls, output: torch.Tensor, reference: torch.Tensor) -> None: + """Compare 1D neighborhood-attention outputs against a reference.""" + torch.testing.assert_close(output, reference) + + +class NeighborhoodAttention2D(_NeighborhoodAttention): + """Compute 2D neighborhood attention through the NATTEN backend.""" + + _BENCHMARK_CASES = ( + ("small-32x32-h2-d32-k3", (1, 32, 32, 2, 32), 3, 1), + ("medium-64x64-h4-d32-k5", (1, 64, 64, 4, 32), 5, 1), + ("large-128x128-h8-d64-k7-d2", (1, 128, 128, 8, 64), 7, 2), + ) + + @FunctionSpec.register( + name="natten", + required_imports=(_NeighborhoodAttention._NATTEN_REQUIREMENT,), + rank=0, + baseline=True, + ) + def natten_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + dilation: int = 1, + **kwargs: Any, + ) -> torch.Tensor: + """Run the 2D NATTEN backend implementation.""" + return _natten.functional.na2d( + q, k, v, kernel_size, dilation=dilation, **kwargs + ) + + @classmethod + def make_inputs_forward(cls, device: torch.device | str = "cpu"): + """Yield labeled forward benchmark cases for 2D neighborhood attention.""" + for label, shape, kernel_size, dilation in cls._BENCHMARK_CASES: + yield ( + label, + (*_make_qkv(shape, device), kernel_size), + {"dilation": dilation}, + ) + + @classmethod + def make_inputs_backward(cls, device: torch.device | str = "cpu"): + """Yield differentiable benchmark cases for 2D neighborhood attention.""" + for label, shape, kernel_size, dilation in cls._BENCHMARK_CASES: + yield ( + label, + (*_make_qkv(shape, device, requires_grad=True), kernel_size), + {"dilation": dilation}, + ) + + @classmethod + def compare_forward(cls, output: torch.Tensor, reference: torch.Tensor) -> None: + """Compare 2D neighborhood-attention outputs against a reference.""" + torch.testing.assert_close(output, reference) + + +class NeighborhoodAttention3D(_NeighborhoodAttention): + """Compute 3D neighborhood attention through the NATTEN backend.""" + + _BENCHMARK_CASES = ( + ("small-8x8x8-h2-d16-k3", (1, 8, 8, 8, 2, 16), 3, 1), + ("medium-16x16x16-h4-d32-k5", (1, 16, 16, 16, 4, 32), 5, 1), + ("large-32x32x32-h4-d32-k7", (1, 32, 32, 32, 4, 32), 7, 1), + ) + + @FunctionSpec.register( + name="natten", + required_imports=(_NeighborhoodAttention._NATTEN_REQUIREMENT,), + rank=0, + baseline=True, + ) + def natten_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + dilation: int = 1, + **kwargs: Any, + ) -> torch.Tensor: + """Run the 3D NATTEN backend implementation.""" + return _natten.functional.na3d( + q, k, v, kernel_size, dilation=dilation, **kwargs + ) + + @classmethod + def make_inputs_forward(cls, device: torch.device | str = "cpu"): + """Yield labeled forward benchmark cases for 3D neighborhood attention.""" + for label, shape, kernel_size, dilation in cls._BENCHMARK_CASES: + yield ( + label, + (*_make_qkv(shape, device), kernel_size), + {"dilation": dilation}, + ) + + @classmethod + def make_inputs_backward(cls, device: torch.device | str = "cpu"): + """Yield differentiable benchmark cases for 3D neighborhood attention.""" + for label, shape, kernel_size, dilation in cls._BENCHMARK_CASES: + yield ( + label, + (*_make_qkv(shape, device, requires_grad=True), kernel_size), + {"dilation": dilation}, + ) + + @classmethod + def compare_forward(cls, output: torch.Tensor, reference: torch.Tensor) -> None: + """Compare 3D neighborhood-attention outputs against a reference.""" + torch.testing.assert_close(output, reference) + + +# Keep the FunctionSpec-produced callables private and expose a second public +# wrapper layer so NAT can first route ShardTensor inputs through PyTorch's +# ``__torch_function__`` protocol. This preserves the domain-parallel halo +# exchange path before falling back to normal FunctionSpec backend dispatch. +# TODO: Generalize this pattern in FunctionSpec once there is a broader design +# for functionals that need tensor-subclass dispatch before backend selection. +_na1d = NeighborhoodAttention1D.make_function("na1d") +_na2d = NeighborhoodAttention2D.make_function("na2d") +_na3d = NeighborhoodAttention3D.make_function("na3d") + + +def na1d( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + dilation: int = 1, + **kwargs: Any, +) -> torch.Tensor: + r"""Compute 1D neighborhood attention, with ``__torch_function__`` dispatch. + + This is a thin wrapper around :func:`natten.functional.na1d` that enables + automatic dispatch through PyTorch's ``__torch_function__`` protocol. When + called with a tensor subclass (e.g. ``ShardTensor``), the registered handler + is invoked instead of the underlying natten implementation. + + Parameters + ---------- + q : torch.Tensor + Query tensor of shape :math:`(B, L, \text{heads}, D)`. + k : torch.Tensor + Key tensor of shape :math:`(B, L, \text{heads}, D)`. + v : torch.Tensor + Value tensor of shape :math:`(B, L, \text{heads}, D)`. + kernel_size : int + Size of the attention kernel window. + dilation : int, default=1 + Dilation factor for the attention kernel. + **kwargs : Any + Additional keyword arguments forwarded to :func:`natten.functional.na1d` + (e.g. ``is_causal``, ``scale``). + + Returns + ------- + torch.Tensor + Output tensor of the same shape as ``q``. + """ + if has_torch_function((q, k, v)): + return handle_torch_function( + na1d, + (q, k, v), + q, + k, + v, + kernel_size, + dilation=dilation, + **kwargs, + ) + return _na1d(q, k, v, kernel_size, dilation=dilation, **kwargs) + + +def na2d( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + dilation: int = 1, + **kwargs: Any, +) -> torch.Tensor: + r"""Compute 2D neighborhood attention, with ``__torch_function__`` dispatch. + + This is a thin wrapper around :func:`natten.functional.na2d` that enables + automatic dispatch through PyTorch's ``__torch_function__`` protocol. When + called with a tensor subclass (e.g. ``ShardTensor``), the registered handler + is invoked instead of the underlying natten implementation. + + Parameters + ---------- + q : torch.Tensor + Query tensor of shape :math:`(B, H, W, \text{heads}, D)`. + k : torch.Tensor + Key tensor of shape :math:`(B, H, W, \text{heads}, D)`. + v : torch.Tensor + Value tensor of shape :math:`(B, H, W, \text{heads}, D)`. + kernel_size : int + Size of the attention kernel window. + dilation : int, default=1 + Dilation factor for the attention kernel. + **kwargs : Any + Additional keyword arguments forwarded to :func:`natten.functional.na2d` + (e.g. ``is_causal``, ``scale``). + + Returns + ------- + torch.Tensor + Output tensor of the same shape as ``q``. + """ + if has_torch_function((q, k, v)): + return handle_torch_function( + na2d, + (q, k, v), + q, + k, + v, + kernel_size, + dilation=dilation, + **kwargs, + ) + return _na2d(q, k, v, kernel_size, dilation=dilation, **kwargs) + + +def na3d( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + dilation: int = 1, + **kwargs: Any, +) -> torch.Tensor: + r"""Compute 3D neighborhood attention, with ``__torch_function__`` dispatch. + + This is a thin wrapper around :func:`natten.functional.na3d` that enables + automatic dispatch through PyTorch's ``__torch_function__`` protocol. When + called with a tensor subclass (e.g. ``ShardTensor``), the registered handler + is invoked instead of the underlying natten implementation. + + Parameters + ---------- + q : torch.Tensor + Query tensor of shape :math:`(B, X, Y, Z, \text{heads}, D)`. + k : torch.Tensor + Key tensor of shape :math:`(B, X, Y, Z, \text{heads}, D)`. + v : torch.Tensor + Value tensor of shape :math:`(B, X, Y, Z, \text{heads}, D)`. + kernel_size : int + Size of the attention kernel window. + dilation : int, default=1 + Dilation factor for the attention kernel. + **kwargs : Any + Additional keyword arguments forwarded to :func:`natten.functional.na3d` + (e.g. ``is_causal``, ``scale``). + + Returns + ------- + torch.Tensor + Output tensor of the same shape as ``q``. + """ + if has_torch_function((q, k, v)): + return handle_torch_function( + na3d, + (q, k, v), + q, + k, + v, + kernel_size, + dilation=dilation, + **kwargs, + ) + return _na3d(q, k, v, kernel_size, dilation=dilation, **kwargs) + + +__all__ = [ + "NeighborhoodAttention1D", + "NeighborhoodAttention2D", + "NeighborhoodAttention3D", + "na1d", + "na2d", + "na3d", +] diff --git a/physicsnemo/nn/functional/equivariant_ops.py b/physicsnemo/nn/functional/equivariant_ops.py index 1b8f16369e..65473cef73 100644 --- a/physicsnemo/nn/functional/equivariant_ops.py +++ b/physicsnemo/nn/functional/equivariant_ops.py @@ -22,11 +22,15 @@ @overload -def smooth_log(x: Float[torch.Tensor, "..."]) -> Float[torch.Tensor, "..."]: ... +def smooth_log(x: Float[torch.Tensor, "..."]) -> Float[torch.Tensor, "..."]: + """Apply smooth log elementwise to a tensor.""" + ... @overload -def smooth_log(x: TensorDict) -> TensorDict: ... +def smooth_log(x: TensorDict) -> TensorDict: + """Apply smooth log elementwise to a TensorDict.""" + ... def smooth_log( @@ -59,11 +63,15 @@ def smooth_log( @overload def legendre_polynomials( x: Float[torch.Tensor, "..."], n: int -) -> list[Float[torch.Tensor, "..."]]: ... +) -> list[Float[torch.Tensor, "..."]]: + """Compute Legendre polynomials for a tensor input.""" + ... @overload -def legendre_polynomials(x: TensorDict, n: int) -> list[TensorDict]: ... +def legendre_polynomials(x: TensorDict, n: int) -> list[TensorDict]: + """Compute Legendre polynomials for a TensorDict input.""" + ... def legendre_polynomials( diff --git a/physicsnemo/nn/functional/natten.py b/physicsnemo/nn/functional/natten.py deleted file mode 100644 index 225acf33d3..0000000000 --- a/physicsnemo/nn/functional/natten.py +++ /dev/null @@ -1,179 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import Any - -import torch -from torch.overrides import handle_torch_function, has_torch_function - -from physicsnemo.core.version_check import OptionalImport - -_natten = OptionalImport("natten") - - -def na1d( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - kernel_size: int, - dilation: int = 1, - **kwargs: Any, -) -> torch.Tensor: - r"""Compute 1D neighborhood attention, with ``__torch_function__`` dispatch. - - This is a thin wrapper around :func:`natten.functional.na1d` that enables - automatic dispatch through PyTorch's ``__torch_function__`` protocol. When - called with a tensor subclass (e.g. ``ShardTensor``), the registered handler - is invoked instead of the underlying natten implementation. - - Parameters - ---------- - q : torch.Tensor - Query tensor of shape :math:`(B, L, \text{heads}, D)`. - k : torch.Tensor - Key tensor of shape :math:`(B, L, \text{heads}, D)`. - v : torch.Tensor - Value tensor of shape :math:`(B, L, \text{heads}, D)`. - kernel_size : int - Size of the attention kernel window. - dilation : int, default=1 - Dilation factor for the attention kernel. - **kwargs : Any - Additional keyword arguments forwarded to :func:`natten.functional.na1d` - (e.g. ``is_causal``, ``scale``). - - Returns - ------- - torch.Tensor - Output tensor of the same shape as ``q``. - """ - if has_torch_function((q, k, v)): - return handle_torch_function( - na1d, - (q, k, v), - q, - k, - v, - kernel_size, - dilation=dilation, - **kwargs, - ) - return _natten.functional.na1d(q, k, v, kernel_size, dilation=dilation, **kwargs) - - -def na2d( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - kernel_size: int, - dilation: int = 1, - **kwargs: Any, -) -> torch.Tensor: - r"""Compute 2D neighborhood attention, with ``__torch_function__`` dispatch. - - This is a thin wrapper around :func:`natten.functional.na2d` that enables - automatic dispatch through PyTorch's ``__torch_function__`` protocol. When - called with a tensor subclass (e.g. ``ShardTensor``), the registered handler - is invoked instead of the underlying natten implementation. - - Parameters - ---------- - q : torch.Tensor - Query tensor of shape :math:`(B, H, W, \text{heads}, D)`. - k : torch.Tensor - Key tensor of shape :math:`(B, H, W, \text{heads}, D)`. - v : torch.Tensor - Value tensor of shape :math:`(B, H, W, \text{heads}, D)`. - kernel_size : int - Size of the attention kernel window. - dilation : int, default=1 - Dilation factor for the attention kernel. - **kwargs : Any - Additional keyword arguments forwarded to :func:`natten.functional.na2d` - (e.g. ``is_causal``, ``scale``). - - Returns - ------- - torch.Tensor - Output tensor of the same shape as ``q``. - """ - if has_torch_function((q, k, v)): - return handle_torch_function( - na2d, - (q, k, v), - q, - k, - v, - kernel_size, - dilation=dilation, - **kwargs, - ) - return _natten.functional.na2d(q, k, v, kernel_size, dilation=dilation, **kwargs) - - -def na3d( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - kernel_size: int, - dilation: int = 1, - **kwargs: Any, -) -> torch.Tensor: - r"""Compute 3D neighborhood attention, with ``__torch_function__`` dispatch. - - This is a thin wrapper around :func:`natten.functional.na3d` that enables - automatic dispatch through PyTorch's ``__torch_function__`` protocol. When - called with a tensor subclass (e.g. ``ShardTensor``), the registered handler - is invoked instead of the underlying natten implementation. - - Parameters - ---------- - q : torch.Tensor - Query tensor of shape :math:`(B, X, Y, Z, \text{heads}, D)`. - k : torch.Tensor - Key tensor of shape :math:`(B, X, Y, Z, \text{heads}, D)`. - v : torch.Tensor - Value tensor of shape :math:`(B, X, Y, Z, \text{heads}, D)`. - kernel_size : int - Size of the attention kernel window. - dilation : int, default=1 - Dilation factor for the attention kernel. - **kwargs : Any - Additional keyword arguments forwarded to :func:`natten.functional.na3d` - (e.g. ``is_causal``, ``scale``). - - Returns - ------- - torch.Tensor - Output tensor of the same shape as ``q``. - """ - if has_torch_function((q, k, v)): - return handle_torch_function( - na3d, - (q, k, v), - q, - k, - v, - kernel_size, - dilation=dilation, - **kwargs, - ) - return _natten.functional.na3d(q, k, v, kernel_size, dilation=dilation, **kwargs) - - -__all__ = ["na1d", "na2d", "na3d"] diff --git a/physicsnemo/nn/module/dit_layers.py b/physicsnemo/nn/module/dit_layers.py index c069c04025..fab69fb471 100644 --- a/physicsnemo/nn/module/dit_layers.py +++ b/physicsnemo/nn/module/dit_layers.py @@ -27,7 +27,9 @@ from physicsnemo.core import Module from physicsnemo.core.version_check import OptionalImport, check_version_spec -from physicsnemo.nn.functional.natten import na2d as _na2d_func +from physicsnemo.nn.functional.attention.neighborhood_attention import ( + na2d as _na2d_func, +) from physicsnemo.nn.module.drop import DropPath from physicsnemo.nn.module.hpx.tokenizer import ( HEALPixPatchDetokenizer, diff --git a/test/ci_tests/interrogate_baseline.txt b/test/ci_tests/interrogate_baseline.txt index eb3b1e1140..58edf9f986 100644 --- a/test/ci_tests/interrogate_baseline.txt +++ b/test/ci_tests/interrogate_baseline.txt @@ -591,10 +591,6 @@ physicsnemo/models/graphcast/utils/graph_backend.py:PyGGraphBackend.create_heter physicsnemo/models/graphcast/utils/graph_backend.py:PyGGraphBackend.khop_adj_all_k physicsnemo/models/pix2pix/pix2pixunet.py:Pix2PixUnet.load_networks physicsnemo/models/pix2pix/pix2pixunet.py:Pix2PixUnet.test -physicsnemo/nn/functional/equivariant_ops.py:legendre_polynomials -physicsnemo/nn/functional/equivariant_ops.py:legendre_polynomials -physicsnemo/nn/functional/equivariant_ops.py:smooth_log -physicsnemo/nn/functional/equivariant_ops.py:smooth_log physicsnemo/nn/functional/fourier_spectral/fft.py:IRFFT.make_inputs_backward physicsnemo/nn/functional/fourier_spectral/fft.py:IRFFT.make_inputs_forward physicsnemo/nn/functional/fourier_spectral/fft.py:IRFFT.torch_forward diff --git a/test/domain_parallel/ops/test_natten.py b/test/domain_parallel/ops/test_natten.py index f54fe2fcce..398f425ed8 100644 --- a/test/domain_parallel/ops/test_natten.py +++ b/test/domain_parallel/ops/test_natten.py @@ -16,9 +16,9 @@ r"""Tests for 1D, 2D, and 3D neighborhood attention on sharded tensors. -This module validates the correctness of :func:`physicsnemo.nn.functional.natten.na1d`, -:func:`physicsnemo.nn.functional.natten.na2d`, and -:func:`physicsnemo.nn.functional.natten.na3d` over sharded inputs, covering both +This module validates the correctness of :func:`physicsnemo.nn.functional.attention.neighborhood_attention.na1d`, +:func:`physicsnemo.nn.functional.attention.neighborhood_attention.na2d`, and +:func:`physicsnemo.nn.functional.attention.neighborhood_attention.na3d` over sharded inputs, covering both forward and backward passes. Sharding is performed over spatial dimensions which correspond to ``Shard(1)``, ``Shard(2)``, etc. in the natten heads-last layout. """ @@ -115,7 +115,7 @@ class TestNA1D: def test_na1d_shard_l( self, distributed_mesh, L, num_heads, head_dim, kernel_size, backward ): - from physicsnemo.nn.functional.natten import na1d + from physicsnemo.nn.functional.attention.neighborhood_attention import na1d _run_natten_check( na1d, @@ -148,7 +148,7 @@ class TestNA2D: def test_na2d_shard_h( self, distributed_mesh, H, W, num_heads, head_dim, kernel_size, backward ): - from physicsnemo.nn.functional.natten import na2d + from physicsnemo.nn.functional.attention.neighborhood_attention import na2d _run_natten_check( na2d, @@ -171,7 +171,7 @@ def test_na2d_shard_h( def test_na2d_shard_w( self, distributed_mesh, H, W, num_heads, head_dim, kernel_size, backward ): - from physicsnemo.nn.functional.natten import na2d + from physicsnemo.nn.functional.attention.neighborhood_attention import na2d _run_natten_check( na2d, @@ -205,7 +205,7 @@ class TestNA3D: def test_na3d_shard_x( self, distributed_mesh, X, Y, Z, num_heads, head_dim, kernel_size, backward ): - from physicsnemo.nn.functional.natten import na3d + from physicsnemo.nn.functional.attention.neighborhood_attention import na3d _run_natten_check( na3d, @@ -229,7 +229,7 @@ def test_na3d_shard_x( def test_na3d_shard_y( self, distributed_mesh, X, Y, Z, num_heads, head_dim, kernel_size, backward ): - from physicsnemo.nn.functional.natten import na3d + from physicsnemo.nn.functional.attention.neighborhood_attention import na3d _run_natten_check( na3d, @@ -253,7 +253,7 @@ def test_na3d_shard_y( def test_na3d_shard_z( self, distributed_mesh, X, Y, Z, num_heads, head_dim, kernel_size, backward ): - from physicsnemo.nn.functional.natten import na3d + from physicsnemo.nn.functional.attention.neighborhood_attention import na3d _run_natten_check( na3d, diff --git a/test/models/globe/test_utilities.py b/test/models/globe/test_utilities.py index f5a130d7ad..d1f0dfa7eb 100644 --- a/test/models/globe/test_utilities.py +++ b/test/models/globe/test_utilities.py @@ -17,7 +17,7 @@ import pytest import torch -from physicsnemo.nn.functional.equivariant_ops import legendre_polynomials +from physicsnemo.nn.functional import legendre_polynomials @pytest.mark.parametrize( diff --git a/test/nn/functional/test_natten.py b/test/nn/functional/attention/test_neighborhood_attention.py similarity index 69% rename from test/nn/functional/test_natten.py rename to test/nn/functional/attention/test_neighborhood_attention.py index eaeeaf91e6..b93bc25a2b 100644 --- a/test/nn/functional/test_natten.py +++ b/test/nn/functional/attention/test_neighborhood_attention.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -r"""Lightweight unit tests for :mod:`physicsnemo.nn.functional.natten`. +r"""Lightweight unit tests for :mod:`physicsnemo.nn.functional.attention.neighborhood_attention`. Validates that the ``na1d``, ``na2d``, and ``na3d`` wrappers: @@ -25,10 +25,18 @@ covers the entire spatial extent. """ +import importlib.util +from dataclasses import replace + import pytest import torch import torch.nn.functional as F +from physicsnemo.nn.functional.attention.neighborhood_attention import ( + NeighborhoodAttention1D, + NeighborhoodAttention2D, + NeighborhoodAttention3D, +) from test.conftest import requires_module # --------------------------------------------------------------------------- @@ -72,6 +80,92 @@ def _sdpa_reference(q, k, v): return out.permute(0, 2, 1, 3).reshape(*leading, heads, d) +@pytest.mark.parametrize( + "spec", + [ + NeighborhoodAttention1D, + NeighborhoodAttention2D, + NeighborhoodAttention3D, + ], +) +def test_natten_function_specs_make_inputs_forward(device, spec): + """NATTEN FunctionSpecs expose labeled forward benchmark cases.""" + cases = list(spec.make_inputs_forward(device=device)) + + assert len(cases) == len(spec._BENCHMARK_CASES) + assert [case[0] for case in cases] == [case[0] for case in spec._BENCHMARK_CASES] + + _label, args, kwargs = cases[0] + q, k, v, kernel_size = args + assert q.shape == k.shape == v.shape + assert isinstance(kernel_size, int) + assert isinstance(kwargs["dilation"], int) + + +@pytest.mark.parametrize( + "spec", + [ + NeighborhoodAttention1D, + NeighborhoodAttention2D, + NeighborhoodAttention3D, + ], +) +def test_natten_function_specs_make_inputs_backward(device, spec): + """NATTEN FunctionSpecs expose differentiable benchmark cases.""" + cases = list(spec.make_inputs_backward(device=device)) + + assert len(cases) == len(spec._BENCHMARK_CASES) + _label, args, _kwargs = cases[0] + q, k, v, _kernel_size = args + assert q.requires_grad + assert k.requires_grad + assert v.requires_grad + + +@requires_module("natten") +@pytest.mark.parametrize( + "spec", + [ + NeighborhoodAttention1D, + NeighborhoodAttention2D, + NeighborhoodAttention3D, + ], +) +def test_natten_function_spec_dispatch_matches_compare_contract(device, spec): + """FunctionSpec dispatch and compare hooks are valid for NATTEN.""" + _label, args, kwargs = next(iter(spec.make_inputs_forward(device=device))) + output = spec.dispatch(*args, implementation="natten", **kwargs) + reference = spec.dispatch(*args, implementation="natten", **kwargs) + + spec.compare_forward(output, reference) + + +def test_natten_missing_dependency_error_message(device): + """Missing NATTEN errors should keep the optional-dependency install hint.""" + if importlib.util.find_spec("natten") is not None: + pytest.skip("natten is installed") + + from physicsnemo.nn.functional.attention.neighborhood_attention import na1d + + q = torch.randn(1, 8, 1, 4, device=device) + with pytest.raises(ImportError, match="Missing optional dependency: natten"): + na1d(q, q, q, kernel_size=3) + + +def test_natten_version_mismatch_error_message(device, monkeypatch): + """Unavailable NATTEN versions should ask users to upgrade.""" + import physicsnemo.nn.functional.attention.neighborhood_attention as natten_functionals + from physicsnemo.nn.functional.attention.neighborhood_attention import na1d + + monkeypatch.setattr(natten_functionals, "get_installed_version", lambda _: "0.21.4") + impls = NeighborhoodAttention1D._get_impls() + monkeypatch.setitem(impls, "natten", replace(impls["natten"], available=False)) + + q = torch.randn(1, 8, 1, 4, device=device) + with pytest.raises(ImportError, match="natten>=0.21.5 is required"): + na1d(q, q, q, kernel_size=3) + + # --------------------------------------------------------------------------- # 1-D neighbourhood attention # --------------------------------------------------------------------------- @@ -79,7 +173,7 @@ def _sdpa_reference(q, k, v): @requires_module("natten") class TestNA1D: - """Unit tests for :func:`physicsnemo.nn.functional.natten.na1d`.""" + """Unit tests for :func:`physicsnemo.nn.functional.attention.neighborhood_attention.na1d`.""" @pytest.mark.parametrize("kernel_size", [3, 5]) @pytest.mark.parametrize("dilation", [1, 2]) @@ -87,7 +181,7 @@ def test_matches_natten_directly(self, device, kernel_size, dilation): """Wrapper output must be identical to ``natten.functional.na1d``.""" import natten.functional as nf - from physicsnemo.nn.functional.natten import na1d + from physicsnemo.nn.functional.attention.neighborhood_attention import na1d B, L, H, D = 2, 16, 4, 8 q = torch.randn(B, L, H, D, device=device) @@ -101,7 +195,7 @@ def test_matches_natten_directly(self, device, kernel_size, dilation): def test_output_shape(self, device): """Output shape must equal the query shape.""" - from physicsnemo.nn.functional.natten import na1d + from physicsnemo.nn.functional.attention.neighborhood_attention import na1d B, L, H, D = 1, 12, 2, 16 q = torch.randn(B, L, H, D, device=device) @@ -110,7 +204,7 @@ def test_output_shape(self, device): def test_backward(self, device): """Gradients must flow back through all three inputs.""" - from physicsnemo.nn.functional.natten import na1d + from physicsnemo.nn.functional.attention.neighborhood_attention import na1d B, L, H, D = 1, 12, 2, 8 q = torch.randn(B, L, H, D, device=device, requires_grad=True) @@ -126,7 +220,7 @@ def test_backward(self, device): def test_torch_function_dispatch(self, device): """``__torch_function__`` must be invoked for tensor subclasses.""" - from physicsnemo.nn.functional.natten import na1d + from physicsnemo.nn.functional.attention.neighborhood_attention import na1d B, L, H, D = 1, 8, 2, 8 q = torch.randn(B, L, H, D, device=device).as_subclass(_DispatchRecorder) @@ -140,7 +234,7 @@ def test_torch_function_dispatch(self, device): def test_full_window_matches_sdpa(self, device): """When kernel covers the entire sequence, NA degenerates to SDPA.""" - from physicsnemo.nn.functional.natten import na1d + from physicsnemo.nn.functional.attention.neighborhood_attention import na1d B, L, H, D = 2, 7, 2, 8 q = torch.randn(B, L, H, D, device=device, dtype=torch.float32) @@ -161,7 +255,7 @@ def test_full_window_matches_sdpa(self, device): @requires_module("natten") class TestNA2D: - """Unit tests for :func:`physicsnemo.nn.functional.natten.na2d`.""" + """Unit tests for :func:`physicsnemo.nn.functional.attention.neighborhood_attention.na2d`.""" @pytest.mark.parametrize("kernel_size", [3, 5]) @pytest.mark.parametrize("dilation", [1, 2]) @@ -169,7 +263,7 @@ def test_matches_natten_directly(self, device, kernel_size, dilation): """Wrapper output must be identical to ``natten.functional.na2d``.""" import natten.functional as nf - from physicsnemo.nn.functional.natten import na2d + from physicsnemo.nn.functional.attention.neighborhood_attention import na2d B, Ht, W, H, D = 2, 16, 16, 4, 8 q = torch.randn(B, Ht, W, H, D, device=device) @@ -183,7 +277,7 @@ def test_matches_natten_directly(self, device, kernel_size, dilation): def test_output_shape(self, device): """Output shape must equal the query shape.""" - from physicsnemo.nn.functional.natten import na2d + from physicsnemo.nn.functional.attention.neighborhood_attention import na2d B, Ht, W, H, D = 1, 6, 6, 2, 16 q = torch.randn(B, Ht, W, H, D, device=device) @@ -192,7 +286,7 @@ def test_output_shape(self, device): def test_backward(self, device): """Gradients must flow back through all three inputs.""" - from physicsnemo.nn.functional.natten import na2d + from physicsnemo.nn.functional.attention.neighborhood_attention import na2d B, Ht, W, H, D = 1, 6, 6, 2, 8 q = torch.randn(B, Ht, W, H, D, device=device, requires_grad=True) @@ -208,7 +302,7 @@ def test_backward(self, device): def test_torch_function_dispatch(self, device): """``__torch_function__`` must be invoked for tensor subclasses.""" - from physicsnemo.nn.functional.natten import na2d + from physicsnemo.nn.functional.attention.neighborhood_attention import na2d B, Ht, W, H, D = 1, 4, 4, 2, 8 q = torch.randn(B, Ht, W, H, D, device=device).as_subclass(_DispatchRecorder) @@ -222,7 +316,7 @@ def test_torch_function_dispatch(self, device): def test_full_window_matches_sdpa(self, device): """When kernel covers the full spatial extent, NA degenerates to SDPA.""" - from physicsnemo.nn.functional.natten import na2d + from physicsnemo.nn.functional.attention.neighborhood_attention import na2d B, Ht, W, H, D = 2, 5, 5, 2, 8 q = torch.randn(B, Ht, W, H, D, device=device, dtype=torch.float32) @@ -242,14 +336,14 @@ def test_full_window_matches_sdpa(self, device): @requires_module("natten") class TestNA3D: - """Unit tests for :func:`physicsnemo.nn.functional.natten.na3d`.""" + """Unit tests for :func:`physicsnemo.nn.functional.attention.neighborhood_attention.na3d`.""" @pytest.mark.parametrize("kernel_size", [3, 5]) def test_matches_natten_directly(self, device, kernel_size): """Wrapper output must be identical to ``natten.functional.na3d``.""" import natten.functional as nf - from physicsnemo.nn.functional.natten import na3d + from physicsnemo.nn.functional.attention.neighborhood_attention import na3d B, X, Y, Z, H, D = 1, 16, 16, 16, 2, 8 q = torch.randn(B, X, Y, Z, H, D, device=device) @@ -263,7 +357,7 @@ def test_matches_natten_directly(self, device, kernel_size): def test_output_shape(self, device): """Output shape must equal the query shape.""" - from physicsnemo.nn.functional.natten import na3d + from physicsnemo.nn.functional.attention.neighborhood_attention import na3d B, X, Y, Z, H, D = 1, 4, 4, 4, 2, 8 q = torch.randn(B, X, Y, Z, H, D, device=device) @@ -272,7 +366,7 @@ def test_output_shape(self, device): def test_backward(self, device): """Gradients must flow back through all three inputs.""" - from physicsnemo.nn.functional.natten import na3d + from physicsnemo.nn.functional.attention.neighborhood_attention import na3d B, X, Y, Z, H, D = 1, 4, 4, 4, 2, 8 q = torch.randn(B, X, Y, Z, H, D, device=device, requires_grad=True) @@ -288,7 +382,7 @@ def test_backward(self, device): def test_torch_function_dispatch(self, device): """``__torch_function__`` must be invoked for tensor subclasses.""" - from physicsnemo.nn.functional.natten import na3d + from physicsnemo.nn.functional.attention.neighborhood_attention import na3d B, X, Y, Z, H, D = 1, 4, 4, 4, 2, 8 q = torch.randn(B, X, Y, Z, H, D, device=device).as_subclass(_DispatchRecorder) @@ -302,7 +396,7 @@ def test_torch_function_dispatch(self, device): def test_full_window_matches_sdpa(self, device): """When kernel covers the full spatial extent, NA degenerates to SDPA.""" - from physicsnemo.nn.functional.natten import na3d + from physicsnemo.nn.functional.attention.neighborhood_attention import na3d B, X, Y, Z, H, D = 1, 5, 5, 5, 2, 8 q = torch.randn(B, X, Y, Z, H, D, device=device, dtype=torch.float32)