Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions CODING_STANDARDS/FUNCTIONAL_APIS.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ This document is structured in two main sections:

| Rule ID | Summary | Apply When |
|---------|---------|------------|
| [`FNC-000`](#fnc-000-functionals-must-use-functionspec) | Functionals must use FunctionSpec | Creating new functional APIs |
| [`FNC-000`](#fnc-000-functionals-must-use-functionspec) | Functionals must use FunctionSpec unless they are lightweight tensor helpers | Creating new functional APIs |
| [`FNC-001`](#fnc-001-functional-location-and-public-api) | Functional location and public API | Organizing or exporting functionals |
| [`FNC-002`](#fnc-002-file-layout-for-functionals) | File layout for functionals | Adding or refactoring functional files |
| [`FNC-003`](#fnc-003-registration-and-dispatch-rules) | Registration and dispatch rules | Registering implementations |
Expand All @@ -71,15 +71,30 @@ This document is structured in two main sections:

**Description:**

All functionals must be implemented with `FunctionSpec`, even if only a single
implementation exists. This ensures the operation participates in validation
and benchmarking through input generators and `compare_forward` (and
All functionals with backend dispatch, optional accelerated implementations, or
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically going to make some space for if we have very simple python functions and dont want to implement the whole FunctionSpec stuff

meaningful benchmark coverage must be implemented with `FunctionSpec`, even if
only a single implementation exists. This ensures the operation participates in
validation and benchmarking through input generators and `compare_forward` (and
`compare_backward` where needed).

Small pure-PyTorch tensor helpers can remain plain functions when all of the
following are true:

- The implementation is a thin composition of PyTorch tensor operations.
- There is no optional backend, custom kernel, or dispatch-selection behavior.
- Benchmarking the helper independently would not provide actionable
performance data.
- The function has focused tests or coverage through its owning feature area.

When a helper later grows an alternate backend, optional dependency, or
performance-sensitive implementation, convert it to `FunctionSpec`.

**Rationale:**

`FunctionSpec` provides a consistent structure for backend registration,
selection, benchmarking and verification across the codebase.
selection, benchmarking and verification across the codebase. The lightweight
helper exception avoids adding ceremony to simple tensor algebra that has no
backend-selection or benchmark surface.

**Example:**

Expand Down
17 changes: 16 additions & 1 deletion benchmarks/physicsnemo/nn/functional/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
"""Registry of FunctionSpec classes to benchmark with ASV."""

from physicsnemo.core.function_spec import FunctionSpec
from physicsnemo.nn.functional.attention.neighborhood_attention import (
NeighborhoodAttention1D,
NeighborhoodAttention2D,
NeighborhoodAttention3D,
)
from physicsnemo.nn.functional.derivatives import (
MeshGreenGaussGradient,
MeshlessFDDerivatives,
Expand All @@ -34,7 +39,11 @@
Real,
ViewAsComplex,
)
from physicsnemo.nn.functional.geometry import SignedDistanceField
from physicsnemo.nn.functional.geometry import (
MeshPoissonDiskSample,
MeshToVoxelFraction,
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of the functionals got left out so adding them back in

SignedDistanceField,
)
from physicsnemo.nn.functional.interpolation import (
GridToPointInterpolation,
PointToGridInterpolation,
Expand Down Expand Up @@ -62,6 +71,8 @@
SpectralGridGradient,
MeshlessFDDerivatives,
# Geometry.
MeshPoissonDiskSample,
MeshToVoxelFraction,
SignedDistanceField,
# Interpolation.
GridToPointInterpolation,
Expand All @@ -74,6 +85,10 @@
ViewAsComplex,
Real,
Imag,
# Neighborhood attention.
NeighborhoodAttention1D,
NeighborhoodAttention2D,
NeighborhoodAttention3D,
)

__all__ = ["FUNCTIONAL_SPECS"]
17 changes: 17 additions & 0 deletions docs/api/nn/functionals/neighborhood_attention.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
Neighborhood Attention Functionals
==================================

NATTEN 1D
---------

.. autofunction:: physicsnemo.nn.functional.na1d

NATTEN 2D
---------

.. autofunction:: physicsnemo.nn.functional.na2d

NATTEN 3D
---------

.. autofunction:: physicsnemo.nn.functional.na3d
1 change: 1 addition & 0 deletions docs/api/physicsnemo.nn.functionals.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ in the documentation for performance comparisons.
nn/functionals/fourier_spectral
nn/functionals/regularization_parameterization
nn/functionals/interpolation
nn/functionals/neighborhood_attention
17 changes: 12 additions & 5 deletions physicsnemo/domain_parallel/shard_utils/natten_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
MissingShardPatch,
UndeterminedShardingError,
)
from physicsnemo.nn.functional.natten import na1d, na2d, na3d
from physicsnemo.nn.functional.attention.neighborhood_attention import na1d, na2d, na3d

_natten = OptionalImport("natten")
_raw_func_map = {
Expand Down Expand Up @@ -221,9 +221,9 @@ def _natten_wrapper(
r"""Shared wrapper for natten functions to support sharded tensors.

Registered with :meth:`ShardTensor.register_function_handler` so that calls
to :func:`~physicsnemo.nn.functional.natten.na1d`,
:func:`~physicsnemo.nn.functional.natten.na2d`, or
:func:`~physicsnemo.nn.functional.natten.na3d` automatically route through
to :func:`~physicsnemo.nn.functional.attention.neighborhood_attention.na1d`,
:func:`~physicsnemo.nn.functional.attention.neighborhood_attention.na2d`, or
:func:`~physicsnemo.nn.functional.attention.neighborhood_attention.na3d` automatically route through
this handler when any argument is a :class:`ShardTensor`.

Parameters
Expand All @@ -250,7 +250,14 @@ def _natten_wrapper(
q, k, v, kernel_size = args[0], args[1], args[2], args[3]

dilation = kwargs.get("dilation", 1)
natten_kwargs = {_k: _v for _k, _v in kwargs.items() if _k != "dilation"}
implementation = kwargs.get("implementation")
if implementation not in (None, "natten"):
raise KeyError(
f"No implementation named '{implementation}' for neighborhood attention"
)
natten_kwargs = {
_k: _v for _k, _v in kwargs.items() if _k not in ("dilation", "implementation")
}

if all(type(_t) is torch.Tensor for _t in (q, k, v)):
return func(
Expand Down
2 changes: 1 addition & 1 deletion physicsnemo/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .attention import na1d, na2d, na3d
from .derivatives import (
mesh_green_gauss_gradient,
mesh_lsq_gradient,
Expand All @@ -40,7 +41,6 @@
interpolation,
point_to_grid_interpolation,
)
from .natten import na1d, na2d, na3d
from .neighbors import knn, radius_search
from .regularization_parameterization import drop_path, weight_fact

Expand Down
33 changes: 33 additions & 0 deletions physicsnemo/nn/functional/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .neighborhood_attention import (
NeighborhoodAttention1D,
NeighborhoodAttention2D,
NeighborhoodAttention3D,
na1d,
na2d,
na3d,
)

__all__ = [
"NeighborhoodAttention1D",
"NeighborhoodAttention2D",
"NeighborhoodAttention3D",
"na1d",
"na2d",
"na3d",
]
Loading
Loading