Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions physicsnemo/domain_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,17 @@


if ST_AVAILABLE:
# In minumum versions are met, we can import the shard tensor and spec.
# If minimum versions are met, we can import the shard tensor and spec.
# Import tensor op handlers here because they are valid on CPU and CUDA.

from ._shard_tensor_spec import ShardTensorSpec
from .custom_ops import _tensor_ops # noqa: F401
from .shard_tensor import ShardTensor, scatter_tensor

def register_custom_ops():
# These imports will register the custom ops with the ShardTensor class.
# It's done here to avoid an import cycle.
from .custom_ops import ( # noqa: F401
_tensor_ops,
mean_wrapper,
sum_wrapper,
)
from .custom_ops import mean_wrapper, sum_wrapper # noqa: F401
from .shard_utils import register_shard_wrappers

register_shard_wrappers()
Expand Down
214 changes: 213 additions & 1 deletion physicsnemo/domain_parallel/custom_ops/_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,220 @@
Shard,
)

from physicsnemo.domain_parallel import ShardTensor
from physicsnemo.domain_parallel._shard_tensor_spec import (
ShardTensorSpec,
_stride_from_contiguous_shape_C_style,
)
from physicsnemo.domain_parallel.shard_tensor import ShardTensor

aten = torch.ops.aten


def _extract_cross_inputs(
args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[ShardTensor, ShardTensor]:
r"""Extract and validate ShardTensor inputs for cross-product handlers."""
input_tensor = args[0] if len(args) > 0 else kwargs.get("input")
other_tensor = args[1] if len(args) > 1 else kwargs.get("other")

_validate_cross_inputs(input_tensor, other_tensor)
return input_tensor, other_tensor


def _validate_cross_inputs(input_tensor: Any, other_tensor: Any) -> None:
r"""Validate ShardTensor inputs for cross-product handlers."""
if not isinstance(input_tensor, ShardTensor) or not isinstance(
other_tensor, ShardTensor
):
raise RuntimeError(
"cross with ShardTensor inputs requires both arguments to be ShardTensor."
)

if input_tensor._spec.mesh != other_tensor._spec.mesh:
raise RuntimeError(
"cross requires both ShardTensor inputs to share the same device mesh."
)
if input_tensor._spec.placements != other_tensor._spec.placements:
raise RuntimeError(
"cross requires both ShardTensor inputs to have identical placements."
)


def _cross_result_from_local(
input_tensor: ShardTensor,
local_input: torch.Tensor,
local_result: torch.Tensor,
) -> ShardTensor:
r"""Wrap a local cross-product result with matching ShardTensor metadata."""
sharding_shapes = (
input_tensor._spec.sharding_shapes()
if local_result.shape == local_input.shape
else "infer"
)
return ShardTensor.from_local(
local_result,
input_tensor._spec.mesh,
input_tensor._spec.placements,
sharding_shapes=sharding_shapes,
)


def _normalize_cross_dim(
input_tensor: ShardTensor,
dim: int | None,
*,
allow_none: bool,
op_name: str,
) -> int:
r"""Return the global cross-product dimension and reject invalid vector dims."""
ndim = input_tensor.ndim
if dim is None:
if not allow_none:
raise TypeError(f"{op_name}(): argument 'dim' must be int, not NoneType")
try:
normalized_dim = next(
i for i, size in enumerate(input_tensor.shape) if size == 3
)
except StopIteration:
raise RuntimeError(
"torch.cross with ShardTensor dim=None requires an input dimension "
"of size 3."
)
else:
if not isinstance(dim, int):
raise TypeError(
f"{op_name}(): argument 'dim' must be int, not {type(dim).__name__}"
)
if dim < -ndim or dim >= ndim:
raise IndexError(
"Dimension out of range "
f"(expected to be in range of [{-ndim}, {ndim - 1}], but got {dim})"
)
normalized_dim = dim if dim >= 0 else dim + ndim

if any(
isinstance(placement, Shard) and placement.dim == normalized_dim
for placement in input_tensor._spec.placements
):
raise RuntimeError(
"cross with ShardTensor inputs is not supported along a sharded dimension."
)
return normalized_dim


def linalg_cross_wrapper(
func: Callable,
types: tuple[Any, ...],
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> ShardTensor:
r"""Functional wrapper for ``torch.linalg.cross`` on ShardTensor inputs."""
if kwargs is None:
kwargs = {}

if kwargs.get("out", None) is not None:
raise RuntimeError(
"torch.linalg.cross(out=...) is not supported for ShardTensor."
)

input_tensor, other_tensor = _extract_cross_inputs(args, kwargs)
dim = kwargs.get("dim", -1)
if len(args) > 2:
dim = args[2]
dim = _normalize_cross_dim(
input_tensor,
dim,
allow_none=False,
op_name="linalg_cross",
)

local_input = input_tensor.to_local()
local_result = torch.linalg.cross(
local_input,
other_tensor.to_local(),
dim=dim,
)
return _cross_result_from_local(input_tensor, local_input, local_result)


def _linalg_cross_dispatch(
input_tensor: ShardTensor,
other_tensor: ShardTensor,
*,
dim: int = -1,
) -> ShardTensor:
r"""ATen dispatch handler for ``aten.linalg_cross.default``."""
_validate_cross_inputs(input_tensor, other_tensor)
dim = _normalize_cross_dim(
input_tensor,
dim,
allow_none=False,
op_name="linalg_cross",
)
local_input = input_tensor.to_local()
local_result = torch.linalg.cross(
local_input,
other_tensor.to_local(),
dim=dim,
)
return _cross_result_from_local(input_tensor, local_input, local_result)


def cross_wrapper(
func: Callable,
types: tuple[Any, ...],
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> ShardTensor:
r"""Functional wrapper for ``torch.cross`` on ShardTensor inputs."""
if kwargs is None:
kwargs = {}

if kwargs.get("out", None) is not None:
raise RuntimeError("torch.cross(out=...) is not supported for ShardTensor.")

input_tensor, other_tensor = _extract_cross_inputs(args, kwargs)
dim = kwargs.get("dim", None)
if len(args) > 2:
dim = args[2]
dim = _normalize_cross_dim(
input_tensor,
dim,
allow_none=True,
op_name="cross",
)

local_input = input_tensor.to_local()
local_result = torch.cross(
local_input,
other_tensor.to_local(),
dim=dim,
)
return _cross_result_from_local(input_tensor, local_input, local_result)


def _cross_dispatch(
input_tensor: ShardTensor,
other_tensor: ShardTensor,
dim: int | None = None,
) -> ShardTensor:
r"""ATen dispatch handler for ``aten.cross.default``."""
_validate_cross_inputs(input_tensor, other_tensor)
dim = _normalize_cross_dim(
input_tensor,
dim,
allow_none=True,
op_name="cross",
)
local_input = input_tensor.to_local()
local_result = torch.cross(
local_input,
other_tensor.to_local(),
dim=dim,
)
return _cross_result_from_local(input_tensor, local_input, local_result)


def _unbind_output_metadata(
input_spec: ShardTensorSpec, dim: int
) -> tuple[int, list, dict[int, list[torch.Size]]]:
Expand Down Expand Up @@ -210,8 +415,15 @@ def unbind_wrapper(


# Python-level function handlers (__torch_function__).
ShardTensor.register_function_handler(torch.linalg.cross, linalg_cross_wrapper)
ShardTensor.register_function_handler(aten.linalg_cross.default, linalg_cross_wrapper)
ShardTensor.register_function_handler(torch.cross, cross_wrapper)
ShardTensor.register_function_handler(torch.Tensor.cross, cross_wrapper)
ShardTensor.register_function_handler(aten.cross.default, cross_wrapper)
ShardTensor.register_function_handler(torch.unbind, unbind_wrapper)
ShardTensor.register_function_handler(torch.Tensor.unbind, unbind_wrapper)

# ATen-level dispatch handler (__torch_dispatch__).
ShardTensor.register_dispatch_handler(aten.linalg_cross.default, _linalg_cross_dispatch)
ShardTensor.register_dispatch_handler(aten.cross.default, _cross_dispatch)
ShardTensor.register_dispatch_handler(aten.unbind.int, _unbind_dispatch)
8 changes: 8 additions & 0 deletions physicsnemo/mesh/boundaries/_facet_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from tensordict import TensorDict

from physicsnemo.mesh.utilities._index_tuple_ops import unique_index_tuples
from physicsnemo.mesh.utilities._scatter_ops import _materialize_shard_tensor
from physicsnemo.mesh.utilities._tolerances import safe_eps

if TYPE_CHECKING:
Expand Down Expand Up @@ -76,6 +77,11 @@ def _generate_combination_indices(n: int, k: int) -> Int[torch.Tensor, "n_choose
return torch.tensor(combos, dtype=torch.int64)


def _materialize_topology_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Gather sharded topology tensors before index-heavy topology operations."""
return _materialize_shard_tensor(tensor)


def categorize_facets_by_count(
candidate_facets: Int[torch.Tensor, "n_candidates n_vertices_per_facet"],
target_counts: list[int] | Literal["boundary", "shared", "interior", "all"] = "all",
Expand Down Expand Up @@ -243,6 +249,7 @@ def extract_candidate_facets(
>>> facets, parents = extract_candidate_facets(cells, manifold_codimension=2)
>>> assert facets.shape == (3, 1) # three vertices
"""
cells = _materialize_topology_tensor(cells)
n_cells, n_vertices_per_cell = cells.shape
n_vertices_per_subsimplex = n_vertices_per_cell - manifold_codimension

Expand Down Expand Up @@ -326,6 +333,7 @@ def _aggregate_tensor_data(

### Gather parent cell data for each candidate facet
# Shape: (n_candidate_facets, *data_shape)
parent_data = _materialize_topology_tensor(parent_data)
candidate_data = parent_data[parent_cell_indices]

### Use unified scatter aggregation utility
Expand Down
16 changes: 13 additions & 3 deletions physicsnemo/mesh/calculus/_lsq_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
import torch
from jaxtyping import Float

from physicsnemo.mesh.utilities._scatter_ops import (
_first_shard_tensor,
_materialize_shard_tensor,
_redistribute_like_template,
)

if TYPE_CHECKING:
from physicsnemo.mesh.mesh import Mesh

Expand Down Expand Up @@ -115,16 +121,20 @@ def compute_point_gradient_lsq(
mesh_lsq_gradient,
)

shard_template = _first_shard_tensor(point_values, mesh.points)
gradients = mesh_lsq_gradient(
points=mesh.points,
values=point_values,
points=_materialize_shard_tensor(mesh.points),
values=_materialize_shard_tensor(point_values),
neighbor_offsets=adjacency.offsets,
neighbor_indices=adjacency.indices,
weight_power=weight_power,
min_neighbors=min_neighbors,
implementation="torch",
)
return _to_mesh_gradient_layout(gradients, point_values)
gradients = _to_mesh_gradient_layout(gradients, point_values)
if shard_template is not None:
gradients = _redistribute_like_template(gradients, shard_template)
return gradients


def compute_cell_gradient_lsq(
Expand Down
Loading
Loading