-
Notifications
You must be signed in to change notification settings - Fork 670
mesh: enable ShardTensor support for mesh conversion/geometry paths #1608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 12 commits
1f310c6
9911959
446d524
4d17ef2
009e6c3
559ad0c
0c9dc3d
e0c7dd3
2370257
0c6ffd8
175d87e
4eadb78
462c432
1980993
0e2684e
8bc39bf
ff70a94
3fc807c
92b6f31
45b3ca4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1461,15 +1461,15 @@ def cell_data_to_point_data(self, overwrite_keys: bool = False) -> "Mesh": | |
| # Shape: (n_cells * n_vertices_per_cell,) | ||
| point_indices = self.cells.flatten() | ||
|
|
||
| # Corresponding cell index for each point | ||
| # Shape: (n_cells * n_vertices_per_cell,) | ||
| cell_indices = torch.arange( | ||
| self.n_cells, device=self.points.device | ||
| ).repeat_interleave(n_vertices_per_cell) | ||
| # Repeat each cell value once per incident vertex. This avoids mixing a | ||
| # ShardTensor data field with a generated dense cell-index tensor. | ||
|
|
||
| converted = self.cell_data.apply( | ||
| lambda cell_values: scatter_aggregate( | ||
| src_data=cell_values[cell_indices], | ||
| src_data=cell_values.unsqueeze(1) | ||
| .expand(-1, n_vertices_per_cell, *cell_values.shape[1:]) | ||
| .clone() | ||
| .reshape(-1, *cell_values.shape[1:]), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is Or is this just an unrelated change (and if so, what motivated this)?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Its because we needed a shardTensor to index hence the change. In the point_data_to_cell_data we use cells to index but that is already a ShardTensor. |
||
| src_to_dst_mapping=point_indices, | ||
| n_dst=self.n_points, | ||
| weights=None, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indexing a ShardTensor with a ShardTensor index should work fine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cell_indicesis not a ShardTensor the way it was before. I was coming from just thetorch.arangefunction