diff --git a/examples/cfd/external_aerodynamics/transformer_models/.gitignore b/examples/cfd/external_aerodynamics/transformer_models/.gitignore new file mode 100644 index 0000000000..a18185d7fe --- /dev/null +++ b/examples/cfd/external_aerodynamics/transformer_models/.gitignore @@ -0,0 +1,5 @@ +*.err +*.out +runs/ +outputs/ + diff --git a/examples/cfd/external_aerodynamics/transformer_models/src/conf/training/base.yaml b/examples/cfd/external_aerodynamics/transformer_models/src/conf/training/base.yaml index a562499f3d..a0cee391ab 100644 --- a/examples/cfd/external_aerodynamics/transformer_models/src/conf/training/base.yaml +++ b/examples/cfd/external_aerodynamics/transformer_models/src/conf/training/base.yaml @@ -25,7 +25,7 @@ scheduler: optimizer: _target_: torch.optim.AdamW - lr: 1.0e-3 + lr: 3.0e-4 weight_decay: 1.0e-4 betas: [0.9, 0.999] eps: 1.0e-8 diff --git a/examples/cfd/external_aerodynamics/transformer_models/src/conf/transolver_surface.yaml b/examples/cfd/external_aerodynamics/transformer_models/src/conf/transolver_surface.yaml index 1d19b3d74e..5ec93719fb 100644 --- a/examples/cfd/external_aerodynamics/transformer_models/src/conf/transolver_surface.yaml +++ b/examples/cfd/external_aerodynamics/transformer_models/src/conf/transolver_surface.yaml @@ -28,6 +28,10 @@ precision: float32 # float32, float16, bfloat16, or float8 compile: true profile: false +# Domain parallelism: number of GPUs collaborating on one sample. +# world_size must be divisible by this. Set to 1 to disable. +domain_parallel_size: 1 + data: include_sdf: false diff --git a/examples/cfd/external_aerodynamics/transformer_models/src/conf/transolver_volume.yaml b/examples/cfd/external_aerodynamics/transformer_models/src/conf/transolver_volume.yaml index 2480b8395e..488d4d5ff9 100644 --- a/examples/cfd/external_aerodynamics/transformer_models/src/conf/transolver_volume.yaml +++ b/examples/cfd/external_aerodynamics/transformer_models/src/conf/transolver_volume.yaml @@ -25,9 +25,13 @@ run_id: "volume/float32" # Performance considerations: precision: float32 # float32, float16, bfloat16, or float8 -compile: true +compile: false profile: false +# Domain parallelism: number of GPUs collaborating on one sample. +# world_size must be divisible by this. Set to 1 to disable. +domain_parallel_size: 1 + model: out_dim: 5 embedding_dim: 7 diff --git a/examples/cfd/external_aerodynamics/transformer_models/src/metrics.py b/examples/cfd/external_aerodynamics/transformer_models/src/metrics.py index cacbda7988..1adb37ffe4 100644 --- a/examples/cfd/external_aerodynamics/transformer_models/src/metrics.py +++ b/examples/cfd/external_aerodynamics/transformer_models/src/metrics.py @@ -23,28 +23,44 @@ def all_reduce_dict( - metrics: dict[str, torch.Tensor], dm: DistributedManager + metrics: dict[str, torch.Tensor], + dm: DistributedManager, + data_mesh=None, ) -> dict[str, torch.Tensor]: """ - Reduces a dictionary of metrics across all distributed processes. + Reduces a dictionary of metrics across data-parallel replicas. + + When *data_mesh* is provided the reduction runs only over the + data-parallel dimension (so domain-parallel ranks are not + double-counted). Otherwise falls back to the full world. Args: metrics: Dictionary of metric names to torch.Tensor values. dm: DistributedManager instance for distributed context. + data_mesh: Optional DeviceMesh for the data-parallel dimension. Returns: Dictionary of reduced metrics. """ - # TODO - update this to use domains and not the full world + if data_mesh is not None: + num_replicas = data_mesh.size() + else: + num_replicas = dm.world_size - if dm.world_size == 1: + if num_replicas <= 1: return metrics # Pack the metrics together: merged_metrics = torch.stack(list(metrics.values()), dim=-1) - dist.all_reduce(merged_metrics) - merged_metrics = merged_metrics / dm.world_size + # ShardTensors (from domain parallelism) must be materialised before + # the plain dist.all_reduce call. + if isinstance(merged_metrics, ShardTensor): + merged_metrics = merged_metrics.full_tensor() + + group = data_mesh.get_group() if data_mesh is not None else None + dist.all_reduce(merged_metrics, group=group) + merged_metrics = merged_metrics / num_replicas # Unstack metrics: metrics = {key: merged_metrics[i] for i, key in enumerate(metrics.keys())} @@ -57,6 +73,7 @@ def metrics_fn( target: torch.Tensor, dm: DistributedManager, mode: str, + data_mesh=None, ) -> dict[str, torch.Tensor]: """ Computes metrics for either surface or volume data. @@ -67,6 +84,7 @@ def metrics_fn( others: Dictionary containing normalization statistics. dm: DistributedManager instance for distributed context. mode: Either "surface" or "volume". + data_mesh: Optional DeviceMesh for the data-parallel dimension. Returns: Dictionary of computed metrics. @@ -79,7 +97,7 @@ def metrics_fn( else: raise ValueError(f"Unknown data mode: {mode}") - metrics = all_reduce_dict(metrics, dm) + metrics = all_reduce_dict(metrics, dm, data_mesh=data_mesh) return metrics diff --git a/examples/cfd/external_aerodynamics/transformer_models/src/train.py b/examples/cfd/external_aerodynamics/transformer_models/src/train.py index a01d645944..576a4a6bc2 100644 --- a/examples/cfd/external_aerodynamics/transformer_models/src/train.py +++ b/examples/cfd/external_aerodynamics/transformer_models/src/train.py @@ -15,15 +15,15 @@ # limitations under the License. # Core python imports: +import json import os import time +from datetime import datetime, timezone from pathlib import Path -from typing import Literal, Any, Callable, Sequence +from typing import Literal, Any import collections from contextlib import nullcontext -from collections.abc import Sequence - # Configuration: import hydra import omegaconf @@ -31,7 +31,6 @@ # Pytorch imports: import torch -from torch.optim import Optimizer from torch.amp import autocast, GradScaler from torch.utils.tensorboard import SummaryWriter @@ -54,6 +53,11 @@ TransolverDataPipe, ) +# Domain parallelism imports: +from torch.distributed.fsdp import fully_shard +from torch.distributed.tensor import distribute_module +from physicsnemo.domain_parallel import ShardTensor + # Local folder imports for this example from metrics import metrics_fn @@ -85,64 +89,56 @@ torch.serialization.add_safe_globals([omegaconf.base.Metadata]) -class CombinedOptimizer(Optimizer): - """Combine multiple PyTorch optimizers into a single Optimizer-like interface. - - The wrapper concatenates the *param_groups* from all contained optimizers so - that learning-rate schedulers (e.g., ReduceLROnPlateau, CosineAnnealingLR) - operate transparently across every parameter. Only a minimal subset of the - *torch.optim.Optimizer* API is implemented—extend as needed. +def _append_jsonl(path, records): + """Append a list of dicts as JSON Lines to *path*.""" + with open(path, "a") as f: + for rec in records: + f.write(json.dumps(rec, default=str) + "\n") - Note: - This will get upstreamed to physicsnemo shortly. Don't count on this - class existing here in the future! - In other words, this is already marked for deprecation! +def setup_domain_parallelism( + cfg: DictConfig, + dist_manager: DistributedManager, + logger, +): + """Set up a 2D DeviceMesh for domain + data parallelism. + + The mesh has shape ``(data_parallel_size, domain_parallel_size)`` with + dimension names ``("ddp", "domain")``. When ``domain_parallel_size`` + is 1 (or single-GPU), both returned meshes are ``None`` and the caller + should fall back to standard DDP. + + Returns + ------- + domain_mesh : DeviceMesh | None + Sub-mesh for the domain-parallel dimension. + data_mesh : DeviceMesh | None + Sub-mesh for the data-parallel (DDP/FSDP) dimension. """ + domain_parallel_size = getattr(cfg, "domain_parallel_size", 1) - def __init__( - self, - optimizers: Sequence[Optimizer], - torch_compile_kwargs: dict[str, Any] | None = None, - ): - if not optimizers: - raise ValueError("`optimizers` must contain at least one optimizer.") - - self.optimizers = optimizers - - # Collect parameter groups from all optimizers. We pass an empty - # *defaults* dict because hyper-parameters are managed by the inner - # optimizers, not this wrapper. - param_groups = [g for opt in optimizers for g in opt.param_groups] - super().__init__(param_groups, defaults={}) + if domain_parallel_size <= 1 or dist_manager.world_size <= 1: + return None, None - if torch_compile_kwargs is None: - self.step_fns: list[Callable] = [opt.step for opt in optimizers] - else: - self.step_fns: list[Callable] = [ - torch.compile(opt.step, **torch_compile_kwargs) for opt in optimizers - ] - - def zero_grad(self, *args, **kwargs) -> None: - """Nullify gradients""" - for opt in self.optimizers: - opt.zero_grad(*args, **kwargs) - - def step(self, closure=None) -> None: - for step_fn in self.step_fns: - if closure is None: - step_fn() - else: - step_fn(closure) + if dist_manager.world_size % domain_parallel_size != 0: + raise ValueError( + f"world_size ({dist_manager.world_size}) must be divisible " + f"by domain_parallel_size ({domain_parallel_size})" + ) - def state_dict(self): - return {"optimizers": [opt.state_dict() for opt in self.optimizers]} + mesh = dist_manager.initialize_mesh( + mesh_shape=(-1, domain_parallel_size), + mesh_dim_names=("ddp", "domain"), + ) + domain_mesh = mesh["domain"] + data_mesh = mesh["ddp"] - def load_state_dict(self, state_dict): - for opt, sd in zip(self.optimizers, state_dict["optimizers"]): - opt.load_state_dict(sd) + logger.info( + f"Domain parallelism enabled: domain_size={domain_parallel_size}, " + f"data_parallel_size={data_mesh.size()}" + ) - self.param_groups = [g for opt in self.optimizers for g in opt.param_groups] + return domain_mesh, data_mesh def get_autocast_context(precision: str) -> nullcontext: @@ -260,6 +256,7 @@ def forward_pass( dist_manager: DistributedManager, data_mode: Literal["surface", "volume"], datapipe: TransolverDataPipe, + data_mesh=None, ): """ Run the forward pass of the model for one batch, including metrics and loss calculation. @@ -271,10 +268,24 @@ def forward_pass( """ - features = batch["fx"] embeddings = batch["embeddings"] targets = batch["fields"] + # Hack: sharded reads in CAEDataset drop Zarr attributes, so the datapipe + # may not build "fx". Reconstruct it from the raw scalars if needed. + if "fx" in batch: + features = batch["fx"] + elif "air_density" in batch and "stream_velocity" in batch: + features = torch.stack( + [batch["air_density"], batch["stream_velocity"]], dim=-1 + ) + features = features.broadcast_to(*embeddings.shape[:-1], -1) + else: + raise KeyError( + f"Batch has neither 'fx' nor 'air_density'/'stream_velocity'. " + f"Keys: {list(batch.keys())}" + ) + # Cast precisions: features = cast_precisions(features, precision=precision) embeddings = cast_precisions(embeddings, precision=precision) @@ -325,7 +336,7 @@ def forward_pass( # This is the Transolver path outputs = model(fx=features, embedding=embeddings) outputs = unpad_output_for_fp8(outputs, output_pad_size) - full_loss = torch.nn.functional.mse_loss(outputs, targets) + full_loss = ((outputs - targets) ** 2).mean() all_metrics[f"loss/{modes[0]}"] = full_loss @@ -346,7 +357,7 @@ def forward_pass( stream_velocity=stream_velocity, factor_type=modes, ) - metrics = metrics_fn(unscaled_outputs, unscaled_targets, dist_manager, modes) + metrics = metrics_fn(unscaled_outputs, unscaled_targets, dist_manager, modes, data_mesh=data_mesh) # In the combined mode, this is a list of dicts. Merge them. metrics = ( @@ -356,6 +367,13 @@ def forward_pass( ) all_metrics.update(metrics) + # Materialise any ShardTensor scalars so downstream code (TensorBoard, + # tabulate, etc.) sees plain tensors. + all_metrics = { + k: v.full_tensor() if isinstance(v, ShardTensor) else v + for k, v in all_metrics.items() + } + return full_loss, all_metrics, (unscaled_outputs, unscaled_targets) @@ -373,6 +391,8 @@ def train_epoch( cfg: DictConfig, dist_manager: DistributedManager, scaler: GradScaler | None = None, + data_mesh=None, + metrics_log_path: str | None = None, ) -> float: """ Train the model for one epoch. @@ -396,13 +416,12 @@ def train_epoch( model.train() total_loss = 0 total_metrics = {} + step_records: list[dict] = [] precision = getattr(cfg, "precision", "float32") start_time = time.time() for i, batch in enumerate(dataloader): - # TransolverX has a different forward pass: - loss, metrics, _ = forward_pass( batch, model, @@ -411,6 +430,7 @@ def train_epoch( dist_manager, cfg.data.mode, dataloader, + data_mesh=data_mesh, ) optimizer.zero_grad() @@ -460,6 +480,20 @@ def train_epoch( f"batch/{metric_name}", metric_value, i + epoch_len * epoch ) + step_records.append({ + "timestamp": datetime.now(timezone.utc).isoformat(), + "phase": "train", + "epoch": epoch, + "iteration": i, + "global_step": i + epoch_len * epoch, + "loss": this_loss, + "learning_rate": optimizer.param_groups[0]["lr"], + "duration_s": duration, + "throughput_per_gpu": images_per_second, + "memory_reserved_gb": mem_usage, + **{k: float(v) if hasattr(v, "item") else v for k, v in metrics.items()}, + }) + if cfg.profile and i >= 10: break # Stop profiling after 10 batches @@ -476,6 +510,10 @@ def train_epoch( tablefmt="pretty", ) print(f"\nEpoch {epoch} Average Metrics:\n{metrics_table}\n") + + if metrics_log_path: + _append_jsonl(metrics_log_path, step_records) + return avg_loss @@ -490,6 +528,8 @@ def val_epoch( epoch: int, cfg: DictConfig, dist_manager: DistributedManager, + data_mesh=None, + metrics_log_path: str | None = None, ) -> float: """ Run validation for one epoch. @@ -511,6 +551,7 @@ def val_epoch( model.eval() # Set model to evaluation mode total_loss = 0 total_metrics = {} + step_records: list[dict] = [] precision = getattr(cfg.training, "precision", "float32") @@ -525,6 +566,7 @@ def val_epoch( dist_manager, cfg.data.mode, dataloader, + data_mesh=data_mesh, ) if i == 0: @@ -542,11 +584,26 @@ def val_epoch( duration = end_time - start_time start_time = end_time + mem_usage = torch.cuda.memory_reserved() / 1024**3 + logger.info( f"Val [{i}/{epoch_len}] Loss: {this_loss:.6f} Duration: {duration:.2f}s" ) # We don't add individual loss measurements to tensorboard in the validation loop. + if dist_manager.rank == 0: + step_records.append({ + "timestamp": datetime.now(timezone.utc).isoformat(), + "phase": "val", + "epoch": epoch, + "iteration": i, + "global_step": i + epoch_len * epoch, + "loss": this_loss, + "duration_s": duration, + "memory_reserved_gb": mem_usage, + **{k: float(v) if hasattr(v, "item") else v for k, v in metrics.items()}, + }) + if cfg.profile and i >= 10: break # Stop profiling after 10 batches @@ -563,6 +620,10 @@ def val_epoch( tablefmt="pretty", ) print(f"\nEpoch {epoch} Validation Average Metrics:\n{metrics_table}\n") + + if metrics_log_path: + _append_jsonl(metrics_log_path, step_records) + return avg_loss @@ -623,9 +684,35 @@ def main(cfg: DictConfig): # Set up distributed training dist_manager = DistributedManager() + # Debug: show rank, local rank, device, and node info for each process + import socket + hostname = socket.gethostname() + device = dist_manager.device + rank = dist_manager.rank + local_rank = dist_manager.local_rank + world_size = dist_manager.world_size + gpu_name = torch.cuda.get_device_name(device) if torch.cuda.is_available() else "N/A" + print( + f"[DEBUG] Rank {rank}/{world_size - 1} | " + f"Local Rank {local_rank} | " + f"Node: {hostname} | " + f"Device: {device} | " + f"GPU: {gpu_name}" + ) + if torch.cuda.is_available(): + mem_total = torch.cuda.get_device_properties(device).total_memory / (1024**3) + print( + f"[DEBUG] Rank {rank} | GPU Memory: {mem_total:.1f} GB | " + f"CUDA Version: {torch.version.cuda}" + ) + torch.distributed.barrier() + # Set up logging logger = RankZeroLoggingWrapper(PythonLogger(name="training"), dist_manager) + # Set up domain parallelism (2D mesh: data-parallel x domain-parallel) + domain_mesh, data_mesh = setup_domain_parallelism(cfg, dist_manager, logger) + # Set checkpoint directory - defaults to output_dir if not specified checkpoint_dir = getattr(cfg, "checkpoint_dir", None) if checkpoint_dir is None: @@ -644,9 +731,13 @@ def main(cfg: DictConfig): cfg.output_dir + "/" + cfg.run_id + "/val", ) ) + metrics_log_path = os.path.join( + cfg.output_dir, cfg.run_id, "step_metrics.jsonl" + ) else: writer = None val_writer = None + metrics_log_path = None logger.info(f"Config:\n{omegaconf.OmegaConf.to_yaml(cfg, resolve=True)}") logger.info(f"Output directory: {cfg.output_dir}/{cfg.run_id}") @@ -661,11 +752,18 @@ def main(cfg: DictConfig): model.to(dist_manager.device) - model = torch.nn.parallel.DistributedDataParallel( - model, - device_ids=[dist_manager.local_rank], - output_device=dist_manager.device, - ) + if domain_mesh is not None: + # Domain parallelism: distribute_module makes the model + # ShardTensor-aware on the domain mesh; fully_shard handles + # gradient sync across data-parallel replicas (FSDP2). + model = distribute_module(model, device_mesh=domain_mesh) + model = fully_shard(model, mesh=data_mesh) + else: + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[dist_manager.local_rank], + output_device=dist_manager.device, + ) num_params = sum(p.numel() for p in model.parameters()) logger.info(f"Number of parameters: {num_params}") @@ -698,19 +796,26 @@ def main(cfg: DictConfig): phase="train", surface_factors=surface_factors, volume_factors=volume_factors, + device_mesh=domain_mesh, ) # Validation dataset - val_dataloader = create_transolver_dataset( cfg.data, phase="val", surface_factors=surface_factors, volume_factors=volume_factors, + device_mesh=domain_mesh, ) - num_replicas = dist_manager.world_size - data_rank = dist_manager.rank + # With domain parallelism, only the data-parallel dimension + # determines which samples each group sees. + if data_mesh is not None: + num_replicas = data_mesh.size() + data_rank = data_mesh.get_local_rank() + else: + num_replicas = dist_manager.world_size + data_rank = dist_manager.rank # Set up distributed samplers train_sampler = torch.utils.data.distributed.DistributedSampler( @@ -729,23 +834,8 @@ def main(cfg: DictConfig): drop_last=True, ) - muon_params = [p for p in model.parameters() if p.ndim == 2] - other_params = [p for p in model.parameters() if p.ndim != 2] - # Set up optimizer and scheduler - optimizer = hydra.utils.instantiate(cfg.training.optimizer, params=other_params) - - optimizer = CombinedOptimizer( - optimizers=[ - torch.optim.Muon( - muon_params, - lr=cfg.training.optimizer.lr, - weight_decay=cfg.training.optimizer.weight_decay, - adjust_lr_fn="match_rms_adamw", - ), - optimizer, - ], - ) + optimizer = hydra.utils.instantiate(cfg.training.optimizer, params=model.parameters()) # Set up learning rate scheduler based on config scheduler_cfg = cfg.training.scheduler @@ -799,6 +889,8 @@ def main(cfg: DictConfig): cfg, dist_manager, scaler, + data_mesh=data_mesh, + metrics_log_path=metrics_log_path, ) end_time = time.time() train_duration = end_time - start_time @@ -815,6 +907,8 @@ def main(cfg: DictConfig): epoch, cfg, dist_manager, + data_mesh=data_mesh, + metrics_log_path=metrics_log_path, ) end_time = time.time() val_duration = end_time - start_time @@ -824,8 +918,28 @@ def main(cfg: DictConfig): f"Epoch [{epoch}/{cfg.training.num_epochs}] Train Loss: {train_loss:.6f} [duration: {train_duration:.2f}s] Val Loss: {val_loss:.6f} [duration: {val_duration:.2f}s]" ) - # save checkpoint - if epoch % cfg.training.save_interval == 0 and dist_manager.rank == 0: + # Write epoch-level summary records + if dist_manager.rank == 0 and metrics_log_path: + _append_jsonl(metrics_log_path, [ + { + "timestamp": datetime.now(timezone.utc).isoformat(), + "phase": "train_epoch", + "epoch": epoch, + "avg_loss": train_loss, + "duration_s": train_duration, + }, + { + "timestamp": datetime.now(timezone.utc).isoformat(), + "phase": "val_epoch", + "epoch": epoch, + "avg_loss": val_loss, + "duration_s": val_duration, + }, + ]) + + # save checkpoint (all ranks must participate for FSDP/DTensor gathers; + # save_checkpoint internally gates file I/O to rank 0) + if epoch % cfg.training.save_interval == 0: save_checkpoint(**ckpt_args, epoch=epoch + 1) if scheduler_name == "StepLR": diff --git a/examples/cfd/external_aerodynamics/transformer_models/train_transolver_sharded.sh b/examples/cfd/external_aerodynamics/transformer_models/train_transolver_sharded.sh new file mode 100644 index 0000000000..cf2a896d38 --- /dev/null +++ b/examples/cfd/external_aerodynamics/transformer_models/train_transolver_sharded.sh @@ -0,0 +1,81 @@ +#!/bin/bash +#SBATCH --account=coreai_modulus_cae +#SBATCH --job-name=transolver_volume-drivaer_ml +#SBATCH --nodes=16 +#SBATCH --ntasks-per-node=4 +#SBATCH --gpus-per-node=4 +#SBATCH --time=4:00:00 +#SBATCH --output=transolver_volume-drivaer_ml_%A.out +#SBATCH --error=transolver_volume-drivaer_ml_%A.err +#SBATCH --partition=batch + +# Paths +export USER_LUSTRE=/lustre/fsw/portfolios/coreai/users/coreya +export GROUP_LUSTRE=/lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae +export HOME=${USER_LUSTRE} + +# Container setup +CONTAINER_IMAGE=$GROUP_LUSTRE/containers/pytorch26.01-py3.sqsh +CONTAINER_MOUNTS="$USER_LUSTRE:/user_data/,$GROUP_LUSTRE:/group_data,$HOME:/root/,/lustre:/lustre,/tmp:/tmp" + +# Virtual environment path +VENV_PATH="$USER_LUSTRE/venvs/shard_tensor_benchmarks/" + +WORKDIR="$USER_LUSTRE/workdir/shard_tensor_benchmarks/examples/cfd/external_aerodynamics/transformer_models/" + +# Hydra (src/conf/transolver_volume.yaml) +TRAIN_SCRIPT="src/train.py --config-name transolver_volume" + +PRECISION="bfloat16" +SAMPLING_RESOLUTION_PER_GPU=200000 + +# DrivAer ML Zarr dataset paths (set these to your Zarr directories): +ZARR_TRAIN_PATH="/lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/domino/train/" +ZARR_VAL_PATH="/lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/domino/val/" + +# Domain parallelism: number of GPUs collaborating on one sample. +# world_size must be divisible by this. Set to 1 to disable. +DOMAIN_PARALLEL_SIZE=8 + +# Total resolution scales with domain parallel size so each GPU +# keeps SAMPLING_RESOLUTION_PER_GPU points. +TOTAL_RESOLUTION=$((SAMPLING_RESOLUTION_PER_GPU * DOMAIN_PARALLEL_SIZE)) + +NODES=${SLURM_NNODES:-1} +GPUS_PER_NODE=${SLURM_NTASKS_PER_NODE:-1} +TOTAL_GPUS=$((NODES * GPUS_PER_NODE)) + +RUN_ID="transolver/volume/drivaer_ml_${PRECISION}_res${TOTAL_RESOLUTION}_${SAMPLING_RESOLUTION_PER_GPU}ppg_dp${DOMAIN_PARALLEL_SIZE}_${TOTAL_GPUS}gpu" + +EXTRA_HYDRA_OVERRIDES="" + +OVERRIDES="run_id=${RUN_ID} " +OVERRIDES+="precision=${PRECISION} " +OVERRIDES+="data.resolution=${TOTAL_RESOLUTION} " +OVERRIDES+="data.train.data_path=${ZARR_TRAIN_PATH} " +OVERRIDES+="data.val.data_path=${ZARR_VAL_PATH} " +OVERRIDES+="domain_parallel_size=${DOMAIN_PARALLEL_SIZE} " +OVERRIDES+="${EXTRA_HYDRA_OVERRIDES}" + +echo "Overrides: ${OVERRIDES}" + +# Launch the job with container +# As far as I know, environment variables are evaluated *before* the container is launched. +# So use the right paths for the right space! + +export PATH="/cm/local/apps/slurm/current/bin:${PATH}" + +srun --ntasks-per-node=4 \ + --container-image=${CONTAINER_IMAGE} \ + --container-mounts ${CONTAINER_MOUNTS} \ + bash -c " + + # Set up virtual environment + source ${VENV_PATH}/bin/activate + + # This is where I have the training script in the container: + cd ${WORKDIR} + + # Run the training script with overrides + python ${TRAIN_SCRIPT} ${OVERRIDES} + " diff --git a/physicsnemo/datapipes/cae/transolver_datapipe.py b/physicsnemo/datapipes/cae/transolver_datapipe.py index ec9c576962..95695da5de 100644 --- a/physicsnemo/datapipes/cae/transolver_datapipe.py +++ b/physicsnemo/datapipes/cae/transolver_datapipe.py @@ -44,7 +44,9 @@ unnormalize, unstandardize, ) +from physicsnemo.domain_parallel import ShardTensor from physicsnemo.nn.functional import signed_distance_field +from torch.distributed.tensor.placement_types import Replicate @dataclass @@ -177,17 +179,36 @@ def preprocess_surface_data( scale_factor: torch.Tensor | None = None, ): positions = data_dict["surface_mesh_centers"] + _is_sharded = isinstance(positions, ShardTensor) if self.config.resolution is not None: - idx = torch.multinomial( - torch.ones(data_dict["surface_mesh_centers"].shape[0]), - self.config.resolution, - ) + if _is_sharded: + _mesh = positions._spec.mesh + _placements = positions._spec.placements + _domain_size = _mesh.size() + _local_res = self.config.resolution // _domain_size + local_pos = positions.to_local() + idx = torch.multinomial( + torch.ones(local_pos.shape[0]), + _local_res, + ) + positions = ShardTensor.from_local( + local_pos[idx], device_mesh=_mesh, placements=_placements, + ) + else: + idx = torch.multinomial( + torch.ones(data_dict["surface_mesh_centers"].shape[0]), + self.config.resolution, + ) + positions = positions[idx] else: idx = None - if idx is not None: - positions = positions[idx] + if _is_sharded: + if scale_factor is not None: + scale_factor = self._promote_to_replicated(scale_factor, positions) + if center_of_mass is not None: + center_of_mass = self._promote_to_replicated(center_of_mass, positions) # This is a center of mass computation for the stl surface, # using the size of each mesh point as weight. @@ -203,7 +224,12 @@ def preprocess_surface_data( if self.config.include_normals: normals = data_dict["surface_normals"] if idx is not None: - normals = normals[idx] + if _is_sharded: + normals = ShardTensor.from_local( + normals.to_local()[idx], device_mesh=_mesh, placements=_placements, + ) + else: + normals = normals[idx] normals = normals / torch.norm(normals, dim=-1, keepdim=True) embeddings_inputs.append(normals) @@ -211,7 +237,12 @@ def preprocess_surface_data( fields = data_dict["surface_fields"] if idx is not None: - fields = fields[idx] + if _is_sharded: + fields = ShardTensor.from_local( + fields.to_local()[idx], device_mesh=_mesh, placements=_placements, + ) + else: + fields = fields[idx] if self.config.scaling_type is not None: fields = self.scale_model_targets(fields, self.config.surface_factors) @@ -248,16 +279,36 @@ def preprocess_volume_data( scale_factor: torch.Tensor | None = None, ): positions = data_dict["volume_mesh_centers"] + _is_sharded = isinstance(positions, ShardTensor) if self.config.resolution is not None: - idx = poisson_sample_indices_fixed( - positions.shape[0], self.config.resolution, device=positions.device - ) + if _is_sharded: + _mesh = positions._spec.mesh + _placements = positions._spec.placements + _domain_size = _mesh.size() + _local_res = self.config.resolution // _domain_size + local_pos = positions.to_local() + idx = poisson_sample_indices_fixed( + local_pos.shape[0], _local_res, device=local_pos.device + ) + positions = ShardTensor.from_local( + local_pos[idx], device_mesh=_mesh, placements=_placements, + ) + else: + idx = poisson_sample_indices_fixed( + positions.shape[0], self.config.resolution, device=positions.device + ) + positions = positions[idx] else: idx = None - if idx is not None: - positions = positions[idx] + # When domain-parallel, promote plain tensors to replicated ShardTensors + # so arithmetic with sharded positions/coords doesn't hit mixed-type errors. + if _is_sharded: + if scale_factor is not None: + scale_factor = self._promote_to_replicated(scale_factor, positions) + if center_of_mass is not None: + center_of_mass = self._promote_to_replicated(center_of_mass, positions) # We need the CoM for some operations, regardless of translation invariance: if center_of_mass is None: @@ -309,8 +360,14 @@ def preprocess_volume_data( distance_to_closest_point = torch.norm(positions - closest_points, dim=-1) null_points = distance_to_closest_point < 1e-6 - # In these cases, we update the vector to be from the center of mass - normals[null_points] = positions[null_points] - center_of_mass + # In these cases, we update the vector to be from the center of mass. + # Use torch.where instead of boolean indexing since DTensor + # doesn't support aten.nonzero (dynamic output shape). + normals = torch.where( + null_points.unsqueeze(-1), + positions - center_of_mass, + normals, + ) norm = torch.norm(normals, dim=-1, keepdim=True) + 1e-6 normals = normals / norm @@ -321,7 +378,12 @@ def preprocess_volume_data( fields = data_dict["volume_fields"] if idx is not None: - fields = fields[idx] + if _is_sharded: + fields = ShardTensor.from_local( + fields.to_local()[idx], device_mesh=_mesh, placements=_placements, + ) + else: + fields = fields[idx] if self.config.scaling_type is not None: fields = self.scale_model_targets(fields, self.config.volume_factors) @@ -525,6 +587,14 @@ def process_data(self, data_dict): return outputs + def _promote_to_replicated(self, tensor, ref_shard_tensor): + """Wrap a plain tensor as a Replicate-placed ShardTensor on the same mesh.""" + if isinstance(tensor, ShardTensor): + return tensor + return ShardTensor.from_local( + tensor, device_mesh=ref_shard_tensor._spec.mesh, placements=[Replicate()], + ) + def scale_model_targets( self, fields: torch.Tensor, factors: torch.Tensor ) -> torch.Tensor: @@ -534,10 +604,16 @@ def scale_model_targets( if self.config.scaling_type == "mean_std_scaling": field_mean = factors["mean"] field_std = factors["std"] + if isinstance(fields, ShardTensor): + field_mean = self._promote_to_replicated(field_mean, fields) + field_std = self._promote_to_replicated(field_std, fields) return standardize(fields, field_mean, field_std) elif self.config.scaling_type == "min_max_scaling": field_min = factors["min"] field_max = factors["max"] + if isinstance(fields, ShardTensor): + field_min = self._promote_to_replicated(field_min, fields) + field_max = self._promote_to_replicated(field_max, fields) return normalize(fields, field_max, field_min) def unscale_model_targets( @@ -571,10 +647,16 @@ def unscale_model_targets( if self.config.scaling_type == "mean_std_scaling": field_mean = factors["mean"] field_std = factors["std"] + if isinstance(fields, ShardTensor): + field_mean = self._promote_to_replicated(field_mean, fields) + field_std = self._promote_to_replicated(field_std, fields) fields = unstandardize(fields, field_mean, field_std) elif self.config.scaling_type == "min_max_scaling": field_min = factors["min"] field_max = factors["max"] + if isinstance(fields, ShardTensor): + field_min = self._promote_to_replicated(field_min, fields) + field_max = self._promote_to_replicated(field_max, fields) fields = unnormalize(fields, field_max, field_min) # if air_density is not None and stream_velocity is not None: @@ -636,6 +718,7 @@ def __call__(self, data_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor """ outputs = self.process_data(data_dict) + for key in outputs.keys(): if isinstance(outputs[key], list): outputs[key] = [item.unsqueeze(0) for item in outputs[key]] @@ -724,10 +807,18 @@ def create_transolver_dataset( if cfg.get(optional_key, None) is not None: overrides[optional_key] = cfg[optional_key] + # Defaults for Zarr attributes that the sharded reader currently drops. + # Non-sharded reads get the real per-sample values from group attrs; + # sharded reads fall back to these until read_file_sharded is fixed. + keys_to_read_if_available = { + "air_density": torch.tensor(1.0, dtype=torch.float32), + "stream_velocity": torch.tensor(1.0, dtype=torch.float32), + } + dataset = CAEDataset( data_dir=input_path, keys_to_read=keys_to_read, - keys_to_read_if_available={}, + keys_to_read_if_available=keys_to_read_if_available, output_device=device, preload_depth=preload_depth, pin_memory=pin_memory,