diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d3996dabe..caaf17aad1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -111,6 +111,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 end-to-end training and sampling of epsilon-parameterized models. Losses gain an `epsilon_to_x0_fn` kwarg used for the epsilon-to-x0 conversion required during DSM training. +- Adds `DiffusionUNet3D` 3D U-Net diffusion backbone for volumetric data at + `physicsnemo.experimental.models.diffusion_unets`. Implements the + `DiffusionModel` protocol. Exposes reusable 3D building blocks + (`Conv3D`, `GroupNorm3D`, `UNetAttention3D`, `UNetBlock3D`) at + `physicsnemo.experimental.nn`. - Added support for Batched radius search, which enables Domino and GeoTransolver with local features and batch size > 1. - Added the underfill recipe. diff --git a/physicsnemo/experimental/models/diffusion_unets/__init__.py b/physicsnemo/experimental/models/diffusion_unets/__init__.py new file mode 100644 index 0000000000..8131881584 --- /dev/null +++ b/physicsnemo/experimental/models/diffusion_unets/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .diffusion_unet_3d import DiffusionUNet3D + +__all__ = ["DiffusionUNet3D"] \ No newline at end of file diff --git a/physicsnemo/experimental/models/diffusion_unets/diffusion_unet_3d.py b/physicsnemo/experimental/models/diffusion_unets/diffusion_unet_3d.py new file mode 100644 index 0000000000..016b2b594f --- /dev/null +++ b/physicsnemo/experimental/models/diffusion_unets/diffusion_unet_3d.py @@ -0,0 +1,570 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Literal, cast + +import numpy as np +import torch +from jaxtyping import Float +from tensordict import TensorDict +from torch.nn.functional import silu +from torch.utils.checkpoint import checkpoint + +from physicsnemo.core.meta import ModelMetaData +from physicsnemo.core.module import Module +from physicsnemo.experimental.nn import Conv3D, GroupNorm3D, UNetBlock3D +from physicsnemo.nn import ( + FourierEmbedding, + Linear, + PositionalEmbedding, +) + + +@dataclass +class MetaData(ModelMetaData): + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp_cpu: bool = False + amp_gpu: bool = True + torch_fx: bool = False + # Data type + bf16: bool = True + # Inference + onnx: bool = False + # Physics informed + func_torch: bool = False + auto_grad: bool = False + + +class DiffusionUNet3D(Module): + r""" + 3D U-Net diffusion backbone for volumetric data. + + Implements the :class:`~physicsnemo.diffusion.base.DiffusionModel` protocol + and can be used directly with preconditioners, losses, and samplers from + :mod:`physicsnemo.diffusion`. Conceptually a 3D counterpart of + :class:`~physicsnemo.models.diffusion_unets.SongUNet`; refer to that class + for the underlying architectural overview. + + Based on the architecture described in `Diff-SPORT: Diffusion-based Sensor + Placement Optimization and Reconstruction of Turbulent flows in urban + environments `_. + + Parameters + ---------- + x_channels : int + Number of channels :math:`C_x` in the input/output state + :math:`\mathbf{x}`. + The output has the same number of channels as the input. + vol_cond_channels : int, optional, default=0 + Number of channels :math:`C_{cond,v}` in the optional volume-based + conditioning. When non-zero, a volume condition tensor of shape + :math:`(B, C_{cond,v}, D, H, W)` may be passed via ``condition["volume"]``; + it is concatenated channel-wise to ``x`` before the first convolution. Set to ``0`` for no volume conditioning. + vec_cond_dim : int, optional, default=0 + Dimension :math:`D_v` of the optional vector-valued condition. + When non-zero, a condition tensor of shape + :math:`(B, D_v)` may be passed via ``condition["vector"]``. The vector + condition is mapped through a linear layer and added to the diffusion time + embedding; the resulting embedding then conditions all 3D U-Net blocks + via adaptive group norm. + num_levels : int, optional, default=4 + Number of encoder/decoder levels. ``len(channel_mult)`` must equal this. + model_channels : int, optional, default=128 + Base channel count at the first U-Net level. + channel_mult : list[int], optional, default=[1, 2, 2, 2] + Per-level channel multipliers. Channels at level :math:`l` equal + ``channel_mult[l] * model_channels``. Length must equal ``num_levels``. + channel_mult_emb : int, optional, default=4 + Multiplier for the conditioning embedding dimension: + ``emb_channels = model_channels * channel_mult_emb``. + num_blocks : int, optional, default=4 + Number of 3D U-Net blocks per level. The decoder has ``num_blocks + 1`` + blocks per level for the extra skip connection. + attention_levels : list[int], optional, default=[] + 0-indexed encoder levels at which to apply 3D self-attention. Level 0 + is the outermost (highest resolution). All values must be in + ``[0, num_levels)``. + dropout : float, optional, default=0.10 + Dropout probability inside the 3D U-Net blocks. + embedding_type : Literal["positional", "fourier", "zero"], optional, default="positional" + Embedding type used for both the diffusion time and (when present) the + vector condition. ``"positional"`` is the DDPM++ style, ``"fourier"`` is + the NCSN++ style, and ``"zero"`` replaces the time embedding by a zero + buffer and disables vector conditioning (so ``vec_cond_dim`` must be + ``0``). Volume conditioning is independent of ``embedding_type`` since + it is concatenated channel-wise to ``x`` before the first convolution. + channel_mult_noise : int, optional, default=1 + Multiplier for the noise-level embedding dimension: + ``noise_channels = model_channels * channel_mult_noise``. + encoder_type : Literal["standard", "skip", "residual"], optional, default="standard" + Encoder architecture variant (``"standard"`` = DDPM++, + ``"residual"`` = NCSN++, ``"skip"`` = skip connections). + decoder_type : Literal["standard", "skip"], optional, default="standard" + Decoder architecture variant. + resample_filter : list[int], optional, default=[1, 1] + 1D coefficients for the separable up/downsampling filter. The 3D filter + is constructed as their outer product, normalized to sum to 1. Use + ``[1, 1]`` for bilinear (DDPM++) or ``[1, 3, 3, 1]`` for bicubic + (NCSN++) resampling. + checkpoint_level : int, optional, default=0 + Gradient checkpointing aggressiveness. Higher values checkpoint more + layers, trading memory for compute. ``0`` disables checkpointing. + bottleneck_attention : bool, optional, default=True + If ``True``, applies 3D self-attention at the innermost bottleneck block. + Set to ``False`` for faster inference without bottleneck attention. + activation : Literal["silu", "gelu"], optional, default="silu" + Activation function used inside the 3D U-Net blocks. + + Forward + ------- + x : torch.Tensor + Input state of shape :math:`(B, C_x, D, H, W)`. Spatial dimensions must + be powers of 2 or multiples of :math:`2^{\text{num_levels}-1}`. + t : torch.Tensor + Batched diffusion time (or noise level) of shape :math:`(B,)`. + condition : TensorDict or None, optional, default=None + Conditioning information. ``None`` for unconditional models. Otherwise + a :class:`~tensordict.TensorDict` with a subset of: + + - ``"vector"``: tensor of shape :math:`(B, D_v)` (requires + ``vec_cond_dim > 0``). + - ``"volume"``: tensor of shape :math:`(B, C_v, D, H, W)` (requires + ``vol_cond_channels > 0`` and matching spatial dimensions). + + Any other key raises ``ValueError``. + + Outputs + ------- + torch.Tensor + Output of shape :math:`(B, C_x, D, H, W)`. The channels match + :math:`C_x` so the model can be used as any predictor type + (:math:`\mathbf{x}_0`, :math:`\boldsymbol{\epsilon}`, score, + velocity, etc.); the interpretation depends on the predictor / loss. + + Raises + ------ + ValueError + If ``len(channel_mult) != num_levels``. + ValueError + If any value in ``attention_levels`` is outside ``[0, num_levels)``. + ValueError + If ``embedding_type == "zero"`` is combined with non-zero + ``vec_cond_dim``. + + See Also + -------- + :class:`~physicsnemo.models.diffusion_unets.SongUNet` : 2D counterpart. + :class:`~physicsnemo.diffusion.base.DiffusionModel` : Protocol this model + implements. + + Examples + -------- + Unconditional model on a non-cubic grid: + + >>> import torch + >>> from physicsnemo.experimental.models.diffusion_unets import DiffusionUNet3D + >>> model = DiffusionUNet3D( + ... x_channels=4, num_levels=2, + ... model_channels=16, channel_mult=[1, 2], num_blocks=1, + ... ) + >>> x = torch.randn(2, 4, 4, 12, 16) + >>> out = model(x, torch.randn(2)) + >>> out.shape + torch.Size([2, 4, 4, 12, 16]) + + Conditional model with vector and volume conditioning: + + >>> from tensordict import TensorDict + >>> model = DiffusionUNet3D( + ... x_channels=4, vol_cond_channels=2, vec_cond_dim=8, + ... num_levels=2, model_channels=16, channel_mult=[1, 2], num_blocks=1, + ... ) + >>> cond = TensorDict( + ... {"vector": torch.randn(2, 8), "volume": torch.randn(2, 2, 4, 12, 16)}, + ... batch_size=[2], + ... ) + >>> out = model(x, torch.randn(2), condition=cond) + >>> out.shape + torch.Size([2, 4, 4, 12, 16]) + + Larger conditional model with custom encoder/decoder, attention at level 1, + NCSN++-style filter, no bottleneck attention, and gelu activation: + + >>> model = DiffusionUNet3D( + ... x_channels=2, vol_cond_channels=1, vec_cond_dim=4, + ... num_levels=3, model_channels=16, channel_mult=[1, 2, 2], num_blocks=2, + ... attention_levels=[1], encoder_type="residual", decoder_type="skip", + ... resample_filter=[1, 3, 3, 1], bottleneck_attention=False, + ... activation="gelu", + ... ) + >>> x = torch.randn(2, 2, 4, 12, 16) + >>> cond = TensorDict( + ... {"vector": torch.randn(2, 4), "volume": torch.randn(2, 1, 4, 12, 16)}, + ... batch_size=[2], + ... ) + >>> out = model(x, torch.randn(2), condition=cond) + >>> out.shape + torch.Size([2, 2, 4, 12, 16]) + """ + + def __init__( + self, + x_channels: int, + vol_cond_channels: int = 0, + vec_cond_dim: int = 0, + num_levels: int = 4, + model_channels: int = 128, + channel_mult: List[int] = [1, 2, 2, 2], + channel_mult_emb: int = 4, + num_blocks: int = 4, + attention_levels: List[int] = [], + dropout: float = 0.10, + embedding_type: Literal["fourier", "positional", "zero"] = "positional", + channel_mult_noise: int = 1, + encoder_type: Literal["standard", "skip", "residual"] = "standard", + decoder_type: Literal["standard", "skip"] = "standard", + resample_filter: List[int] = [1, 1], + checkpoint_level: int = 0, + bottleneck_attention: bool = True, + activation: Literal["silu", "gelu"] = "silu", + ): + if len(channel_mult) != num_levels: + raise ValueError( + f"len(channel_mult) must equal num_levels, got " + f"len(channel_mult)={len(channel_mult)} and num_levels={num_levels}" + ) + + if any(not (0 <= lvl < num_levels) for lvl in attention_levels): + raise ValueError( + f"All values in attention_levels must be in [0, num_levels=" + f"{num_levels}), got {attention_levels}" + ) + + if embedding_type == "zero" and vec_cond_dim > 0: + raise ValueError( + "embedding_type='zero' disables the conditioning embedding; " + "vec_cond_dim must be 0 in that case " + f"(got vec_cond_dim={vec_cond_dim})." + ) + + super().__init__(meta=MetaData()) + + self.x_channels = x_channels + self.vol_cond_channels = vol_cond_channels + self.vec_cond_dim = vec_cond_dim + self.embedding_type = embedding_type + self.num_levels = num_levels + self._input_shape_mult = 2 ** (num_levels - 1) + self.checkpoint_level = checkpoint_level + + emb_channels = model_channels * channel_mult_emb + self.emb_channels = emb_channels + noise_channels = model_channels * channel_mult_noise + + init = dict(init_mode="xavier_uniform") + init_zero = dict(init_mode="xavier_uniform", init_weight=1e-5) + init_attn = dict(init_mode="xavier_uniform", init_weight=np.sqrt(0.2)) + + block_kwargs = dict( + emb_channels=emb_channels, + num_heads=1, + dropout=dropout, + skip_scale=np.sqrt(0.5), + eps=1e-6, + resample_filter=resample_filter, + resample_proj=True, + adaptive_scale=False, + activation=activation, + init=init, + init_zero=init_zero, + init_attn=init_attn, + ) + + if self.embedding_type != "zero": + self.map_noise = ( + PositionalEmbedding(num_channels=noise_channels, endpoint=True) + if embedding_type == "positional" + else FourierEmbedding(num_channels=noise_channels) + ) + self.map_condition = ( + Linear(in_features=vec_cond_dim, out_features=noise_channels, **init) + if vec_cond_dim > 0 + else None + ) + self.map_layer0 = Linear( + in_features=noise_channels, out_features=emb_channels, **init + ) + self.map_layer1 = Linear( + in_features=emb_channels, out_features=emb_channels, **init + ) + else: + # FSDP-compatible zero buffer; persistent=False keeps it out of state_dict + self.register_buffer( + "zero_emb", torch.zeros(1, emb_channels), persistent=False + ) + self.map_condition = None + + # Encoder + self.enc = torch.nn.ModuleDict() + cout = x_channels + vol_cond_channels + caux = x_channels + vol_cond_channels + for level, mult in enumerate(channel_mult): + if level == 0: + cin = cout + cout = model_channels + self.enc[f"l{level}_conv"] = Conv3D( + in_channels=cin, out_channels=cout, kernel=3, **init + ) + else: + self.enc[f"l{level}_down"] = UNetBlock3D( + in_channels=cout, out_channels=cout, down=True, **block_kwargs + ) + if encoder_type == "skip": + self.enc[f"l{level}_aux_down"] = Conv3D( + in_channels=caux, + out_channels=caux, + kernel=0, + down=True, + resample_filter=resample_filter, + ) + self.enc[f"l{level}_aux_skip"] = Conv3D( + in_channels=caux, out_channels=cout, kernel=1, **init + ) + if encoder_type == "residual": + self.enc[f"l{level}_aux_residual"] = Conv3D( + in_channels=caux, + out_channels=cout, + kernel=3, + down=True, + resample_filter=resample_filter, + **init, + ) + caux = cout + for idx in range(num_blocks): + cin = cout + cout = model_channels * mult + attn = level in attention_levels + self.enc[f"l{level}_block{idx}"] = UNetBlock3D( + in_channels=cin, out_channels=cout, attention=attn, **block_kwargs + ) + + skips = [ + block.out_channels + for name, block in self.enc.items() + if "aux" not in name + ] + + # Decoder + self.dec = torch.nn.ModuleDict() + for level, mult in reversed(list(enumerate(channel_mult))): + if level == len(channel_mult) - 1: + self.dec[f"l{level}_in0"] = UNetBlock3D( + in_channels=cout, + out_channels=cout, + attention=bottleneck_attention, + **block_kwargs, + ) + self.dec[f"l{level}_in1"] = UNetBlock3D( + in_channels=cout, out_channels=cout, **block_kwargs + ) + else: + self.dec[f"l{level}_up"] = UNetBlock3D( + in_channels=cout, out_channels=cout, up=True, **block_kwargs + ) + for idx in range(num_blocks + 1): + cin = cout + skips.pop() + cout = model_channels * mult + attn = idx == num_blocks and level in attention_levels + self.dec[f"l{level}_block{idx}"] = UNetBlock3D( + in_channels=cin, out_channels=cout, attention=attn, **block_kwargs + ) + if decoder_type == "skip" or level == 0: + if decoder_type == "skip" and level < len(channel_mult) - 1: + self.dec[f"l{level}_aux_up"] = Conv3D( + in_channels=x_channels, + out_channels=x_channels, + kernel=0, + up=True, + resample_filter=resample_filter, + ) + self.dec[f"l{level}_aux_norm"] = GroupNorm3D( + num_channels=cout, eps=1e-6 + ) + self.dec[f"l{level}_aux_conv"] = Conv3D( + in_channels=cout, out_channels=x_channels, kernel=3, **init_zero + ) + + def forward( + self, + x: Float[torch.Tensor, "B C_x D H W"], + t: Float[torch.Tensor, " B"], + condition: TensorDict | None = None, + ) -> Float[torch.Tensor, "B C_x D H W"]: + + # Tensor shape validation + if not torch.compiler.is_compiling(): + if x.ndim != 5: + raise ValueError( + f"Expected x to be a 5D tensor, " + f"got {x.ndim}D tensor with shape {tuple(x.shape)}" + ) + + B, _, D, H, W = x.shape + + if x.shape[1] != self.x_channels: + raise ValueError( + f"Expected x to have {self.x_channels} channels (x_channels), " + f"got {x.shape[1]}" + ) + + for d in (D, H, W): + is_power_of_2 = (d & (d - 1)) == 0 and d > 0 + if not ( + (is_power_of_2 and d < self._input_shape_mult) + or (d % self._input_shape_mult == 0) + ): + raise ValueError( + f"Input spatial dimensions (D, H, W)={(D, H, W)} must be " + f"powers of 2 or multiples of 2**(num_levels-1)=" + f"{self._input_shape_mult}" + ) + + if t.shape != (B,): + raise ValueError( + f"Expected t to have shape ({B},), got {tuple(t.shape)}" + ) + + if condition is not None: + valid_keys = {"vector", "volume"} + extra_keys = set(condition.keys()) - valid_keys + if extra_keys: + raise ValueError( + f"Unexpected condition keys: {extra_keys}. " + f"Allowed keys: {valid_keys}" + ) + + vector_cond = condition.get("vector", None) + volume_cond = condition.get("volume", None) + + if vector_cond is not None: + if self.embedding_type == "zero": + raise ValueError( + "condition['vector'] cannot be used with " + "embedding_type='zero'." + ) + if self.vec_cond_dim == 0: + raise ValueError( + "condition['vector'] provided but vec_cond_dim=0" + ) + if vector_cond.shape != (B, self.vec_cond_dim): + raise ValueError( + f"Expected condition['vector'] to have shape " + f"{(B, self.vec_cond_dim)}, got {tuple(vector_cond.shape)}" + ) + + if volume_cond is not None: + if self.vol_cond_channels == 0: + raise ValueError( + "condition['volume'] provided but vol_cond_channels=0" + ) + if volume_cond.shape != (B, self.vol_cond_channels, D, H, W): + raise ValueError( + f"Expected condition['volume'] to have shape " + f"{(B, self.vol_cond_channels, D, H, W)}, got {tuple(volume_cond.shape)}" + ) + + # Extract condition components (no isinstance under torch.compile) + if condition is not None: + vector_cond = condition.get("vector", None) + volume_cond = condition.get("volume", None) + else: + vector_cond = None + volume_cond = None + + # Prepend volume condition channels to x + if volume_cond is not None: + x = torch.cat([x, volume_cond], dim=1) # (B, C_x + C_v, D, H, W) + + # Compute conditioning embedding from t and optional vector_cond + if self.embedding_type != "zero": + emb = self.map_noise(t) + emb_shape = emb.shape + # Swap sin/cos halves to match the DDPM++ convention + emb = emb.reshape(emb.shape[0], 2, -1) + emb = torch.concat([emb[:, 1:], emb[:, :1]], dim=1).reshape(*emb_shape) + if self.map_condition is not None and vector_cond is not None: + emb = emb + self.map_condition( + vector_cond * np.sqrt(self.map_condition.in_features) + ) + emb = silu(self.map_layer0(emb)) + emb = silu(self.map_layer1(emb)) + else: + emb = self.zero_emb.repeat(t.shape[0], 1) + + # Gradient-checkpointing threshold from current spatial extent + max_dim = max(x.shape[-3], x.shape[-2], x.shape[-1]) + threshold = (max_dim >> self.checkpoint_level) + 1 + + # Encoder: progressively downsample and cache skip connections + skips = [] + aux = x + for name, block in self.enc.items(): + if "aux_down" in name: + aux = block(aux) + elif "aux_skip" in name: + x = skips[-1] = x + block(aux) + elif "aux_residual" in name: + # Normalize by 1/sqrt(2) to preserve activation variance + x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2) + elif "_conv" in name: + x = block(x) + skips.append(x) + else: + if isinstance(block, UNetBlock3D): + if max(x.shape[-3], x.shape[-2], x.shape[-1]) > threshold: + x = checkpoint(block, x, emb, use_reentrant=False) + else: + x = block(x, emb) + else: + x = block(x) + skips.append(x) + + # Decoder: progressively upsample and merge skip connections. + out = None + tmp = None + for name, block in self.dec.items(): + if "aux_up" in name: + out = block(out) + elif "aux_norm" in name: + tmp = block(x) + elif "aux_conv" in name: + tmp = block(silu(tmp)) + out = tmp if out is None else tmp + out + else: + if x.shape[1] != block.in_channels: + x = torch.cat([x, skips.pop()], dim=1) + cur_max = max(x.shape[-3], x.shape[-2], x.shape[-1]) + if (cur_max > threshold and "_block" in name) or ( + cur_max > (threshold / 2) and "_up" in name + ): + x = checkpoint(block, x, emb, use_reentrant=False) + else: + x = block(x, emb) + + return cast(torch.Tensor, out) diff --git a/physicsnemo/experimental/nn/__init__.py b/physicsnemo/experimental/nn/__init__.py index 0e9cc95dee..a983424f91 100644 --- a/physicsnemo/experimental/nn/__init__.py +++ b/physicsnemo/experimental/nn/__init__.py @@ -22,5 +22,6 @@ """ from .flare_attention import FLARE +from .diffusion_unet_3d_blocks import UNetBlock3D, Conv3D, GroupNorm3D, UNetAttention3D -__all__ = ["FLARE"] +__all__ = ["FLARE", "UNetBlock3D", "Conv3D", "GroupNorm3D", "UNetAttention3D"] diff --git a/physicsnemo/experimental/nn/diffusion_unet_3d_blocks.py b/physicsnemo/experimental/nn/diffusion_unet_3d_blocks.py new file mode 100644 index 0000000000..405a0b1e73 --- /dev/null +++ b/physicsnemo/experimental/nn/diffusion_unet_3d_blocks.py @@ -0,0 +1,583 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Dict, List, Literal + +import torch +from einops import rearrange +from jaxtyping import Float +from torch.nn.functional import dropout, scaled_dot_product_attention, silu + +from physicsnemo.core.meta import ModelMetaData +from physicsnemo.core.module import Module +from physicsnemo.nn.module.fully_connected_layers import Linear +from physicsnemo.nn.module.utils.weight_init import _weight_init + + +class GroupNorm3D(Module): + r""" + Group Normalization for 5D tensors :math:`(B, C, D, H, W)`. + + Divides the channel dimension into groups and normalizes within each group + independently. During training, uses ``torch.nn.functional.group_norm``. + During inference, uses a manual implementation compatible with channels-last + memory layouts. + + Parameters + ---------- + num_channels : int + Number of channels in the input tensor. + num_groups : int, optional, default=32 + Target number of groups. Adjusted downward if + ``num_channels // num_groups < min_channels_per_group``. + min_channels_per_group : int, optional, default=4 + Minimum channels allowed per group. + eps : float, optional, default=1e-5 + Epsilon for numerical stability. + + Forward + ------- + x : torch.Tensor + Input tensor of shape :math:`(B, C, D, H, W)`. + + Outputs + ------- + torch.Tensor + Normalized tensor of shape :math:`(B, C, D, H, W)`. + + Examples + -------- + >>> import torch + >>> from physicsnemo.experimental.nn import GroupNorm3D + >>> gn = GroupNorm3D(num_channels=32) + >>> x = torch.randn(2, 32, 4, 12, 16) + >>> y = gn(x) + >>> y.shape + torch.Size([2, 32, 4, 12, 16]) + """ + + def __init__( + self, + num_channels: int, + num_groups: int = 32, + min_channels_per_group: int = 4, + eps: float = 1e-5, + ): + super().__init__(meta=ModelMetaData()) + self.num_groups = min(num_groups, num_channels // min_channels_per_group) + self.eps = eps + self.weight = torch.nn.Parameter(torch.ones(num_channels)) + self.bias = torch.nn.Parameter(torch.zeros(num_channels)) + + def forward( + self, x: Float[torch.Tensor, "B C D H W"] + ) -> Float[torch.Tensor, "B C D H W"]: + if self.training: + x = torch.nn.functional.group_norm( + x, + num_groups=self.num_groups, + weight=self.weight.to(x.dtype), + bias=self.bias.to(x.dtype), + eps=self.eps, + ) + else: + # Manual implementation that supports channels-last memory layout + dtype = x.dtype + x = x.float() + x = rearrange(x, "b (g c) d h w -> b g c d h w", g=self.num_groups) + mean = x.mean(dim=[2, 3, 4, 5], keepdim=True) + var = x.var(dim=[2, 3, 4, 5], keepdim=True) + x = (x - mean) * (var + self.eps).rsqrt() + x = rearrange(x, "b g c d h w -> b (g c) d h w") + x = x * rearrange(self.weight, "c -> 1 c 1 1 1") + rearrange( + self.bias, "c -> 1 c 1 1 1" + ) + x = x.to(dtype) + return x + + +class Conv3D(Module): + r""" + 3D convolution with optional fused up/downsampling. + + Implements a 3D convolution with optional 2x upsampling or downsampling via + separable bilinear/bicubic filters. When a convolution weight is present + (``kernel > 0``), resampling is fused with the convolution for efficiency. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel : int + Convolution kernel size applied uniformly across all spatial dimensions. + Set to 0 to apply resampling only (no learned convolution). + bias : bool, optional, default=True + Whether to include a learnable bias. + up : bool, optional, default=False + Apply 2x upsampling. Cannot be ``True`` simultaneously with ``down``. + down : bool, optional, default=False + Apply 2x downsampling. Cannot be ``True`` simultaneously with ``up``. + resample_filter : list[int], optional, default=[1, 1] + 1D coefficients for the separable up/downsampling filter. The 3D filter + is constructed as their outer product, normalized so it sums to 1. + Use ``[1, 1]`` for bilinear resampling or ``[1, 3, 3, 1]`` for bicubic. + Must be a non-empty list of positive integers. + init_mode : Literal["xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal"], optional, default="kaiming_normal" + Weight initialization mode. + init_weight : float, optional, default=1.0 + Multiplier applied to the initialized weight tensor. + init_bias : float, optional, default=0.0 + Multiplier applied to the initialized bias tensor. + + Raises + ------ + ValueError + If both ``up`` and ``down`` are ``True``, or if ``resample_filter`` is + empty / contains non-positive values when ``up`` or ``down`` is ``True``. + + Forward + ------- + x : torch.Tensor + Input tensor of shape :math:`(B, C_{in}, D, H, W)`. + + Outputs + ------- + torch.Tensor + Output tensor of shape :math:`(B, C_{out}, D', H', W')`. The spatial + dimensions are doubled (``up=True``), halved (``down=True``), or unchanged. + + Examples + -------- + >>> import torch + >>> from physicsnemo.experimental.nn import Conv3D + >>> conv = Conv3D(in_channels=4, out_channels=8, kernel=3) + >>> x = torch.randn(2, 4, 4, 12, 16) + >>> y = conv(x) + >>> y.shape + torch.Size([2, 8, 4, 12, 16]) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel: int, + bias: bool = True, + up: bool = False, + down: bool = False, + resample_filter: List[int] = [1, 1], + init_mode: Literal[ + "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" + ] = "kaiming_normal", + init_weight: float = 1.0, + init_bias: float = 0.0, + ): + if up and down: + raise ValueError("Both 'up' and 'down' cannot be True at the same time.") + if (up or down) and ( + not resample_filter or any(v <= 0 for v in resample_filter) + ): + raise ValueError( + f"resample_filter must be a non-empty list of positive integers " + f"when up=True or down=True, got {resample_filter}" + ) + + super().__init__(meta=ModelMetaData()) + self.in_channels = in_channels + self.out_channels = out_channels + self.up = up + self.down = down + + init_kwargs = dict( + mode=init_mode, + fan_in=in_channels * kernel * kernel * kernel, + fan_out=out_channels * kernel * kernel * kernel, + ) + self.weight = ( + torch.nn.Parameter( + _weight_init( + (out_channels, in_channels, kernel, kernel, kernel), **init_kwargs + ) + * init_weight + ) + if kernel + else None + ) + self.bias = ( + torch.nn.Parameter( + _weight_init((out_channels,), **init_kwargs) * init_bias + ) + if kernel and bias + else None + ) + + f = torch.as_tensor(resample_filter, dtype=torch.float32) + f = (f.ger(f).unsqueeze(2) * f.view(1, 1, -1)).unsqueeze(0).unsqueeze( + 1 + ) / f.sum().pow(3) + self.register_buffer("resample_filter", f.contiguous() if up or down else None) + + def forward( + self, x: Float[torch.Tensor, "B C_in D H W"] + ) -> Float[torch.Tensor, "B C_out D_out H_out W_out"]: + w = self.weight.to(x.dtype) if self.weight is not None else None + b = self.bias.to(x.dtype) if self.bias is not None else None + f = ( + self.resample_filter.to(x.dtype) + if self.resample_filter is not None + else None + ) + w_pad = w.shape[-1] // 2 if w is not None else 0 + f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0 + + if self.up and w is not None: + # Fused upsample + conv + x = torch.nn.functional.conv_transpose3d( + x, + f.mul(4).tile([self.in_channels, 1, 1, 1, 1]), + groups=self.in_channels, + stride=2, + padding=max(f_pad - w_pad, 0), + ) + x = torch.nn.functional.conv3d(x, w, padding=max(w_pad - f_pad, 0)) + elif self.down and w is not None: + # Fused conv + downsample + x = torch.nn.functional.conv3d(x, w, padding=w_pad + f_pad) + x = torch.nn.functional.conv3d( + x, + f.tile([self.out_channels, 1, 1, 1, 1]), + groups=self.out_channels, + stride=2, + ) + else: + if self.up: + x = torch.nn.functional.conv_transpose3d( + x, + f.mul(4).tile([self.in_channels, 1, 1, 1, 1]), + groups=self.in_channels, + stride=2, + padding=f_pad, + ) + if self.down: + x = torch.nn.functional.conv3d( + x, + f.tile([self.in_channels, 1, 1, 1, 1]), + groups=self.in_channels, + stride=2, + padding=f_pad, + ) + if w is not None: + x = torch.nn.functional.conv3d(x, w, padding=w_pad) + + if b is not None: + x = x.add_(b.reshape(1, -1, 1, 1, 1)) + return x + + +class UNetAttention3D(Module): + r""" + Multi-head 3D self-attention block. + + Applies group normalization followed by multi-head self-attention with a + residual connection. Operates on volumetric feature maps of shape + :math:`(B, C, D, H, W)`, flattening the spatial dimensions for the + attention operation. + + Parameters + ---------- + out_channels : int + Number of channels :math:`C` in the input and output feature maps. + Must be divisible by ``num_heads``. + num_heads : int + Number of attention heads. Must be a positive integer. + eps : float, optional, default=1e-5 + Epsilon for numerical stability in :class:`GroupNorm3D`. + init_zero : dict, optional, default={'init_weight': 0} + Initialization kwargs with near-zero weights for the output projection. + init_attn : dict or None, optional, default=None + Initialization kwargs for the QKV projection. Defaults to ``init`` if ``None``. + init : dict, optional, default={} + Initialization kwargs for linear and convolutional layers. + + Raises + ------ + ValueError + If ``num_heads`` is not a positive integer, or ``out_channels`` is not + divisible by ``num_heads``. + + Forward + ------- + x : torch.Tensor + Input feature map of shape :math:`(B, C, D, H, W)`. + + Outputs + ------- + torch.Tensor + Output feature map of shape :math:`(B, C, D, H, W)`, identical to input shape. + + Examples + -------- + >>> import torch + >>> from physicsnemo.experimental.nn import UNetAttention3D + >>> attn = UNetAttention3D(out_channels=32, num_heads=4) + >>> x = torch.randn(2, 32, 4, 12, 16) + >>> y = attn(x) + >>> y.shape + torch.Size([2, 32, 4, 12, 16]) + """ + + def __init__( + self, + *, + out_channels: int, + num_heads: int, + eps: float = 1e-5, + init_zero: Dict[str, Any] = dict(init_weight=0), + init_attn: Any = None, + init: Dict[str, Any] = dict(), + ) -> None: + super().__init__(meta=ModelMetaData()) + if not isinstance(num_heads, int) or num_heads <= 0: + raise ValueError( + f"num_heads must be a positive integer, got {num_heads}" + ) + if out_channels % num_heads != 0: + raise ValueError( + f"out_channels must be divisible by num_heads, " + f"got out_channels={out_channels} and num_heads={num_heads}" + ) + self.num_heads = num_heads + self.norm = GroupNorm3D(num_channels=out_channels, eps=eps) + self.qkv = Conv3D( + in_channels=out_channels, + out_channels=out_channels * 3, + kernel=1, + **(init_attn if init_attn is not None else init), + ) + self.proj = Conv3D( + in_channels=out_channels, + out_channels=out_channels, + kernel=1, + **init_zero, + ) + + def forward( + self, x: Float[torch.Tensor, "B C D H W"] + ) -> Float[torch.Tensor, "B C D H W"]: + x1 = self.qkv(self.norm(x)) # (B, 3C, D, H, W) + + # Reshape for multi-head attention over flattened spatial dims D*H*W + qkv = ( + x1.reshape(x.shape[0], self.num_heads, x.shape[1] // self.num_heads, 3, -1) + ).permute(0, 1, 4, 3, 2) # (B, num_heads, D*H*W, 3, C//num_heads) + q, k, v = (qkv[..., i, :] for i in range(3)) + + attn = scaled_dot_product_attention( + q, k, v, scale=1 / math.sqrt(k.shape[-1]) + ) # (B, num_heads, D*H*W, C//num_heads) + + attn = attn.transpose(-1, -2) # (B, num_heads, C//num_heads, D*H*W) + return self.proj(attn.reshape(*x.shape)).add_(x) # residual, (B, C, D, H, W) + + +class UNetBlock3D(Module): + r""" + Residual U-Net block for 3D volumetric inputs with an external embedding input. + + Applies a residual block with optional up/downsampling and self-attention, + conditioned on an external vector input :math:`\mathbf{e}` via an affine + transformation on intermediate features. The architecture combines elements + from the DDPM++, NCSN++, and ADM U-Net designs and is suitable for any + backbone that needs a conditioned 3D residual block. + + Parameters + ---------- + in_channels : int + Number of input channels :math:`C_{in}`. + out_channels : int + Number of output channels :math:`C_{out}`. + emb_channels : int + Dimension :math:`C_{emb}` of the external embedding vector :math:`\mathbf{e}`. + :math:`\mathbf{e}` is broadcast spatially and consumed by the affine + conditioning step. It can be any vector-valued input (e.g. a diffusion-time + embedding, a sinusoidal positional code, a learned class embedding, etc.). + up : bool, optional, default=False + Apply 2x upsampling to the feature map in the first convolution. + down : bool, optional, default=False + Apply 2x downsampling to the feature map in the first convolution. + attention : bool, optional, default=False + Apply 3D self-attention after the residual branch. + num_heads : int or None, optional, default=None + Number of attention heads when ``attention=True``. Defaults to 1 if ``None``. + Ignored when ``attention=False``. + dropout : float, optional, default=0.0 + Dropout probability applied before the second convolution. + skip_scale : float, optional, default=1.0 + Scale factor applied to the residual output and (if attention is enabled) + to the attention residual. + eps : float, optional, default=1e-5 + Epsilon for :class:`GroupNorm3D` normalization layers. + resample_filter : list[int], optional, default=[1, 1] + 1D filter coefficients for up/downsampling. Passed to :class:`Conv3D`. + resample_proj : bool, optional, default=False + Use a :math:`1 \times 1 \times 1` projection in the skip path when + the number of channels or the resolution changes. + adaptive_scale : bool, optional, default=True + If ``True``, apply FiLM-style scale-and-shift affine conditioning. + If ``False``, apply additive shift only. + activation : Literal["silu", "gelu"], optional, default="silu" + Activation function applied after normalization layers. + init : dict, optional, default={} + Weight initialization kwargs for convolutions and linear layers. + init_zero : dict, optional, default={'init_weight': 0} + Weight initialization kwargs with near-zero weights for the output convolution. + init_attn : dict or None, optional, default=None + Weight initialization kwargs for the attention QKV projection. Defaults to + ``init`` if ``None``. + + Forward + ------- + x : torch.Tensor + Input feature map of shape :math:`(B, C_{in}, D, H, W)`. + emb : torch.Tensor + External vector input of shape :math:`(B, C_{emb})`. Used for affine + conditioning of intermediate features. + + Outputs + ------- + torch.Tensor + Output feature map of shape :math:`(B, C_{out}, D', H', W')` where + :math:`D', H', W'` are halved (``down=True``), doubled (``up=True``), + or equal to the input spatial dimensions. + + Examples + -------- + >>> import torch + >>> from physicsnemo.experimental.nn import UNetBlock3D + >>> block = UNetBlock3D(in_channels=8, out_channels=16, emb_channels=32) + >>> x = torch.randn(2, 8, 4, 12, 16) + >>> emb = torch.randn(2, 32) + >>> y = block(x, emb) + >>> y.shape + torch.Size([2, 16, 4, 12, 16]) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + emb_channels: int, + up: bool = False, + down: bool = False, + attention: bool = False, + num_heads: int | None = None, + dropout: float = 0.0, + skip_scale: float = 1.0, + eps: float = 1e-5, + resample_filter: List[int] = [1, 1], + resample_proj: bool = False, + adaptive_scale: bool = True, + activation: Literal["silu", "gelu"] = "silu", + init: Dict[str, Any] = dict(), + init_zero: Dict[str, Any] = dict(init_weight=0), + init_attn: Any = None, + ): + super().__init__(meta=ModelMetaData()) + + self.in_channels = in_channels + self.out_channels = out_channels + self.emb_channels = emb_channels + self.attention = attention + self.dropout = dropout + self.skip_scale = skip_scale + self.adaptive_scale = adaptive_scale + self.act = silu if activation == "silu" else torch.nn.functional.gelu + + self.norm0 = GroupNorm3D(num_channels=in_channels, eps=eps) + self.conv0 = Conv3D( + in_channels=in_channels, + out_channels=out_channels, + kernel=3, + up=up, + down=down, + resample_filter=resample_filter, + **init, + ) + self.affine = Linear( + in_features=emb_channels, + out_features=out_channels * (2 if adaptive_scale else 1), + **init, + ) + self.norm1 = GroupNorm3D(num_channels=out_channels, eps=eps) + self.conv1 = Conv3D( + in_channels=out_channels, out_channels=out_channels, kernel=3, **init_zero + ) + + self.skip = None + if out_channels != in_channels or up or down: + kernel = 1 if resample_proj or out_channels != in_channels else 0 + self.skip = Conv3D( + in_channels=in_channels, + out_channels=out_channels, + kernel=kernel, + up=up, + down=down, + resample_filter=resample_filter, + **init, + ) + + if self.attention: + self.attn = UNetAttention3D( + out_channels=out_channels, + num_heads=num_heads if num_heads is not None else 1, + eps=eps, + init=init, + init_zero=init_zero, + init_attn=init_attn, + ) + + def forward( + self, + x: Float[torch.Tensor, "B C_in D H W"], + emb: Float[torch.Tensor, "B C_emb"], + ) -> Float[torch.Tensor, "B C_out D_out H_out W_out"]: + orig = x + + # First norm + conv (with optional up/down) + x = self.conv0(self.act(self.norm0(x))) + + # Affine conditioning from emb, broadcast over spatial dims + params = self.affine(emb).unsqueeze(2).unsqueeze(3).unsqueeze(4).to(x.dtype) + if self.adaptive_scale: + scale, shift = params.chunk(chunks=2, dim=1) + x = self.act(torch.addcmul(shift, self.norm1(x), scale + 1)) + else: + x = self.act(self.norm1(x.add_(params))) + + # Second conv with dropout and residual connection + x = self.conv1(dropout(x, p=self.dropout, training=self.training)) + x = x.add_(self.skip(orig) if self.skip is not None else orig) + x = x * self.skip_scale + + # Optional self-attention with residual scaling + if self.attention: + x = self.attn(x) + x = x * self.skip_scale + + return x diff --git a/test/domain_parallel/models/test_diffusion_unet_3d.py b/test/domain_parallel/models/test_diffusion_unet_3d.py new file mode 100644 index 0000000000..75ff4bfef1 --- /dev/null +++ b/test/domain_parallel/models/test_diffusion_unet_3d.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Domain-parallel sanity tests for DiffusionUNet3D. + +These tests verify that the 3D U-Net backbone integrates with +``ShardTensor`` and ``distribute_module``: spatial inputs sharded along a +spatial axis flow through the encoder/decoder, produce an output of the +expected dense shape, and stay sharded on the same mesh. They mirror the +patterns used in ``test/domain_parallel/models/test_transolver.py``. +""" + +import pytest +import torch +from tensordict import TensorDict +from torch.distributed.tensor import distribute_module +from torch.distributed.tensor.placement_types import Shard + +from physicsnemo.distributed import DistributedManager +from physicsnemo.domain_parallel import scatter_tensor +from physicsnemo.experimental.models.diffusion_unets import DiffusionUNet3D + + +def _build_model(x_channels, vol_cond_channels=0, vec_cond_dim=0, device="cpu"): + """Construct a small DiffusionUNet3D suitable for distributed sanity checks.""" + return DiffusionUNet3D( + x_channels=x_channels, + vol_cond_channels=vol_cond_channels, + vec_cond_dim=vec_cond_dim, + num_levels=2, + model_channels=16, + channel_mult=[1, 2], + num_blocks=1, + attention_levels=[1], + dropout=0.0, + ).to(device) + + +@pytest.mark.multigpu_static +def test_diffusion_unet_3d_distributed(distributed_mesh): + """Unconditional forward with `x` sharded along the H axis.""" + dm = DistributedManager() + B, C, D, H, W = 2, 2, 8, 32, 32 + + model = _build_model(x_channels=C, device=dm.device) + model = distribute_module(model, device_mesh=distributed_mesh) + + x = torch.randn(B, C, D, H, W, device=dm.device) + t = torch.rand(B, device=dm.device) + + # Shard along H (dim 3): every spatial dim of a 5D volume is fair game, + # H is the conventional choice in the 2D SongUNet tests. + placements = (Shard(3),) + x_sharded = scatter_tensor(x, 0, distributed_mesh, placements, requires_grad=False) + + out = model(x_sharded, t) + + assert out.shape == (B, C, D, H, W) + assert out._spec.placements == x_sharded._spec.placements + + +@pytest.mark.multigpu_static +def test_diffusion_unet_3d_conditional_distributed(distributed_mesh): + """Forward with vector + volume conditioning, sharded along H.""" + dm = DistributedManager() + B, C, D, H, W = 2, 2, 8, 32, 32 + C_vol, D_vec = 2, 8 + + model = _build_model( + x_channels=C, + vol_cond_channels=C_vol, + vec_cond_dim=D_vec, + device=dm.device, + ) + model = distribute_module(model, device_mesh=distributed_mesh) + + x = torch.randn(B, C, D, H, W, device=dm.device) + t = torch.rand(B, device=dm.device) + volume = torch.randn(B, C_vol, D, H, W, device=dm.device) + vector = torch.randn(B, D_vec, device=dm.device) + + placements = (Shard(3),) + x_sharded = scatter_tensor(x, 0, distributed_mesh, placements, requires_grad=False) + volume_sharded = scatter_tensor( + volume, 0, distributed_mesh, placements, requires_grad=False + ) + + condition = TensorDict( + {"vector": vector, "volume": volume_sharded}, + batch_size=[B], + ) + + out = model(x_sharded, t, condition=condition) + + assert out.shape == (B, C, D, H, W) + assert out._spec.placements == x_sharded._spec.placements diff --git a/test/models/diffusion/_helpers.py b/test/models/diffusion/_helpers.py new file mode 100644 index 0000000000..c0cf0f4776 --- /dev/null +++ b/test/models/diffusion/_helpers.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared helper functions for diffusion model/block tests in this package. + +Pytest fixtures (``deterministic_settings``, ``tolerances``, ``nop_compile``, +``reset_dynamo``) live in ``conftest.py`` and are auto-discovered. +""" + +from pathlib import Path +from typing import Any, Callable, Dict, Tuple + +import torch + +import physicsnemo.core + +# ============================================================================= +# Constants +# ============================================================================= + +GLOBAL_SEED = 42 +DATA_DIR = Path(__file__).parent / "data" + + +# ============================================================================= +# Helper functions +# ============================================================================= + + +def instantiate_model_deterministic(cls, seed: int = 0, **kwargs: Any): + """Instantiate a model with deterministic random parameters.""" + model = cls(**kwargs) + gen = torch.Generator(device="cpu") + gen.manual_seed(seed) + with torch.no_grad(): + for param in model.parameters(): + param.copy_(torch.randn(param.shape, generator=gen, dtype=param.dtype)) + return model + + +def load_or_create_reference( + file_name: str, + compute_fn: Callable[[], Dict[str, torch.Tensor]], +) -> Dict[str, torch.Tensor]: + """Load a saved reference file, or create+save it on first run.""" + path = DATA_DIR / file_name + if path.exists(): + return torch.load(path, weights_only=True) + DATA_DIR.mkdir(parents=True, exist_ok=True) + data = compute_fn() + data_cpu = { + k: (v.cpu() if isinstance(v, torch.Tensor) else v) for k, v in data.items() + } + torch.save(data_cpu, path) + return data + + +def load_or_create_checkpoint( + checkpoint_name: str, create_fn: Callable[[], physicsnemo.core.Module] +): + """Load a saved checkpoint, or create+save it on first run.""" + path = DATA_DIR / checkpoint_name + if path.exists(): + return physicsnemo.core.Module.from_checkpoint(str(path)) + DATA_DIR.mkdir(parents=True, exist_ok=True) + model = create_fn() + model.save(str(path)) + return model + + +def compare_outputs(actual: torch.Tensor, expected: torch.Tensor, **tol: Any) -> None: + """Compare two tensors with detailed shape/value reporting.""" + if actual.shape != expected.shape: + raise AssertionError( + f"Shape mismatch: actual {actual.shape} vs expected {expected.shape}" + ) + a64 = actual.to(torch.float64) + e64 = expected.to(device=actual.device, dtype=torch.float64) + torch.testing.assert_close(a64, e64, **tol) + + +def make_input(shape: Tuple[int, ...], seed: int, device: str) -> torch.Tensor: + """Create a deterministic random input tensor.""" + gen = torch.Generator(device="cpu") + gen.manual_seed(seed) + return torch.randn(*shape, generator=gen).to(device) diff --git a/test/models/diffusion/conftest.py b/test/models/diffusion/conftest.py index d72f3a8205..49b118b01b 100644 --- a/test/models/diffusion/conftest.py +++ b/test/models/diffusion/conftest.py @@ -14,12 +14,75 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random + import pytest +import torch +import torch._dynamo from physicsnemo.core.version_check import check_version_spec _APEX_AVAILABLE = check_version_spec("apex", hard_fail=False) +_GLOBAL_SEED = 42 + + +def _nop_backend(gm, _inputs): + def forward(*args, **kwargs): + return gm.forward(*args, **kwargs) + + return forward + + +@pytest.fixture(autouse=True) +def reset_dynamo(): + """Reset torch._dynamo state between tests to avoid cross-test recompile errors.""" + torch._dynamo.reset() + torch._dynamo.config.error_on_recompile = False + yield + torch._dynamo.reset() + torch._dynamo.config.error_on_recompile = False + + +@pytest.fixture +def nop_compile(monkeypatch): + """Redirect torch.compile to a no-op backend for fast compile-shape tests.""" + original = torch.compile + monkeypatch.setattr( + torch, + "compile", + lambda fn, *args, backend=_nop_backend, **kwargs: original( + fn, *args, backend=backend, **kwargs + ), + ) + + +@pytest.fixture +def deterministic_settings(): + """Set deterministic settings for reproducibility, then restore old state.""" + old_cudnn_deterministic = torch.backends.cudnn.deterministic + old_cudnn_benchmark = torch.backends.cudnn.benchmark + old_matmul_tf32 = torch.backends.cuda.matmul.allow_tf32 + old_cudnn_tf32 = torch.backends.cudnn.allow_tf32 + old_random_state = random.getstate() + + try: + random.seed(_GLOBAL_SEED) + torch.manual_seed(_GLOBAL_SEED) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(_GLOBAL_SEED) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + yield + finally: + torch.backends.cudnn.deterministic = old_cudnn_deterministic + torch.backends.cudnn.benchmark = old_cudnn_benchmark + torch.backends.cuda.matmul.allow_tf32 = old_matmul_tf32 + torch.backends.cudnn.allow_tf32 = old_cudnn_tf32 + random.setstate(old_random_state) + @pytest.fixture def apex_device(request, device): diff --git a/test/models/diffusion/data/conv3d_down.mdlus b/test/models/diffusion/data/conv3d_down.mdlus new file mode 100644 index 0000000000..1f78d2fefd Binary files /dev/null and b/test/models/diffusion/data/conv3d_down.mdlus differ diff --git a/test/models/diffusion/data/conv3d_down_forward.pth b/test/models/diffusion/data/conv3d_down_forward.pth new file mode 100644 index 0000000000..e6aef52e21 Binary files /dev/null and b/test/models/diffusion/data/conv3d_down_forward.pth differ diff --git a/test/models/diffusion/data/conv3d_no_bias_xavier.mdlus b/test/models/diffusion/data/conv3d_no_bias_xavier.mdlus new file mode 100644 index 0000000000..6c46b78f3a Binary files /dev/null and b/test/models/diffusion/data/conv3d_no_bias_xavier.mdlus differ diff --git a/test/models/diffusion/data/conv3d_no_bias_xavier_forward.pth b/test/models/diffusion/data/conv3d_no_bias_xavier_forward.pth new file mode 100644 index 0000000000..8037bba83b Binary files /dev/null and b/test/models/diffusion/data/conv3d_no_bias_xavier_forward.pth differ diff --git a/test/models/diffusion/data/conv3d_plain.mdlus b/test/models/diffusion/data/conv3d_plain.mdlus new file mode 100644 index 0000000000..758ddb03bf Binary files /dev/null and b/test/models/diffusion/data/conv3d_plain.mdlus differ diff --git a/test/models/diffusion/data/conv3d_plain_forward.pth b/test/models/diffusion/data/conv3d_plain_forward.pth new file mode 100644 index 0000000000..3998517c02 Binary files /dev/null and b/test/models/diffusion/data/conv3d_plain_forward.pth differ diff --git a/test/models/diffusion/data/conv3d_up_ncsnpp.mdlus b/test/models/diffusion/data/conv3d_up_ncsnpp.mdlus new file mode 100644 index 0000000000..36198aaaf5 Binary files /dev/null and b/test/models/diffusion/data/conv3d_up_ncsnpp.mdlus differ diff --git a/test/models/diffusion/data/conv3d_up_ncsnpp_forward.pth b/test/models/diffusion/data/conv3d_up_ncsnpp_forward.pth new file mode 100644 index 0000000000..ba66ea1a9a Binary files /dev/null and b/test/models/diffusion/data/conv3d_up_ncsnpp_forward.pth differ diff --git a/test/models/diffusion/data/diffusion_unet_3d_advanced.mdlus b/test/models/diffusion/data/diffusion_unet_3d_advanced.mdlus new file mode 100644 index 0000000000..7578d9c022 Binary files /dev/null and b/test/models/diffusion/data/diffusion_unet_3d_advanced.mdlus differ diff --git a/test/models/diffusion/data/diffusion_unet_3d_advanced_forward.pth b/test/models/diffusion/data/diffusion_unet_3d_advanced_forward.pth new file mode 100644 index 0000000000..cc3fe32847 Binary files /dev/null and b/test/models/diffusion/data/diffusion_unet_3d_advanced_forward.pth differ diff --git a/test/models/diffusion/data/diffusion_unet_3d_conditional.mdlus b/test/models/diffusion/data/diffusion_unet_3d_conditional.mdlus new file mode 100644 index 0000000000..404e9885ae Binary files /dev/null and b/test/models/diffusion/data/diffusion_unet_3d_conditional.mdlus differ diff --git a/test/models/diffusion/data/diffusion_unet_3d_conditional_forward.pth b/test/models/diffusion/data/diffusion_unet_3d_conditional_forward.pth new file mode 100644 index 0000000000..1f142836af Binary files /dev/null and b/test/models/diffusion/data/diffusion_unet_3d_conditional_forward.pth differ diff --git a/test/models/diffusion/data/diffusion_unet_3d_default.mdlus b/test/models/diffusion/data/diffusion_unet_3d_default.mdlus new file mode 100644 index 0000000000..b0337c5bfc Binary files /dev/null and b/test/models/diffusion/data/diffusion_unet_3d_default.mdlus differ diff --git a/test/models/diffusion/data/diffusion_unet_3d_default_forward.pth b/test/models/diffusion/data/diffusion_unet_3d_default_forward.pth new file mode 100644 index 0000000000..98b3ef22d9 Binary files /dev/null and b/test/models/diffusion/data/diffusion_unet_3d_default_forward.pth differ diff --git a/test/models/diffusion/data/groupnorm3d_custom_groups.mdlus b/test/models/diffusion/data/groupnorm3d_custom_groups.mdlus new file mode 100644 index 0000000000..710f20bb32 Binary files /dev/null and b/test/models/diffusion/data/groupnorm3d_custom_groups.mdlus differ diff --git a/test/models/diffusion/data/groupnorm3d_custom_groups_forward.pth b/test/models/diffusion/data/groupnorm3d_custom_groups_forward.pth new file mode 100644 index 0000000000..8e2a6e8886 Binary files /dev/null and b/test/models/diffusion/data/groupnorm3d_custom_groups_forward.pth differ diff --git a/test/models/diffusion/data/groupnorm3d_default.mdlus b/test/models/diffusion/data/groupnorm3d_default.mdlus new file mode 100644 index 0000000000..f39703358e Binary files /dev/null and b/test/models/diffusion/data/groupnorm3d_default.mdlus differ diff --git a/test/models/diffusion/data/groupnorm3d_default_forward.pth b/test/models/diffusion/data/groupnorm3d_default_forward.pth new file mode 100644 index 0000000000..172095f36f Binary files /dev/null and b/test/models/diffusion/data/groupnorm3d_default_forward.pth differ diff --git a/test/models/diffusion/data/groupnorm3d_min_per_group.mdlus b/test/models/diffusion/data/groupnorm3d_min_per_group.mdlus new file mode 100644 index 0000000000..44f53c3dc3 Binary files /dev/null and b/test/models/diffusion/data/groupnorm3d_min_per_group.mdlus differ diff --git a/test/models/diffusion/data/groupnorm3d_min_per_group_forward.pth b/test/models/diffusion/data/groupnorm3d_min_per_group_forward.pth new file mode 100644 index 0000000000..1a5ec58413 Binary files /dev/null and b/test/models/diffusion/data/groupnorm3d_min_per_group_forward.pth differ diff --git a/test/models/diffusion/data/unet_attention_3d_custom_eps.mdlus b/test/models/diffusion/data/unet_attention_3d_custom_eps.mdlus new file mode 100644 index 0000000000..99a7be65b2 Binary files /dev/null and b/test/models/diffusion/data/unet_attention_3d_custom_eps.mdlus differ diff --git a/test/models/diffusion/data/unet_attention_3d_custom_eps_forward.pth b/test/models/diffusion/data/unet_attention_3d_custom_eps_forward.pth new file mode 100644 index 0000000000..1f684c959c Binary files /dev/null and b/test/models/diffusion/data/unet_attention_3d_custom_eps_forward.pth differ diff --git a/test/models/diffusion/data/unet_attention_3d_multi_head.mdlus b/test/models/diffusion/data/unet_attention_3d_multi_head.mdlus new file mode 100644 index 0000000000..e090a61a39 Binary files /dev/null and b/test/models/diffusion/data/unet_attention_3d_multi_head.mdlus differ diff --git a/test/models/diffusion/data/unet_attention_3d_multi_head_forward.pth b/test/models/diffusion/data/unet_attention_3d_multi_head_forward.pth new file mode 100644 index 0000000000..aba359137d Binary files /dev/null and b/test/models/diffusion/data/unet_attention_3d_multi_head_forward.pth differ diff --git a/test/models/diffusion/data/unet_attention_3d_single_head.mdlus b/test/models/diffusion/data/unet_attention_3d_single_head.mdlus new file mode 100644 index 0000000000..02e7943e2a Binary files /dev/null and b/test/models/diffusion/data/unet_attention_3d_single_head.mdlus differ diff --git a/test/models/diffusion/data/unet_attention_3d_single_head_forward.pth b/test/models/diffusion/data/unet_attention_3d_single_head_forward.pth new file mode 100644 index 0000000000..09d5f12fc7 Binary files /dev/null and b/test/models/diffusion/data/unet_attention_3d_single_head_forward.pth differ diff --git a/test/models/diffusion/data/unet_block_3d_attention_multi_head.mdlus b/test/models/diffusion/data/unet_block_3d_attention_multi_head.mdlus new file mode 100644 index 0000000000..cc8f64bb0c Binary files /dev/null and b/test/models/diffusion/data/unet_block_3d_attention_multi_head.mdlus differ diff --git a/test/models/diffusion/data/unet_block_3d_attention_multi_head_forward.pth b/test/models/diffusion/data/unet_block_3d_attention_multi_head_forward.pth new file mode 100644 index 0000000000..0b8827e5ca Binary files /dev/null and b/test/models/diffusion/data/unet_block_3d_attention_multi_head_forward.pth differ diff --git a/test/models/diffusion/data/unet_block_3d_down_adaptive.mdlus b/test/models/diffusion/data/unet_block_3d_down_adaptive.mdlus new file mode 100644 index 0000000000..8c2400ad1d Binary files /dev/null and b/test/models/diffusion/data/unet_block_3d_down_adaptive.mdlus differ diff --git a/test/models/diffusion/data/unet_block_3d_down_adaptive_forward.pth b/test/models/diffusion/data/unet_block_3d_down_adaptive_forward.pth new file mode 100644 index 0000000000..2521db34bd Binary files /dev/null and b/test/models/diffusion/data/unet_block_3d_down_adaptive_forward.pth differ diff --git a/test/models/diffusion/data/unet_block_3d_plain.mdlus b/test/models/diffusion/data/unet_block_3d_plain.mdlus new file mode 100644 index 0000000000..45c56b9a20 Binary files /dev/null and b/test/models/diffusion/data/unet_block_3d_plain.mdlus differ diff --git a/test/models/diffusion/data/unet_block_3d_plain_forward.pth b/test/models/diffusion/data/unet_block_3d_plain_forward.pth new file mode 100644 index 0000000000..f29f818402 Binary files /dev/null and b/test/models/diffusion/data/unet_block_3d_plain_forward.pth differ diff --git a/test/models/diffusion/data/unet_block_3d_up_gelu.mdlus b/test/models/diffusion/data/unet_block_3d_up_gelu.mdlus new file mode 100644 index 0000000000..c8eb476474 Binary files /dev/null and b/test/models/diffusion/data/unet_block_3d_up_gelu.mdlus differ diff --git a/test/models/diffusion/data/unet_block_3d_up_gelu_forward.pth b/test/models/diffusion/data/unet_block_3d_up_gelu_forward.pth new file mode 100644 index 0000000000..5eda30e9c6 Binary files /dev/null and b/test/models/diffusion/data/unet_block_3d_up_gelu_forward.pth differ diff --git a/test/models/diffusion/test_diffusion_unet_3d.py b/test/models/diffusion/test_diffusion_unet_3d.py new file mode 100644 index 0000000000..e24081679b --- /dev/null +++ b/test/models/diffusion/test_diffusion_unet_3d.py @@ -0,0 +1,309 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for DiffusionUNet3D.""" + +from typing import Any, Dict, Tuple + +import pytest +import torch +import torch._dynamo +from tensordict import TensorDict + +from physicsnemo.experimental.models.diffusion_unets import DiffusionUNet3D +from test.models.diffusion._helpers import ( + GLOBAL_SEED, + compare_outputs, + instantiate_model_deterministic, + load_or_create_checkpoint, + load_or_create_reference, +) + +# Loose GPU tolerances are needed here because attention via SDPA returns +# meaningfully different values on CPU vs GPU (and across GPU architectures), +# and the test models are initialized with purely-random weights and inputs. +# Scoped to this file so the looseness doesn't leak into sibling tests. +_CPU_TOLERANCES = {"atol": 1e-3, "rtol": 1e-3} +_GPU_TOLERANCES = {"atol": 1e-2, "rtol": 5e-2} + + +@pytest.fixture +def tolerances(device): + return _CPU_TOLERANCES if device == "cpu" else _GPU_TOLERANCES + + +# ============================================================================= +# Architecture configurations +# ============================================================================= + +# (name, arch_kwargs, x_shape) — minimal sizes that exercise every code path +ARCH_CONFIGS: Tuple[Tuple[str, Dict[str, Any], Tuple[int, int, int, int, int]], ...] = ( + ( + "default", + dict( + x_channels=2, + num_levels=2, + model_channels=8, + channel_mult=[1, 2], + num_blocks=1, + dropout=0.0, + ), + (2, 2, 4, 8, 8), + ), + ( + "conditional", + dict( + x_channels=2, + vol_cond_channels=2, + vec_cond_dim=4, + num_levels=2, + model_channels=8, + channel_mult=[1, 2], + num_blocks=1, + attention_levels=[1], + dropout=0.0, + ), + (2, 2, 4, 8, 8), + ), + ( + "advanced", + dict( + x_channels=2, + vol_cond_channels=1, + vec_cond_dim=2, + num_levels=3, + model_channels=8, + channel_mult=[1, 2, 2], + num_blocks=1, + attention_levels=[2], + embedding_type="fourier", + channel_mult_noise=2, + encoder_type="residual", + decoder_type="skip", + resample_filter=[1, 3, 3, 1], + bottleneck_attention=False, + activation="gelu", + dropout=0.0, + ), + (2, 2, 4, 8, 8), + ), +) + + +def _generate_batch_data( + arch_kwargs: Dict[str, Any], + x_shape: Tuple[int, int, int, int, int], + seed: int, + device: str, +) -> Dict[str, Any]: + """Generate a deterministic (x, t, condition) tuple for the architecture.""" + gen = torch.Generator(device="cpu") + gen.manual_seed(seed) + + B = x_shape[0] + x = torch.randn(*x_shape, generator=gen).to(device) + t = (torch.rand(B, generator=gen) * 0.5 + 0.4).to(device) + + cond_entries: Dict[str, torch.Tensor] = {} + if arch_kwargs.get("vec_cond_dim", 0) > 0: + cond_entries["vector"] = torch.randn( + B, arch_kwargs["vec_cond_dim"], generator=gen + ).to(device) + if arch_kwargs.get("vol_cond_channels", 0) > 0: + cond_entries["volume"] = torch.randn( + B, + arch_kwargs["vol_cond_channels"], + *x_shape[2:], + generator=gen, + ).to(device) + + condition = ( + TensorDict(cond_entries, batch_size=[B]).to(device) if cond_entries else None + ) + return {"x": x, "t": t, "condition": condition} + + +# ============================================================================= +# Constructor tests (default + invalid cases, not parametrized) +# ============================================================================= + + +class TestConstructor: + """Constructor / attribute tests not tied to a specific architecture.""" + + def test_default_attributes(self, device): + """Default-construction values match the documented defaults.""" + model = DiffusionUNet3D(x_channels=4).to(device) + assert model.x_channels == 4 + assert model.vol_cond_channels == 0 + assert model.vec_cond_dim == 0 + assert model.embedding_type == "positional" + assert model.num_levels == 4 + assert model.checkpoint_level == 0 + assert model.emb_channels == 128 * 4 # model_channels * channel_mult_emb + assert isinstance(model, DiffusionUNet3D) + + def test_invalid_channel_mult_raises(self): + with pytest.raises(ValueError, match="channel_mult"): + DiffusionUNet3D(x_channels=2, num_levels=4, channel_mult=[1, 2, 3]) + + def test_invalid_attention_level_raises(self): + with pytest.raises(ValueError, match="attention_levels"): + DiffusionUNet3D( + x_channels=2, + num_levels=2, + channel_mult=[1, 2], + attention_levels=[5], + ) + + def test_zero_embedding_with_condition_raises(self): + with pytest.raises(ValueError, match="embedding_type='zero'"): + DiffusionUNet3D( + x_channels=2, + vec_cond_dim=4, + num_levels=2, + channel_mult=[1, 2], + embedding_type="zero", + ) + + +# ============================================================================= +# Architecture tests (class-level parametrize over ARCH_CONFIGS) +# ============================================================================= + + +@pytest.mark.parametrize( + "arch_name,arch_kwargs,x_shape", + ARCH_CONFIGS, + ids=[c[0] for c in ARCH_CONFIGS], +) +class TestArchitecture: + """Tests parameterized across every architecture configuration.""" + + def test_attributes_match_kwargs(self, arch_name, arch_kwargs, x_shape, device): + """Every public attribute reflects the kwargs (or its documented default).""" + model = DiffusionUNet3D(**arch_kwargs).to(device) + + assert model.x_channels == arch_kwargs["x_channels"] + assert model.vol_cond_channels == arch_kwargs.get("vol_cond_channels", 0) + assert model.vec_cond_dim == arch_kwargs.get("vec_cond_dim", 0) + assert model.num_levels == arch_kwargs["num_levels"] + assert model.embedding_type == arch_kwargs.get("embedding_type", "positional") + assert model.checkpoint_level == arch_kwargs.get("checkpoint_level", 0) + + expected_emb = arch_kwargs.get("model_channels", 128) * arch_kwargs.get( + "channel_mult_emb", 4 + ) + assert model.emb_channels == expected_emb + + def test_forward_non_regression( + self, + deterministic_settings, + arch_name, + arch_kwargs, + x_shape, + device, + tolerances, + ): + """Forward output matches a saved reference.""" + model = instantiate_model_deterministic( + DiffusionUNet3D, seed=0, **arch_kwargs + ).to(device) + data = _generate_batch_data(arch_kwargs, x_shape, GLOBAL_SEED, device) + out = model(data["x"], data["t"], condition=data["condition"]) + + ref_file = f"diffusion_unet_3d_{arch_name}_forward.pth" + ref = load_or_create_reference(ref_file, lambda: {"out": out.cpu()}) + compare_outputs(out, ref["out"], **tolerances) + + def test_forward_from_checkpoint( + self, + deterministic_settings, + arch_name, + arch_kwargs, + x_shape, + device, + tolerances, + ): + """Forward output from a loaded checkpoint matches the same reference.""" + + def create_fn(): + return instantiate_model_deterministic( + DiffusionUNet3D, seed=0, **arch_kwargs + ) + + ckpt_file = f"diffusion_unet_3d_{arch_name}.mdlus" + model = load_or_create_checkpoint(ckpt_file, create_fn).to(device) + data = _generate_batch_data(arch_kwargs, x_shape, GLOBAL_SEED, device) + out = model(data["x"], data["t"], condition=data["condition"]) + + ref_file = f"diffusion_unet_3d_{arch_name}_forward.pth" + ref = load_or_create_reference(ref_file, lambda: {"out": out.cpu()}) + compare_outputs(out, ref["out"], **tolerances) + + def test_forward_output(self, arch_name, arch_kwargs, x_shape, device): + """Forward returns a tensor of the expected shape and dtype.""" + model = instantiate_model_deterministic( + DiffusionUNet3D, seed=0, **arch_kwargs + ).to(device) + data = _generate_batch_data(arch_kwargs, x_shape, GLOBAL_SEED, device) + out = model(data["x"], data["t"], condition=data["condition"]) + assert out.shape == x_shape + assert out.dtype == data["x"].dtype + + def test_gradient_flow(self, arch_name, arch_kwargs, x_shape, device): + """Gradients flow back through the model.""" + model = instantiate_model_deterministic( + DiffusionUNet3D, seed=0, **arch_kwargs + ).to(device) + data = _generate_batch_data(arch_kwargs, x_shape, GLOBAL_SEED, device) + x = data["x"].clone().requires_grad_(True) + out = model(x, data["t"], condition=data["condition"]) + out.sum().backward() + assert x.grad is not None + assert not torch.isnan(x.grad).any() + + @pytest.mark.usefixtures("nop_compile") + def test_compile( + self, + deterministic_settings, + arch_name, + arch_kwargs, + x_shape, + device, + ): + """Compiled forward matches eager and graph is reused on second call.""" + torch._dynamo.config.error_on_recompile = True + + # eval mode disables dropout so eager and compiled paths are deterministic + model = ( + instantiate_model_deterministic(DiffusionUNet3D, seed=0, **arch_kwargs) + .to(device) + .eval() + ) + data = _generate_batch_data(arch_kwargs, x_shape, GLOBAL_SEED, device) + x, t, cond = data["x"], data["t"], data["condition"] + + compiled = torch.compile(model, fullgraph=True) + + with torch.no_grad(): + out_eager = model(x, t, condition=cond) + out_compiled = compiled(x, t, condition=cond) + torch.testing.assert_close(out_eager, out_compiled) + + with torch.no_grad(): + out_compiled_2 = compiled(x, t, condition=cond) + torch.testing.assert_close(out_compiled, out_compiled_2) diff --git a/test/models/diffusion/test_layers_diffusion_unet_3d_blocks.py b/test/models/diffusion/test_layers_diffusion_unet_3d_blocks.py new file mode 100644 index 0000000000..13287de4e8 --- /dev/null +++ b/test/models/diffusion/test_layers_diffusion_unet_3d_blocks.py @@ -0,0 +1,562 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the 3D building blocks (GroupNorm3D, Conv3D, UNetAttention3D, UNetBlock3D).""" + +from typing import Any, Dict, Tuple + +import pytest +import torch +import torch._dynamo + +from physicsnemo.experimental.nn import ( + Conv3D, + GroupNorm3D, + UNetAttention3D, + UNetBlock3D, +) +from test.models.diffusion._helpers import ( + GLOBAL_SEED, + compare_outputs, + instantiate_model_deterministic, + load_or_create_checkpoint, + load_or_create_reference, + make_input, +) + +# Loose GPU tolerances are needed here because attention via SDPA returns +# meaningfully different values on CPU vs GPU (and across GPU architectures), +# and the test blocks are initialized with purely-random weights and inputs. +# Scoped to this file so the looseness doesn't leak into sibling tests. +_CPU_TOLERANCES = {"atol": 1e-3, "rtol": 1e-3} +_GPU_TOLERANCES = {"atol": 1e-2, "rtol": 5e-2} + + +@pytest.fixture +def tolerances(device): + return _CPU_TOLERANCES if device == "cpu" else _GPU_TOLERANCES + + +# ============================================================================= +# GroupNorm3D +# ============================================================================= + +# (name, kwargs) +GROUPNORM_CONFIGS: Tuple[Tuple[str, Dict[str, Any]], ...] = ( + ("default", dict(num_channels=32)), + ("custom_groups", dict(num_channels=16, num_groups=8, eps=1e-6)), + ("min_per_group", dict(num_channels=8, num_groups=32)), +) + + +@pytest.mark.parametrize( + "config_name,kwargs", + GROUPNORM_CONFIGS, + ids=[c[0] for c in GROUPNORM_CONFIGS], +) +class TestGroupNorm3D: + """Tests for GroupNorm3D, parameterized over configurations.""" + + def test_attributes_match_kwargs(self, config_name, kwargs, device): + gn = GroupNorm3D(**kwargs).to(device) + # num_groups is capped to keep at least min_channels_per_group=4 channels per group + expected_num_groups = min( + kwargs.get("num_groups", 32), + kwargs["num_channels"] // 4, + ) + assert gn.num_groups == expected_num_groups + assert gn.eps == kwargs.get("eps", 1e-5) + assert gn.weight.shape == (kwargs["num_channels"],) + assert gn.bias.shape == (kwargs["num_channels"],) + + def test_forward_non_regression( + self, + deterministic_settings, + config_name, + kwargs, + device, + tolerances, + ): + gn = instantiate_model_deterministic(GroupNorm3D, seed=0, **kwargs).to(device) + x = make_input( + (2, kwargs["num_channels"], 4, 8, 8), seed=GLOBAL_SEED, device=device + ) + out = gn(x) + ref_file = f"groupnorm3d_{config_name}_forward.pth" + ref = load_or_create_reference(ref_file, lambda: {"out": out.cpu()}) + compare_outputs(out, ref["out"], **tolerances) + + def test_forward_from_checkpoint( + self, + deterministic_settings, + config_name, + kwargs, + device, + tolerances, + ): + def create_fn(): + return instantiate_model_deterministic(GroupNorm3D, seed=0, **kwargs) + + ckpt_file = f"groupnorm3d_{config_name}.mdlus" + gn = load_or_create_checkpoint(ckpt_file, create_fn).to(device) + x = make_input( + (2, kwargs["num_channels"], 4, 8, 8), seed=GLOBAL_SEED, device=device + ) + out = gn(x) + ref_file = f"groupnorm3d_{config_name}_forward.pth" + ref = load_or_create_reference(ref_file, lambda: {"out": out.cpu()}) + compare_outputs(out, ref["out"], **tolerances) + + def test_gradient_flow(self, config_name, kwargs, device): + gn = GroupNorm3D(**kwargs).to(device) + x = torch.randn( + 2, kwargs["num_channels"], 4, 8, 8, device=device, requires_grad=True + ) + gn(x).sum().backward() + assert x.grad is not None + assert not torch.isnan(x.grad).any() + + @pytest.mark.usefixtures("nop_compile") + def test_compile(self, deterministic_settings, config_name, kwargs, device): + torch._dynamo.config.error_on_recompile = True + gn = GroupNorm3D(**kwargs).to(device).eval() + x = make_input( + (2, kwargs["num_channels"], 4, 8, 8), seed=GLOBAL_SEED, device=device + ) + compiled = torch.compile(gn, fullgraph=True) + with torch.no_grad(): + out_eager = gn(x) + out_compiled = compiled(x) + torch.testing.assert_close(out_eager, out_compiled) + with torch.no_grad(): + out_compiled_2 = compiled(x) + torch.testing.assert_close(out_compiled, out_compiled_2) + + +# ============================================================================= +# Conv3D +# ============================================================================= + +# (name, kwargs) +CONV3D_CONFIGS: Tuple[Tuple[str, Dict[str, Any]], ...] = ( + ( + "plain", + dict(in_channels=4, out_channels=8, kernel=3), + ), + ( + "down", + dict(in_channels=4, out_channels=8, kernel=3, down=True), + ), + ( + "up_ncsnpp", + dict( + in_channels=4, + out_channels=8, + kernel=3, + up=True, + resample_filter=[1, 3, 3, 1], + ), + ), + ( + "no_bias_xavier", + dict( + in_channels=4, + out_channels=8, + kernel=3, + bias=False, + init_mode="xavier_uniform", + ), + ), +) + + +class TestConv3DErrors: + """Constructor validation errors (not parametrized over configs).""" + + def test_up_down_both_raises(self): + with pytest.raises(ValueError, match="up.*down"): + Conv3D(in_channels=4, out_channels=8, kernel=3, up=True, down=True) + + def test_invalid_resample_filter_raises(self): + with pytest.raises(ValueError, match="resample_filter"): + Conv3D( + in_channels=4, + out_channels=8, + kernel=3, + down=True, + resample_filter=[], + ) + with pytest.raises(ValueError, match="resample_filter"): + Conv3D( + in_channels=4, + out_channels=8, + kernel=3, + down=True, + resample_filter=[1, 0], + ) + + +@pytest.mark.parametrize( + "config_name,kwargs", + CONV3D_CONFIGS, + ids=[c[0] for c in CONV3D_CONFIGS], +) +class TestConv3D: + """Tests for Conv3D, parameterized over configurations.""" + + def test_attributes_match_kwargs(self, config_name, kwargs, device): + conv = Conv3D(**kwargs).to(device) + assert conv.in_channels == kwargs["in_channels"] + assert conv.out_channels == kwargs["out_channels"] + assert conv.up == kwargs.get("up", False) + assert conv.down == kwargs.get("down", False) + if kwargs["kernel"] > 0: + assert conv.weight is not None + assert conv.weight.shape == ( + kwargs["out_channels"], + kwargs["in_channels"], + kwargs["kernel"], + kwargs["kernel"], + kwargs["kernel"], + ) + if kwargs.get("bias", True) and kwargs["kernel"] > 0: + assert conv.bias is not None + assert conv.bias.shape == (kwargs["out_channels"],) + else: + assert conv.bias is None + + def test_forward_non_regression( + self, + deterministic_settings, + config_name, + kwargs, + device, + tolerances, + ): + conv = instantiate_model_deterministic(Conv3D, seed=0, **kwargs).to(device) + x = make_input( + (2, kwargs["in_channels"], 4, 8, 8), seed=GLOBAL_SEED, device=device + ) + out = conv(x) + ref_file = f"conv3d_{config_name}_forward.pth" + ref = load_or_create_reference(ref_file, lambda: {"out": out.cpu()}) + compare_outputs(out, ref["out"], **tolerances) + + def test_forward_from_checkpoint( + self, + deterministic_settings, + config_name, + kwargs, + device, + tolerances, + ): + def create_fn(): + return instantiate_model_deterministic(Conv3D, seed=0, **kwargs) + + ckpt_file = f"conv3d_{config_name}.mdlus" + conv = load_or_create_checkpoint(ckpt_file, create_fn).to(device) + x = make_input( + (2, kwargs["in_channels"], 4, 8, 8), seed=GLOBAL_SEED, device=device + ) + out = conv(x) + ref_file = f"conv3d_{config_name}_forward.pth" + ref = load_or_create_reference(ref_file, lambda: {"out": out.cpu()}) + compare_outputs(out, ref["out"], **tolerances) + + def test_gradient_flow(self, config_name, kwargs, device): + conv = Conv3D(**kwargs).to(device) + x = torch.randn( + 2, kwargs["in_channels"], 4, 8, 8, device=device, requires_grad=True + ) + conv(x).sum().backward() + assert x.grad is not None + assert not torch.isnan(x.grad).any() + + @pytest.mark.usefixtures("nop_compile") + def test_compile(self, deterministic_settings, config_name, kwargs, device): + torch._dynamo.config.error_on_recompile = True + conv = Conv3D(**kwargs).to(device).eval() + x = make_input( + (2, kwargs["in_channels"], 4, 8, 8), seed=GLOBAL_SEED, device=device + ) + compiled = torch.compile(conv, fullgraph=True) + with torch.no_grad(): + out_eager = conv(x) + out_compiled = compiled(x) + torch.testing.assert_close(out_eager, out_compiled) + with torch.no_grad(): + out_compiled_2 = compiled(x) + torch.testing.assert_close(out_compiled, out_compiled_2) + + +# ============================================================================= +# UNetAttention3D +# ============================================================================= + +# (name, kwargs) +ATTENTION_CONFIGS: Tuple[Tuple[str, Dict[str, Any]], ...] = ( + ("single_head", dict(out_channels=16, num_heads=1)), + ("multi_head", dict(out_channels=16, num_heads=4)), + ("custom_eps", dict(out_channels=8, num_heads=2, eps=1e-6)), +) + + +class TestUNetAttention3DErrors: + """Constructor validation errors (not parametrized over configs).""" + + def test_invalid_num_heads_raises(self): + with pytest.raises(ValueError, match="num_heads"): + UNetAttention3D(out_channels=16, num_heads=0) + with pytest.raises(ValueError, match="num_heads"): + UNetAttention3D(out_channels=16, num_heads=-1) + + def test_indivisible_channels_raises(self): + with pytest.raises(ValueError, match="divisible"): + UNetAttention3D(out_channels=15, num_heads=4) + + +@pytest.mark.parametrize( + "config_name,kwargs", + ATTENTION_CONFIGS, + ids=[c[0] for c in ATTENTION_CONFIGS], +) +class TestUNetAttention3D: + """Tests for UNetAttention3D, parameterized over configurations.""" + + def test_attributes_match_kwargs(self, config_name, kwargs, device): + attn = UNetAttention3D(**kwargs).to(device) + assert attn.num_heads == kwargs["num_heads"] + + def test_forward_non_regression( + self, + deterministic_settings, + config_name, + kwargs, + device, + tolerances, + ): + attn = instantiate_model_deterministic(UNetAttention3D, seed=0, **kwargs).to( + device + ) + x = make_input( + (2, kwargs["out_channels"], 4, 8, 8), seed=GLOBAL_SEED, device=device + ) + out = attn(x) + ref_file = f"unet_attention_3d_{config_name}_forward.pth" + ref = load_or_create_reference(ref_file, lambda: {"out": out.cpu()}) + compare_outputs(out, ref["out"], **tolerances) + + def test_forward_from_checkpoint( + self, + deterministic_settings, + config_name, + kwargs, + device, + tolerances, + ): + def create_fn(): + return instantiate_model_deterministic(UNetAttention3D, seed=0, **kwargs) + + ckpt_file = f"unet_attention_3d_{config_name}.mdlus" + attn = load_or_create_checkpoint(ckpt_file, create_fn).to(device) + x = make_input( + (2, kwargs["out_channels"], 4, 8, 8), seed=GLOBAL_SEED, device=device + ) + out = attn(x) + ref_file = f"unet_attention_3d_{config_name}_forward.pth" + ref = load_or_create_reference(ref_file, lambda: {"out": out.cpu()}) + compare_outputs(out, ref["out"], **tolerances) + + def test_gradient_flow(self, config_name, kwargs, device): + attn = UNetAttention3D(**kwargs).to(device) + x = torch.randn( + 2, kwargs["out_channels"], 4, 8, 8, device=device, requires_grad=True + ) + attn(x).sum().backward() + assert x.grad is not None + assert not torch.isnan(x.grad).any() + + @pytest.mark.usefixtures("nop_compile") + def test_compile(self, deterministic_settings, config_name, kwargs, device): + torch._dynamo.config.error_on_recompile = True + attn = UNetAttention3D(**kwargs).to(device).eval() + x = make_input( + (2, kwargs["out_channels"], 4, 8, 8), seed=GLOBAL_SEED, device=device + ) + compiled = torch.compile(attn, fullgraph=True) + with torch.no_grad(): + out_eager = attn(x) + out_compiled = compiled(x) + torch.testing.assert_close(out_eager, out_compiled) + with torch.no_grad(): + out_compiled_2 = compiled(x) + torch.testing.assert_close(out_compiled, out_compiled_2) + + +# ============================================================================= +# UNetBlock3D +# ============================================================================= + +# (name, kwargs, x_shape, emb_shape) +BLOCK_CONFIGS: Tuple[ + Tuple[str, Dict[str, Any], Tuple[int, ...], Tuple[int, ...]], ... +] = ( + ( + "plain", + dict(in_channels=8, out_channels=16, emb_channels=32), + (2, 8, 4, 8, 8), + (2, 32), + ), + ( + "attention_multi_head", + dict( + in_channels=8, + out_channels=16, + emb_channels=32, + attention=True, + num_heads=4, + ), + (2, 8, 4, 8, 8), + (2, 32), + ), + ( + "down_adaptive", + dict( + in_channels=8, + out_channels=16, + emb_channels=32, + down=True, + adaptive_scale=True, + ), + (2, 8, 4, 8, 8), + (2, 32), + ), + ( + "up_gelu", + dict( + in_channels=8, + out_channels=16, + emb_channels=32, + up=True, + activation="gelu", + resample_filter=[1, 3, 3, 1], + ), + (2, 8, 4, 8, 8), + (2, 32), + ), +) + + +@pytest.mark.parametrize( + "config_name,kwargs,x_shape,emb_shape", + BLOCK_CONFIGS, + ids=[c[0] for c in BLOCK_CONFIGS], +) +class TestUNetBlock3D: + """Tests for UNetBlock3D, parameterized over configurations.""" + + def test_attributes_match_kwargs( + self, config_name, kwargs, x_shape, emb_shape, device + ): + block = UNetBlock3D(**kwargs).to(device) + assert block.in_channels == kwargs["in_channels"] + assert block.out_channels == kwargs["out_channels"] + assert block.emb_channels == kwargs["emb_channels"] + assert block.attention == kwargs.get("attention", False) + assert block.dropout == kwargs.get("dropout", 0.0) + assert block.skip_scale == kwargs.get("skip_scale", 1.0) + assert block.adaptive_scale == kwargs.get("adaptive_scale", True) + if kwargs.get("attention", False): + assert hasattr(block, "attn") + else: + assert not hasattr(block, "attn") + + def test_forward_non_regression( + self, + deterministic_settings, + config_name, + kwargs, + x_shape, + emb_shape, + device, + tolerances, + ): + block = instantiate_model_deterministic(UNetBlock3D, seed=0, **kwargs).to( + device + ) + x = make_input(x_shape, seed=GLOBAL_SEED, device=device) + emb = make_input(emb_shape, seed=GLOBAL_SEED + 1, device=device) + out = block(x, emb) + ref_file = f"unet_block_3d_{config_name}_forward.pth" + ref = load_or_create_reference(ref_file, lambda: {"out": out.cpu()}) + compare_outputs(out, ref["out"], **tolerances) + + def test_forward_from_checkpoint( + self, + deterministic_settings, + config_name, + kwargs, + x_shape, + emb_shape, + device, + tolerances, + ): + def create_fn(): + return instantiate_model_deterministic(UNetBlock3D, seed=0, **kwargs) + + ckpt_file = f"unet_block_3d_{config_name}.mdlus" + block = load_or_create_checkpoint(ckpt_file, create_fn).to(device) + x = make_input(x_shape, seed=GLOBAL_SEED, device=device) + emb = make_input(emb_shape, seed=GLOBAL_SEED + 1, device=device) + out = block(x, emb) + ref_file = f"unet_block_3d_{config_name}_forward.pth" + ref = load_or_create_reference(ref_file, lambda: {"out": out.cpu()}) + compare_outputs(out, ref["out"], **tolerances) + + def test_gradient_flow(self, config_name, kwargs, x_shape, emb_shape, device): + block = UNetBlock3D(**kwargs).to(device) + x = torch.randn(*x_shape, device=device, requires_grad=True) + emb = torch.randn(*emb_shape, device=device) + block(x, emb).sum().backward() + assert x.grad is not None + assert not torch.isnan(x.grad).any() + + @pytest.mark.usefixtures("nop_compile") + def test_compile( + self, + deterministic_settings, + config_name, + kwargs, + x_shape, + emb_shape, + device, + ): + torch._dynamo.config.error_on_recompile = True + block = ( + instantiate_model_deterministic(UNetBlock3D, seed=0, **kwargs) + .to(device) + .eval() + ) + x = make_input(x_shape, seed=GLOBAL_SEED, device=device) + emb = make_input(emb_shape, seed=GLOBAL_SEED + 1, device=device) + compiled = torch.compile(block, fullgraph=True) + with torch.no_grad(): + out_eager = block(x, emb) + out_compiled = compiled(x, emb) + torch.testing.assert_close(out_eager, out_compiled) + with torch.no_grad(): + out_compiled_2 = compiled(x, emb) + torch.testing.assert_close(out_compiled, out_compiled_2)