Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion physicsnemo/domain_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
# In minumum versions are met, we can import the shard tensor and spec.

from ._shard_tensor_spec import ShardTensorSpec
from .shard_tensor import ShardTensor, scatter_tensor
from .shard_tensor import ShardTensor, replicated_zeros_like, scatter_tensor
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @coreyjadams, if I am adding the zeros_like correct I might add a few similar ops just consistency even though they are not needed.


def register_custom_ops():
# These imports will register the custom ops with the ShardTensor class.
Expand All @@ -69,3 +69,4 @@ def register_custom_ops():
ShardTensor = None
ShardTensorSpec = None
scatter_tensor = None
replicated_zeros_like = None
97 changes: 97 additions & 0 deletions physicsnemo/domain_parallel/shard_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,38 @@ def to_local(

return _ToTorchTensor.apply(self, grad_placements)

def new_replicated_zeros(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add this function as an API call on shard tensor, vs supporting the underlying dispatch call, has to have a really good motivation.

What is the value of this vs. supporting the backend of torch.zeros_like(a) when a is a shard tensor? and, in fact, I think that should already work?

self,
shape: Sequence[int] | torch.Size,
*,
dtype: torch.dtype | None = None,
) -> "ShardTensor":
r"""Create a replicated zero tensor on this tensor's mesh.

This is useful for reductions/accumulators where an op naturally produces
a replicated output regardless of input placement.

Parameters
----------
shape : Sequence[int] or torch.Size
Global shape of the output tensor.
dtype : torch.dtype, optional
Output dtype. Defaults to this tensor's dtype.

Returns
-------
ShardTensor
A replicated ShardTensor of zeros on the same mesh.
"""
out_dtype = self.dtype if dtype is None else dtype
local = torch.zeros(tuple(shape), dtype=out_dtype, device=self.device)
return ShardTensor.from_local(
local,
self._spec.mesh,
[Replicate() for _ in range(self._spec.mesh.ndim)],
sharding_shapes="infer",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unusual here: typically I'd object loudly to passing "infer" as sharding shapes since that can trigger a blocking allreduce and is a major perf headache. but you're replicating on all ranks. I don't think I understand this function's role, really.

)

def full_tensor(
self, *, grad_placements: Sequence[Placement] | None = None
) -> torch.Tensor:
Expand Down Expand Up @@ -965,6 +997,71 @@ def backward(self, *args, **kwargs):
return self.to_local().backward(*args, **kwargs)


def replicated_zeros_like(
tensor: torch.Tensor,
shape: Sequence[int] | torch.Size,
*,
dtype: torch.dtype | None = None,
) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think it makes more sense to support torch.zeros_like(a) for a as a replicated tensor.

tl;dr the ShardTensor design philosophy is to make zero code changes on user side, whenever possible, so we support torch calls on shard tensors first and foremost rather then introduce new API. Is it possible to implement your work without this?

r"""Create zeros matching a tensor's device/mesh semantics.

For ``ShardTensor`` inputs this returns a replicated ``ShardTensor`` on the
same mesh. For regular tensors this falls back to ``torch.zeros`` on the
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that this is part of the public API, please add complete NumPy-style docstring

input device.
"""
if isinstance(tensor, ShardTensor):
return tensor.new_replicated_zeros(shape, dtype=dtype)
out_dtype = tensor.dtype if dtype is None else dtype
return torch.zeros(tuple(shape), dtype=out_dtype, device=tensor.device)


def _cross_wrapper(func, types, args, kwargs):
if kwargs is None:
kwargs = {}

if kwargs.get("out", None) is not None:
raise RuntimeError("torch.linalg.cross(out=...) is not supported for ShardTensor.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is a function overload, it's in the wrong place. shard_tensor.py is for the core tensor object only. There is a sub folder for ops.


input_tensor = args[0] if len(args) > 0 else kwargs.get("input")
other_tensor = args[1] if len(args) > 1 else kwargs.get("other")
dim = kwargs.get("dim", -1)
if len(args) > 2:
dim = args[2]

if not isinstance(input_tensor, ShardTensor) or not isinstance(
other_tensor, ShardTensor
):
raise RuntimeError(
"torch.linalg.cross with ShardTensor inputs requires both arguments to be ShardTensor."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 torch.cross dim defaulting diverges from original behavior

When _cross_wrapper is invoked via the torch.cross handler (registered on line 1062) and the caller omits dim, the wrapper defaults to dim=-1. However, torch.cross (pre-deprecation) auto-detects the first dimension of size 3, which may not be the last dimension. Any call like torch.cross(a_shard, b_shard) where the cross-product axis isn't the last one will silently produce a wrong result instead of matching the original op's semantics.

For torch.linalg.cross, dim is keyword-only and defaults to -1, so the current default is correct there. For the torch.cross handler, consider either (a) raising explicitly when dim is absent to force callers to use the unambiguous form, or (b) documenting that only the torch.linalg.cross semantic (dim=-1) is supported.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Hardcoded error message for both torch.cross and torch.linalg.cross

The error string says "torch.linalg.cross with ShardTensor inputs…" but this wrapper is registered for both torch.linalg.cross (line 1060) and torch.cross (line 1062). When the handler fires via torch.cross, the message will mislead users.

Suggested change
if not isinstance(input_tensor, ShardTensor) or not isinstance(
other_tensor, ShardTensor
):
raise RuntimeError(
"torch.linalg.cross with ShardTensor inputs requires both arguments to be ShardTensor."
)
raise RuntimeError(
f"{func.__module__}.{func.__name__} with ShardTensor inputs requires both arguments to be ShardTensor."
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like what we want to be doing is implementing a wrapper for torch.linalg.cross on shard tensor objects. There is not a need to check that all objects are ShardTensor at this time.


if input_tensor._spec.mesh != other_tensor._spec.mesh:
raise RuntimeError(
"torch.linalg.cross requires both ShardTensor inputs to share the same device mesh."
)
if input_tensor._spec.placements != other_tensor._spec.placements:
raise RuntimeError(
"torch.linalg.cross requires both ShardTensor inputs to have identical placements."
)

local_result = torch.linalg.cross(
input_tensor.to_local(),
other_tensor.to_local(),
dim=dim,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to assume there aren't funky placements. But does this work if input and other are both sharded? Both replicated? What about the tensor_dim of sharding? Any constraints on sharding shapes? Is cross purely an op on the dim channel, and changes the dim channel?

return ShardTensor.from_local(
local_result,
input_tensor._spec.mesh,
input_tensor._spec.placements,
sharding_shapes=input_tensor._spec.sharding_shapes(),
)


ShardTensor.register_function_handler(torch.linalg.cross, _cross_wrapper)
if hasattr(torch, "cross"):
ShardTensor.register_function_handler(torch.cross, _cross_wrapper)


def scatter_tensor(
tensor: torch.Tensor,
global_src: int,
Expand Down
11 changes: 5 additions & 6 deletions physicsnemo/mesh/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,15 +1463,14 @@ def cell_data_to_point_data(self, overwrite_keys: bool = False) -> "Mesh":
# Shape: (n_cells * n_vertices_per_cell,)
point_indices = self.cells.flatten()

# Corresponding cell index for each point
# Shape: (n_cells * n_vertices_per_cell,)
cell_indices = torch.arange(
self.n_cells, device=self.points.device
).repeat_interleave(n_vertices_per_cell)
# Repeat each cell value once per incident vertex.
# This avoids advanced indexing with an explicit cell-index tensor.

converted = self.cell_data.apply(
lambda cell_values: scatter_aggregate(
src_data=cell_values[cell_indices],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indexing a ShardTensor with a ShardTensor index should work fine?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cell_indices is not a ShardTensor the way it was before. I was coming from just the torch.arange function

src_data=cell_values.unsqueeze(1)
.expand(-1, n_vertices_per_cell, *cell_values.shape[1:])
.reshape(-1, *cell_values.shape[1:]),
Copy link
Copy Markdown
Collaborator

@peterdsharpe peterdsharpe Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is Mesh.cell_data_to_point_data() updated, but Mesh.point_data_to_cell_data() is not? The asymmetry seems suspect - seems like they should either both need updates for ShardTensor or neither should, but I might be missing something?

Or is this just an unrelated change (and if so, what motivated this)?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its because we needed a shardTensor to index hence the change. In the point_data_to_cell_data we use cells to index but that is already a ShardTensor.

src_to_dst_mapping=point_indices,
n_dst=self.n_points,
weights=None,
Expand Down
36 changes: 28 additions & 8 deletions physicsnemo/mesh/utilities/_scatter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,28 @@
import torch
from jaxtyping import Float, Int

from physicsnemo.domain_parallel import replicated_zeros_like
from physicsnemo.mesh.utilities._tolerances import safe_eps


def _is_sharded_tensor(tensor: torch.Tensor) -> bool:
return hasattr(tensor, "_spec") and hasattr(type(tensor), "from_local")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@coreyjadams is this best-practices for type-narrowing ShardTensor?

Feels like isinstance(tensor, ShardTensor) would be better, unless there's something I'm missing?

(And if this does get replaced with isinstance, then this can be inlined rather than keeping a separate function.)



def _replicated_zeros_like(
tensor: torch.Tensor, shape: tuple[int, ...], dtype: torch.dtype
) -> torch.Tensor:
if not _is_sharded_tensor(tensor) or replicated_zeros_like is None:
return torch.zeros(shape, dtype=dtype, device=tensor.device)

# Delegate replicated temporary allocation to the ShardTensor layer.
return replicated_zeros_like(
tensor,
shape,
dtype=dtype,
)


def scatter_aggregate(
src_data: Float[torch.Tensor, "n_src ..."],
src_to_dst_mapping: Int[torch.Tensor, " n_src"],
Expand Down Expand Up @@ -93,7 +112,9 @@ def scatter_aggregate(

### Fast path: unweighted sum is a single scatter_add_ with no extra work
if weights is None and aggregation == "sum":
aggregated_data = torch.zeros((n_dst, *data_shape), dtype=dtype, device=device)
aggregated_data = _replicated_zeros_like(
src_data, (n_dst, *data_shape), dtype
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distributed-friendly update to make here is, if mesh can support it, to port torch.zeros away from this shape and onto torch.zeros_like. So whatever is building data_shape as input, we build zeros_like(that_object) and then mesh can work on single device and sharded inputs too.

expanded_indices = src_to_dst_mapping.view(
-1, *([1] * len(data_shape))
).expand_as(src_data)
Expand All @@ -102,7 +123,10 @@ def scatter_aggregate(

### Initialize weights if not provided
if weights is None:
weights = torch.ones(len(src_to_dst_mapping), dtype=dtype, device=device)
if _is_sharded_tensor(src_data):
weights = torch.ones_like(src_to_dst_mapping, dtype=dtype)
else:
weights = torch.ones(len(src_to_dst_mapping), dtype=dtype, device=device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much closer to the "right" way for domain parallelism, but in fact we can probably consolidate to just "ones_like" for all paths.


### Ensure weights have same dtype as data (avoid dtype mismatch in multiplication)
if weights.dtype != dtype:
Expand All @@ -114,11 +138,7 @@ def scatter_aggregate(
weighted_data = src_data * weights.view(weight_shape)

### Scatter-add weighted data to destinations
aggregated_data = torch.zeros(
(n_dst, *data_shape),
dtype=dtype,
device=device,
)
aggregated_data = _replicated_zeros_like(src_data, (n_dst, *data_shape), dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here re:zeros_like


# Expand src_to_dst_mapping to match data dimensions
expanded_indices = src_to_dst_mapping.view(-1, *([1] * len(data_shape))).expand_as(
Expand All @@ -134,7 +154,7 @@ def scatter_aggregate(
### Normalize weighted sum to weighted mean
if aggregation == "mean":
### Compute sum of weights at each destination
weight_sums = torch.zeros(n_dst, dtype=dtype, device=device)
weight_sums = _replicated_zeros_like(src_data, (n_dst,), dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here re:zeros_like

weight_sums.scatter_add_(
dim=0,
index=src_to_dst_mapping,
Expand Down
Loading
Loading