-
Notifications
You must be signed in to change notification settings - Fork 669
mesh: enable ShardTensor support for mesh conversion/geometry paths #1608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
1f310c6
9911959
446d524
4d17ef2
009e6c3
559ad0c
0c9dc3d
e0c7dd3
2370257
0c6ffd8
175d87e
4eadb78
462c432
1980993
0e2684e
8bc39bf
ff70a94
3fc807c
92b6f31
45b3ca4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -908,6 +908,38 @@ def to_local( | |||||||||||||||||||
|
|
||||||||||||||||||||
| return _ToTorchTensor.apply(self, grad_placements) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def new_replicated_zeros( | ||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To add this function as an API call on shard tensor, vs supporting the underlying dispatch call, has to have a really good motivation. What is the value of this vs. supporting the backend of |
||||||||||||||||||||
| self, | ||||||||||||||||||||
| shape: Sequence[int] | torch.Size, | ||||||||||||||||||||
| *, | ||||||||||||||||||||
| dtype: torch.dtype | None = None, | ||||||||||||||||||||
| ) -> "ShardTensor": | ||||||||||||||||||||
| r"""Create a replicated zero tensor on this tensor's mesh. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| This is useful for reductions/accumulators where an op naturally produces | ||||||||||||||||||||
| a replicated output regardless of input placement. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Parameters | ||||||||||||||||||||
| ---------- | ||||||||||||||||||||
| shape : Sequence[int] or torch.Size | ||||||||||||||||||||
| Global shape of the output tensor. | ||||||||||||||||||||
| dtype : torch.dtype, optional | ||||||||||||||||||||
| Output dtype. Defaults to this tensor's dtype. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Returns | ||||||||||||||||||||
| ------- | ||||||||||||||||||||
| ShardTensor | ||||||||||||||||||||
| A replicated ShardTensor of zeros on the same mesh. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| out_dtype = self.dtype if dtype is None else dtype | ||||||||||||||||||||
| local = torch.zeros(tuple(shape), dtype=out_dtype, device=self.device) | ||||||||||||||||||||
| return ShardTensor.from_local( | ||||||||||||||||||||
| local, | ||||||||||||||||||||
| self._spec.mesh, | ||||||||||||||||||||
| [Replicate() for _ in range(self._spec.mesh.ndim)], | ||||||||||||||||||||
| sharding_shapes="infer", | ||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is unusual here: typically I'd object loudly to passing "infer" as sharding shapes since that can trigger a blocking allreduce and is a major perf headache. but you're replicating on all ranks. I don't think I understand this function's role, really. |
||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def full_tensor( | ||||||||||||||||||||
| self, *, grad_placements: Sequence[Placement] | None = None | ||||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||||
|
|
@@ -965,6 +997,71 @@ def backward(self, *args, **kwargs): | |||||||||||||||||||
| return self.to_local().backward(*args, **kwargs) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def replicated_zeros_like( | ||||||||||||||||||||
| tensor: torch.Tensor, | ||||||||||||||||||||
| shape: Sequence[int] | torch.Size, | ||||||||||||||||||||
| *, | ||||||||||||||||||||
| dtype: torch.dtype | None = None, | ||||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still think it makes more sense to support tl;dr the ShardTensor design philosophy is to make zero code changes on user side, whenever possible, so we support torch calls on shard tensors first and foremost rather then introduce new API. Is it possible to implement your work without this? |
||||||||||||||||||||
| r"""Create zeros matching a tensor's device/mesh semantics. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| For ``ShardTensor`` inputs this returns a replicated ``ShardTensor`` on the | ||||||||||||||||||||
| same mesh. For regular tensors this falls back to ``torch.zeros`` on the | ||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given that this is part of the public API, please add complete NumPy-style docstring |
||||||||||||||||||||
| input device. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| if isinstance(tensor, ShardTensor): | ||||||||||||||||||||
| return tensor.new_replicated_zeros(shape, dtype=dtype) | ||||||||||||||||||||
| out_dtype = tensor.dtype if dtype is None else dtype | ||||||||||||||||||||
| return torch.zeros(tuple(shape), dtype=out_dtype, device=tensor.device) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def _cross_wrapper(func, types, args, kwargs): | ||||||||||||||||||||
| if kwargs is None: | ||||||||||||||||||||
| kwargs = {} | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if kwargs.get("out", None) is not None: | ||||||||||||||||||||
| raise RuntimeError("torch.linalg.cross(out=...) is not supported for ShardTensor.") | ||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this is a function overload, it's in the wrong place. |
||||||||||||||||||||
|
|
||||||||||||||||||||
| input_tensor = args[0] if len(args) > 0 else kwargs.get("input") | ||||||||||||||||||||
| other_tensor = args[1] if len(args) > 1 else kwargs.get("other") | ||||||||||||||||||||
| dim = kwargs.get("dim", -1) | ||||||||||||||||||||
| if len(args) > 2: | ||||||||||||||||||||
| dim = args[2] | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if not isinstance(input_tensor, ShardTensor) or not isinstance( | ||||||||||||||||||||
| other_tensor, ShardTensor | ||||||||||||||||||||
| ): | ||||||||||||||||||||
| raise RuntimeError( | ||||||||||||||||||||
| "torch.linalg.cross with ShardTensor inputs requires both arguments to be ShardTensor." | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When For
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The error string says
Suggested change
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like what we want to be doing is implementing a wrapper for |
||||||||||||||||||||
|
|
||||||||||||||||||||
| if input_tensor._spec.mesh != other_tensor._spec.mesh: | ||||||||||||||||||||
| raise RuntimeError( | ||||||||||||||||||||
| "torch.linalg.cross requires both ShardTensor inputs to share the same device mesh." | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| if input_tensor._spec.placements != other_tensor._spec.placements: | ||||||||||||||||||||
| raise RuntimeError( | ||||||||||||||||||||
| "torch.linalg.cross requires both ShardTensor inputs to have identical placements." | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| local_result = torch.linalg.cross( | ||||||||||||||||||||
| input_tensor.to_local(), | ||||||||||||||||||||
| other_tensor.to_local(), | ||||||||||||||||||||
| dim=dim, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This appears to assume there aren't funky placements. But does this work if input and other are both sharded? Both replicated? What about the tensor_dim of sharding? Any constraints on sharding shapes? Is cross purely an op on the |
||||||||||||||||||||
| return ShardTensor.from_local( | ||||||||||||||||||||
| local_result, | ||||||||||||||||||||
| input_tensor._spec.mesh, | ||||||||||||||||||||
| input_tensor._spec.placements, | ||||||||||||||||||||
| sharding_shapes=input_tensor._spec.sharding_shapes(), | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| ShardTensor.register_function_handler(torch.linalg.cross, _cross_wrapper) | ||||||||||||||||||||
| if hasattr(torch, "cross"): | ||||||||||||||||||||
| ShardTensor.register_function_handler(torch.cross, _cross_wrapper) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def scatter_tensor( | ||||||||||||||||||||
| tensor: torch.Tensor, | ||||||||||||||||||||
| global_src: int, | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1463,15 +1463,14 @@ def cell_data_to_point_data(self, overwrite_keys: bool = False) -> "Mesh": | |
| # Shape: (n_cells * n_vertices_per_cell,) | ||
| point_indices = self.cells.flatten() | ||
|
|
||
| # Corresponding cell index for each point | ||
| # Shape: (n_cells * n_vertices_per_cell,) | ||
| cell_indices = torch.arange( | ||
| self.n_cells, device=self.points.device | ||
| ).repeat_interleave(n_vertices_per_cell) | ||
| # Repeat each cell value once per incident vertex. | ||
| # This avoids advanced indexing with an explicit cell-index tensor. | ||
|
|
||
| converted = self.cell_data.apply( | ||
| lambda cell_values: scatter_aggregate( | ||
| src_data=cell_values[cell_indices], | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indexing a ShardTensor with a ShardTensor index should work fine?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| src_data=cell_values.unsqueeze(1) | ||
| .expand(-1, n_vertices_per_cell, *cell_values.shape[1:]) | ||
| .reshape(-1, *cell_values.shape[1:]), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is Or is this just an unrelated change (and if so, what motivated this)?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Its because we needed a shardTensor to index hence the change. In the point_data_to_cell_data we use cells to index but that is already a ShardTensor. |
||
| src_to_dst_mapping=point_indices, | ||
| n_dst=self.n_points, | ||
| weights=None, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,9 +24,28 @@ | |
| import torch | ||
| from jaxtyping import Float, Int | ||
|
|
||
| from physicsnemo.domain_parallel import replicated_zeros_like | ||
| from physicsnemo.mesh.utilities._tolerances import safe_eps | ||
|
|
||
|
|
||
| def _is_sharded_tensor(tensor: torch.Tensor) -> bool: | ||
| return hasattr(tensor, "_spec") and hasattr(type(tensor), "from_local") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @coreyjadams is this best-practices for type-narrowing ShardTensor? Feels like (And if this does get replaced with |
||
|
|
||
|
|
||
| def _replicated_zeros_like( | ||
| tensor: torch.Tensor, shape: tuple[int, ...], dtype: torch.dtype | ||
| ) -> torch.Tensor: | ||
| if not _is_sharded_tensor(tensor) or replicated_zeros_like is None: | ||
| return torch.zeros(shape, dtype=dtype, device=tensor.device) | ||
|
|
||
| # Delegate replicated temporary allocation to the ShardTensor layer. | ||
| return replicated_zeros_like( | ||
| tensor, | ||
| shape, | ||
| dtype=dtype, | ||
| ) | ||
|
|
||
|
|
||
| def scatter_aggregate( | ||
| src_data: Float[torch.Tensor, "n_src ..."], | ||
| src_to_dst_mapping: Int[torch.Tensor, " n_src"], | ||
|
|
@@ -93,7 +112,9 @@ def scatter_aggregate( | |
|
|
||
| ### Fast path: unweighted sum is a single scatter_add_ with no extra work | ||
| if weights is None and aggregation == "sum": | ||
| aggregated_data = torch.zeros((n_dst, *data_shape), dtype=dtype, device=device) | ||
| aggregated_data = _replicated_zeros_like( | ||
| src_data, (n_dst, *data_shape), dtype | ||
| ) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The distributed-friendly update to make here is, if mesh can support it, to port |
||
| expanded_indices = src_to_dst_mapping.view( | ||
| -1, *([1] * len(data_shape)) | ||
| ).expand_as(src_data) | ||
|
|
@@ -102,7 +123,10 @@ def scatter_aggregate( | |
|
|
||
| ### Initialize weights if not provided | ||
| if weights is None: | ||
| weights = torch.ones(len(src_to_dst_mapping), dtype=dtype, device=device) | ||
| if _is_sharded_tensor(src_data): | ||
| weights = torch.ones_like(src_to_dst_mapping, dtype=dtype) | ||
| else: | ||
| weights = torch.ones(len(src_to_dst_mapping), dtype=dtype, device=device) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is much closer to the "right" way for domain parallelism, but in fact we can probably consolidate to just "ones_like" for all paths. |
||
|
|
||
| ### Ensure weights have same dtype as data (avoid dtype mismatch in multiplication) | ||
| if weights.dtype != dtype: | ||
|
|
@@ -114,11 +138,7 @@ def scatter_aggregate( | |
| weighted_data = src_data * weights.view(weight_shape) | ||
|
|
||
| ### Scatter-add weighted data to destinations | ||
| aggregated_data = torch.zeros( | ||
| (n_dst, *data_shape), | ||
| dtype=dtype, | ||
| device=device, | ||
| ) | ||
| aggregated_data = _replicated_zeros_like(src_data, (n_dst, *data_shape), dtype) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here re:zeros_like |
||
|
|
||
| # Expand src_to_dst_mapping to match data dimensions | ||
| expanded_indices = src_to_dst_mapping.view(-1, *([1] * len(data_shape))).expand_as( | ||
|
|
@@ -134,7 +154,7 @@ def scatter_aggregate( | |
| ### Normalize weighted sum to weighted mean | ||
| if aggregation == "mean": | ||
| ### Compute sum of weights at each destination | ||
| weight_sums = torch.zeros(n_dst, dtype=dtype, device=device) | ||
| weight_sums = _replicated_zeros_like(src_data, (n_dst,), dtype) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here re:zeros_like |
||
| weight_sums.scatter_add_( | ||
| dim=0, | ||
| index=src_to_dst_mapping, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @coreyjadams, if I am adding the zeros_like correct I might add a few similar ops just consistency even though they are not needed.