Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
*.err
*.out
runs/
outputs/

Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ scheduler:

optimizer:
_target_: torch.optim.AdamW
lr: 1.0e-3
lr: 3.0e-4
weight_decay: 1.0e-4
betas: [0.9, 0.999]
eps: 1.0e-8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ precision: float32 # float32, float16, bfloat16, or float8
compile: true
profile: false

# Domain parallelism: number of GPUs collaborating on one sample.
# world_size must be divisible by this. Set to 1 to disable.
domain_parallel_size: 1

data:
include_sdf: false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@ run_id: "volume/float32"

# Performance considerations:
precision: float32 # float32, float16, bfloat16, or float8
compile: true
compile: false
profile: false

# Domain parallelism: number of GPUs collaborating on one sample.
# world_size must be divisible by this. Set to 1 to disable.
domain_parallel_size: 1

model:
out_dim: 5
embedding_dim: 7
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,44 @@


def all_reduce_dict(
metrics: dict[str, torch.Tensor], dm: DistributedManager
metrics: dict[str, torch.Tensor],
dm: DistributedManager,
data_mesh=None,
) -> dict[str, torch.Tensor]:
"""
Reduces a dictionary of metrics across all distributed processes.
Reduces a dictionary of metrics across data-parallel replicas.

When *data_mesh* is provided the reduction runs only over the
data-parallel dimension (so domain-parallel ranks are not
double-counted). Otherwise falls back to the full world.

Args:
metrics: Dictionary of metric names to torch.Tensor values.
dm: DistributedManager instance for distributed context.
data_mesh: Optional DeviceMesh for the data-parallel dimension.

Returns:
Dictionary of reduced metrics.
"""
# TODO - update this to use domains and not the full world
if data_mesh is not None:
num_replicas = data_mesh.size()
else:
num_replicas = dm.world_size

if dm.world_size == 1:
if num_replicas <= 1:
return metrics

# Pack the metrics together:
merged_metrics = torch.stack(list(metrics.values()), dim=-1)

dist.all_reduce(merged_metrics)
merged_metrics = merged_metrics / dm.world_size
# ShardTensors (from domain parallelism) must be materialised before
# the plain dist.all_reduce call.
if isinstance(merged_metrics, ShardTensor):
merged_metrics = merged_metrics.full_tensor()

group = data_mesh.get_group() if data_mesh is not None else None
dist.all_reduce(merged_metrics, group=group)
merged_metrics = merged_metrics / num_replicas

# Unstack metrics:
metrics = {key: merged_metrics[i] for i, key in enumerate(metrics.keys())}
Expand All @@ -57,6 +73,7 @@ def metrics_fn(
target: torch.Tensor,
dm: DistributedManager,
mode: str,
data_mesh=None,
) -> dict[str, torch.Tensor]:
"""
Computes metrics for either surface or volume data.
Expand All @@ -67,6 +84,7 @@ def metrics_fn(
others: Dictionary containing normalization statistics.
dm: DistributedManager instance for distributed context.
mode: Either "surface" or "volume".
data_mesh: Optional DeviceMesh for the data-parallel dimension.

Returns:
Dictionary of computed metrics.
Expand All @@ -79,7 +97,7 @@ def metrics_fn(
else:
raise ValueError(f"Unknown data mode: {mode}")

metrics = all_reduce_dict(metrics, dm)
metrics = all_reduce_dict(metrics, dm, data_mesh=data_mesh)
return metrics


Expand Down
Loading