diff --git a/CHANGELOG.md b/CHANGELOG.md index caaf17aad1..3018c44668 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -102,6 +102,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 standard solvers). Patching primitives (`BasePatching2D`, `GridPatching2D`, `RandomPatching2D`) are exposed under the same subpackage and are `torch.compile`-friendly with `fullgraph=True`. + `MultiDiffusionPredictor` supports memory-efficient inference on + large domains via `chunk_size` and `use_checkpointing`. The + subpackage also ships patch-local DPS guidance: + `MultiDiffusionDPSScorePredictor` (drop-in score predictor that plugs + into the standard sampling stack), + `MultiDiffusionDataConsistencyDPSGuidance` for inpainting and sparse + data assimilation, and `MultiDiffusionModelConsistencyDPSGuidance` for + generic patch-local observation operators. Use these instead of the + global `DPSScorePredictor` to run guided sampling on domains that + would otherwise OOM. - Adds `"epsilon"` as a supported prediction type throughout the diffusion framework, alongside the existing `"x0"` and `"score"` modes. A new `PredictorType = Literal["x0", "score", "epsilon"]` alias in diff --git a/physicsnemo/diffusion/guidance/dps_guidance.py b/physicsnemo/diffusion/guidance/dps_guidance.py index 58d57d87b6..f5999c5a05 100644 --- a/physicsnemo/diffusion/guidance/dps_guidance.py +++ b/physicsnemo/diffusion/guidance/dps_guidance.py @@ -25,6 +25,16 @@ from physicsnemo.diffusion.base import Predictor +def _lp_loss_fn(p: int) -> Callable[[Tensor, Tensor], Tensor]: + """Return a per-batch-element Lp loss function with exponent ``p``.""" + + def _loss(y_pred: Tensor, y_true: Tensor) -> Tensor: + residual = (y_pred - y_true).reshape(y_pred.shape[0], -1) + return residual.abs().pow(p).sum(dim=1) + + return _loss + + @runtime_checkable class DPSGuidance(Protocol): r""" @@ -658,7 +668,10 @@ def __init__( self.observation_operator = observation_operator self.y = y self.std_y = std_y - self.norm = norm + if isinstance(norm, int): + self._loss_fn: Callable[[Tensor, Tensor], Tensor] = _lp_loss_fn(norm) + else: + self._loss_fn = norm self.gamma = gamma self.sigma_fn = ( sigma_fn if sigma_fn is not None else lambda t: torch.zeros_like(t) @@ -708,17 +721,8 @@ def __call__( y = self.y.to(dtype=x.dtype, device=x.device) with torch.enable_grad(): - # Compute predicted observations y_pred = self.observation_operator(x_0) - - # Compute loss - if callable(self.norm): - loss = self.norm(y_pred, y) - else: - residual = (y_pred - y).reshape(y_pred.shape[0], -1) - loss = residual.abs().pow(self.norm).sum(dim=1) - - # Compute gradient of loss w.r.t. x (backprop through x_0) + loss = self._loss_fn(y_pred, y) grad_x = torch.autograd.grad( outputs=loss.sum(), inputs=x, @@ -956,7 +960,7 @@ def __init__( std_y: float, norm: int | Callable[ - [Float[Tensor, "B *dims"], Float[Tensor, "B *dims"]], # noqa: F821 + [Float[Tensor, " B *dims"], Float[Tensor, " B *dims"]], # noqa: F821 Float[Tensor, " B"], ] = 2, gamma: float = 0.0, @@ -972,7 +976,10 @@ def __init__( self.mask = mask.float() self.y = y self.std_y = std_y - self.norm = norm + if isinstance(norm, int): + self._loss_fn: Callable[[Tensor, Tensor], Tensor] = _lp_loss_fn(norm) + else: + self._loss_fn = norm self.gamma = gamma self.sigma_fn = ( sigma_fn if sigma_fn is not None else lambda t: torch.zeros_like(t) @@ -1023,18 +1030,9 @@ def __call__( y = self.y.to(dtype=x.dtype, device=x.device) with torch.enable_grad(): - # Compute masked predicted and observed values y_pred = mask * x_0 y_true = mask * y - - # Compute loss - if callable(self.norm): - loss = self.norm(y_pred, y_true) - else: - residual = (y_pred - y_true).reshape(x_0.shape[0], -1) - loss = residual.abs().pow(self.norm).sum(dim=1) - - # Compute gradient of loss w.r.t. x (backprop through x_0) + loss = self._loss_fn(y_pred, y_true) grad_x = torch.autograd.grad( outputs=loss.sum(), inputs=x, diff --git a/physicsnemo/diffusion/multi_diffusion/__init__.py b/physicsnemo/diffusion/multi_diffusion/__init__.py index ce09708e93..a3e530daa1 100644 --- a/physicsnemo/diffusion/multi_diffusion/__init__.py +++ b/physicsnemo/diffusion/multi_diffusion/__init__.py @@ -14,6 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .dps_guidance import ( + MultiDiffusionDataConsistencyDPSGuidance, + MultiDiffusionDPSGuidance, + MultiDiffusionDPSScorePredictor, + MultiDiffusionModelConsistencyDPSGuidance, +) from .losses import MultiDiffusionMSEDSMLoss, MultiDiffusionWeightedMSEDSMLoss from .models import MultiDiffusionModel2D from .patching import ( diff --git a/physicsnemo/diffusion/multi_diffusion/dps_guidance.py b/physicsnemo/diffusion/multi_diffusion/dps_guidance.py new file mode 100644 index 0000000000..09b31b145e --- /dev/null +++ b/physicsnemo/diffusion/multi_diffusion/dps_guidance.py @@ -0,0 +1,1015 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Patch-local DPS guidance for multi-diffusion sampling.""" + +from typing import Callable, Protocol, Sequence, runtime_checkable + +import torch +from jaxtyping import Bool, Float +from torch import Tensor + +from physicsnemo.diffusion.base import Predictor +from physicsnemo.diffusion.multi_diffusion.predictor import MultiDiffusionPredictor + + +def _lp_loss_fn(p: int) -> Callable[[Tensor, Tensor], Tensor]: + """Return a per-batch-element Lp loss function with exponent ``p``.""" + + def _loss(y_pred: Tensor, y_true: Tensor) -> Tensor: + residual = (y_pred - y_true).reshape(y_pred.shape[0], -1) + return residual.abs().pow(p).sum(dim=1) + + return _loss + + +@runtime_checkable +class MultiDiffusionDPSGuidance(Protocol): + r"""Protocol for **patch-local** DPS guidance compatible with + :class:`MultiDiffusionDPSScorePredictor`. + + A guidance is **patch-local** when its computation decomposes along + the multi-diffusion patch grid: the guidance value at each patch + depends only on the data of that patch. This protocol is **not + applicable** to globally-coupled guidances (e.g. ones that mix + information across patches), use + :class:`~physicsnemo.diffusion.guidance.DPSGuidance` for those. + + Identical to the standard + :class:`~physicsnemo.diffusion.guidance.DPSGuidance` protocol, plus an + optional ``slice_start`` argument that enables chunked evaluation: + + - **Full batch mode** (``slice_start=None``, the default): the call + processes the full :math:`P \times B` batch of patches at once. + Inputs match the size of the pre-patched data stored on the + guidance. The implementation may optionally fuse the result back to + the global resolution. + - **Chunked batch mode** (``slice_start=s``): the call processes a + single chunk of :math:`K \leq \text{chunk\_size}` patches starting + at row ``s``. The implementation slices its pre-patched data with + ``[s : s + K]`` and returns a chunk-sized guidance term (no fusing). + + Chunked batch mode is the key memory-efficiency knob, the per-chunk + activations are released between iterations, so peak GPU memory stays + proportional to ``chunk_size`` rather than to the full + :math:`P \times B`. Use it for large global domains where the + full-batch counterpart from :class:`~physicsnemo.diffusion.guidance.DPSGuidance` + would OOM. + + A guidance satisfying this protocol also satisfies + :class:`~physicsnemo.diffusion.guidance.DPSGuidance` because the extra + argument is optional. + + Examples + -------- + Implementing a simple patch-local guidance from scratch. The mask and + observations are pre-patched once at construction time and sliced per + chunk based on ``slice_start``: + + >>> import torch + >>> from physicsnemo.diffusion.multi_diffusion import ( + ... MultiDiffusionDPSGuidance, + ... ) + >>> + >>> class InpaintGuidance: + ... def __init__(self, mask_patched, y_patched, gamma=1.0): + ... self.mask = mask_patched + ... self.y = y_patched + ... self.gamma = gamma + ... + ... def __call__(self, x, t, x_0, slice_start=None): + ... if slice_start is None: + ... mask, y = self.mask, self.y + ... else: + ... K = x.shape[0] + ... mask = self.mask[slice_start : slice_start + K] + ... y = self.y[slice_start : slice_start + K] + ... return -self.gamma * mask * (x_0 - y) + ... + >>> mask = torch.ones(8, 3, 8, 8) # (P*B, C, Hp, Wp) + >>> y = torch.randn(8, 3, 8, 8) + >>> guidance = InpaintGuidance(mask, y) + >>> isinstance(guidance, MultiDiffusionDPSGuidance) + True + >>> + >>> # Full batch mode: process all P*B = 8 patches at once + >>> x = torch.randn(8, 3, 8, 8) + >>> t = torch.full((8,), 1.0) + >>> x_0 = x * 0.9 + >>> guidance(x, t, x_0).shape + torch.Size([8, 3, 8, 8]) + >>> + >>> # Chunked batch mode: process a chunk of 2 patches starting at row 0 + >>> guidance(x[:2], t[:2], x_0[:2], slice_start=0).shape + torch.Size([2, 3, 8, 8]) + """ + + def __call__( + self, + x: Float[Tensor, "K C Hp Wp"], + t: Float[Tensor, " K"], + x_0: Float[Tensor, "K C Hp Wp"], + slice_start: int | None = None, + ) -> Float[Tensor, "K C Hp Wp"]: ... + + +class MultiDiffusionDPSScorePredictor(Predictor): + r"""Score predictor that combines a + :class:`~physicsnemo.diffusion.multi_diffusion.MultiDiffusionPredictor` + with one or more patch-local DPS guidances for guided sampling on + large multi-diffusion domains. + + Implements the same :class:`~physicsnemo.diffusion.Predictor` + interface as :class:`~physicsnemo.diffusion.guidance.DPSScorePredictor` + and slots into the standard sampling stack: pass it to + :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.get_denoiser` + to obtain a :class:`~physicsnemo.diffusion.Denoiser` that can be used + with :func:`~physicsnemo.diffusion.samplers.sample` or any sampling + utility that consumes a denoiser. + + Use this class instead of + :class:`~physicsnemo.diffusion.guidance.DPSScorePredictor` when every + guidance is patch-local (see :class:`MultiDiffusionDPSGuidance`) and + the global domain is too large for the full + :math:`(P \times B, \dots)` activation tensor to fit in memory. The + predictor streams score and guidance contributions chunk by chunk in + patch space and fuses once at the end: + + .. math:: + + \nabla_{\mathbf{x}} \log p(\mathbf{x}) + + \sum_i g_i(\mathbf{x}, t, \hat{\mathbf{x}}_0) + \;=\; + \mathrm{Fuse}\!\left[\, s^k + \sum_i g_i^k\, \right]_{k=1..P} + + where the superscript :math:`k` denotes the :math:`k`-th patch chunk + and :math:`\mathrm{Fuse}` is the multi-diffusion fusing operator. The + full :math:`(P \times B, \dots)` activation tensor is never + materialized. + + .. important:: + + Use :class:`~physicsnemo.diffusion.guidance.DPSScorePredictor` for + guidances that do **not** decompose patch-locally. Passing a + globally-coupled guidance to this class produces incorrect results. + + Each guidance must implement the :class:`MultiDiffusionDPSGuidance` + protocol: + + .. code-block:: python + + def guidance( + x: Tensor, # shape: (K, C, Hp, Wp) + t: Tensor, # shape: (K,) + x_0: Tensor, # shape: (K, C, Hp, Wp) + slice_start: int | None, # row index of the chunk in (P*B); + # None means full-batch mode + ) -> Tensor: ... # shape: (K, C, Hp, Wp) + + where :math:`K` is the number of patches in the current chunk + (:math:`K = P \times B` in full batch mode, :math:`K \leq + \text{chunk\_size}` in chunked batch mode). The predictor forwards + each chunk's ``slice_start`` from + :meth:`MultiDiffusionPredictor.chunks` directly to every guidance, so + each guidance reads the corresponding slice of its own pre-patched + observations without any internal state. + + The ``x0_to_score_fn`` callback must be an elementwise conversion + with the signature: + + .. code-block:: python + + def x0_to_score_fn( + x_0: Tensor, # shape: (K, C, Hp, Wp) + x_t: Tensor, # shape: (K, C, Hp, Wp) + t: Tensor, # shape: (K,) + ) -> Tensor: ... # shape: (K, C, Hp, Wp) + + Parameters + ---------- + x0_predictor : MultiDiffusionPredictor + A trained predictor with ``chunk_size`` set, returning x0 + estimates. + x0_to_score_fn : callable + Elementwise conversion ``(x_0, x_t, t) -> score`` (see the + signature above). Typically obtained from a noise scheduler, + e.g. + :meth:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler.x0_to_score`. + guidances : MultiDiffusionDPSGuidance or sequence of MultiDiffusionDPSGuidance + One or more patch-local guidance objects implementing the + :class:`MultiDiffusionDPSGuidance` protocol. + + See Also + -------- + :class:`MultiDiffusionDPSGuidance` : Protocol that guidances must satisfy. + :class:`MultiDiffusionDataConsistencyDPSGuidance` : Patch-local + guidance for masked observations. + :class:`MultiDiffusionModelConsistencyDPSGuidance` : Patch-local + guidance for generic patch-local observation operators. + :class:`~physicsnemo.diffusion.guidance.DPSScorePredictor` : Use for + non-patch-local guidances. + + Examples + -------- + **Example 1:** Basic usage with a single inpainting guidance: + + >>> import torch + >>> from physicsnemo.core import Module + >>> from physicsnemo.diffusion.multi_diffusion import ( + ... MultiDiffusionModel2D, MultiDiffusionPredictor, + ... MultiDiffusionDPSScorePredictor, + ... ) + >>> + >>> class Backbone(Module): + ... def __init__(self): + ... super().__init__() + ... self.net = torch.nn.Conv2d(3, 3, 1) + ... def forward(self, x, t, condition=None): + ... return self.net(x) + >>> + >>> md = MultiDiffusionModel2D(Backbone(), global_spatial_shape=(16, 16)) + >>> md.set_random_patching(patch_shape=(8, 8), patch_num=4) + >>> _ = md.eval() + >>> predictor = MultiDiffusionPredictor(md, chunk_size=2) + >>> predictor.set_patching(overlap_pix=0, boundary_pix=0) + >>> + >>> # x0-to-score for EDM: score = (x_0 - x) / t^2 + >>> def x0_to_score_fn(x_0, x, t): + ... t_bc = t.reshape((-1,) + (1,) * (x.ndim - 1)) + ... return (x_0 - x) / (t_bc ** 2) + >>> + >>> # Inline inpainting guidance; mask and observations are pre-patched + >>> # by the user via predictor.patch_fn so all patching uses the same + >>> # grid as the predictor. + >>> class InpaintGuidance: + ... def __init__(self, mask_patched, y_patched, gamma=0.1): + ... self.mask = mask_patched + ... self.y = y_patched + ... self.gamma = gamma + ... def __call__(self, x, t, x_0, slice_start=None): + ... if slice_start is None: + ... mask, y = self.mask, self.y + ... else: + ... K = x.shape[0] + ... mask = self.mask[slice_start : slice_start + K] + ... y = self.y[slice_start : slice_start + K] + ... return -self.gamma * mask * (x_0 - y) + >>> + >>> mask_patched = predictor.patch_fn(torch.ones(2, 3, 16, 16)) + >>> y_patched = predictor.patch_fn(torch.randn(2, 3, 16, 16)) + >>> guidance = InpaintGuidance(mask_patched, y_patched) + >>> + >>> dps = MultiDiffusionDPSScorePredictor( + ... x0_predictor=predictor, + ... x0_to_score_fn=x0_to_score_fn, + ... guidances=guidance, + ... ) + >>> x = torch.randn(2, 3, 16, 16) + >>> t = torch.tensor([1.0, 1.0]) + >>> dps(x, t).shape + torch.Size([2, 3, 16, 16]) + + **Example 2:** Multiple guidances for multi-constraint problems. The + predictor returned by this class is a drop-in score predictor that + plugs into any sampling utility (here + :func:`~physicsnemo.diffusion.samplers.sample`): + + >>> from physicsnemo.diffusion.multi_diffusion import ( + ... MultiDiffusionDataConsistencyDPSGuidance, + ... MultiDiffusionModelConsistencyDPSGuidance, + ... ) + >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler + >>> from physicsnemo.diffusion.samplers import sample + >>> + >>> scheduler = EDMNoiseScheduler() + >>> + >>> # First guidance: masked observations (inpainting) + >>> mask = torch.zeros(2, 3, 16, 16, dtype=torch.bool) + >>> mask[:, :, 4:, :] = True + >>> y_obs1 = torch.randn(2, 3, 16, 16) + >>> g1 = MultiDiffusionDataConsistencyDPSGuidance( + ... predictor=predictor, mask=mask, y=y_obs1, std_y=0.1, + ... retain_graph=True, # required: not the last autograd guidance + ... ) + >>> + >>> # Second guidance: nonlinear patch-local channel response + >>> A = lambda x_0: torch.sigmoid(x_0[:, :1]) + >>> y_obs2 = torch.rand(2, 1, 16, 16) + >>> g2 = MultiDiffusionModelConsistencyDPSGuidance( + ... predictor=predictor, observation_operator=A, + ... y=y_obs2, std_y=0.1, + ... ) + >>> + >>> dps = MultiDiffusionDPSScorePredictor( + ... x0_predictor=predictor, + ... x0_to_score_fn=scheduler.x0_to_score, + ... guidances=[g1, g2], + ... ) + >>> denoiser = scheduler.get_denoiser(score_predictor=dps) + >>> xN = torch.randn(2, 3, 16, 16) + >>> x0 = sample(denoiser, xN, scheduler, num_steps=4) + >>> x0.shape + torch.Size([2, 3, 16, 16]) + """ + + def __init__( + self, + x0_predictor: MultiDiffusionPredictor, + x0_to_score_fn: Callable[ + [ + Float[Tensor, "K C Hp Wp"], + Float[Tensor, "K C Hp Wp"], + Float[Tensor, " K"], + ], + Float[Tensor, "K C Hp Wp"], + ], + guidances: MultiDiffusionDPSGuidance | Sequence[MultiDiffusionDPSGuidance], + ) -> None: + if not isinstance(x0_predictor, MultiDiffusionPredictor): + raise TypeError( + f"x0_predictor must be a MultiDiffusionPredictor, " + f"got {type(x0_predictor).__name__}." + ) + if x0_predictor._chunk_size is None: + raise ValueError( + "x0_predictor must have chunk_size set. " + "Pass chunk_size= to MultiDiffusionPredictor.__init__." + ) + self.x0_predictor = x0_predictor + self.x0_to_score_fn = x0_to_score_fn + if isinstance(guidances, Sequence) and not isinstance(guidances, str): + self.guidances: list[MultiDiffusionDPSGuidance] = list(guidances) + else: + self.guidances = [guidances] # type: ignore[list-item] + + def __call__( + self, + x: Float[Tensor, "B C H W"], + t: Float[Tensor, " B"], + ) -> Float[Tensor, "B C H W"]: + r"""Compute the guided score at the global resolution. + + Parameters + ---------- + x : Tensor + Noisy latent at global resolution, shape :math:`(B, C, H, W)`. + t : Tensor + Diffusion time, shape :math:`(B,)`. + + Returns + ------- + Tensor + Guided score at global resolution, shape :math:`(B, C, H, W)`. + """ + if not torch.compiler.is_compiling() and torch.is_inference_mode_enabled(): + raise RuntimeError( + "MultiDiffusionDPSScorePredictor requires autograd but torch " + "inference mode is enabled. Wrap the calling code with " + "'with torch.inference_mode(False):' or 'with torch.no_grad():' " + "instead." + ) + + x = x.detach().requires_grad_(True) + combined_list: list[Tensor] = [] + + with torch.enable_grad(): + for s, x0_chunk, x_chunk, t_chunk in self.x0_predictor.chunks(x, t): + g_chunk = torch.zeros_like(x0_chunk) + for g in self.guidances: + g_chunk = g_chunk + g(x_chunk, t_chunk, x0_chunk, slice_start=s) + score_chunk = self.x0_to_score_fn(x0_chunk, x_chunk, t_chunk) + combined_list.append(score_chunk + g_chunk) + + combined_patched = torch.cat(combined_list, dim=0) # (P*B, C, Hp, Wp) + return self.x0_predictor.fuse_fn(combined_patched) + + +class MultiDiffusionModelConsistencyDPSGuidance: + r"""Patch-local DPS guidance for generic observation operators with + Gaussian noise. + + Multi-diffusion counterpart of + :class:`~physicsnemo.diffusion.guidance.ModelConsistencyDPSGuidance`, + intended for cases where the observation operator :math:`A` + decomposes along the multi-diffusion patch grid. Implements the + :class:`MultiDiffusionDPSGuidance` protocol, see it for the two-mode + (``slice_start``) semantics and the :math:`K` chunk-size convention. + + Computes the likelihood score assuming Gaussian measurement noise + with standard deviation :math:`\sigma_y`. Letting :math:`k` index the + current patch chunk: + + .. math:: + + \nabla_{\mathbf{x}} \log p(\mathbf{y}^k | \mathbf{x}_t^k) + = -\frac{1}{2 \left( \sigma_y^2 + \Gamma \frac{\sigma(t)^2}{\alpha(t)^2} + \right)} \nabla_{\mathbf{x}^k} + \| A(\hat{\mathbf{x}}_0^k) - \mathbf{y}^k \|^2 + + where the scaling incorporates a Score-Based Data Assimilation (SDA) + correction through :math:`\Gamma`. The L2 norm can be replaced by + other Lp norms or a custom loss function via the ``norm`` parameter. + + Observations ``y`` are pre-patched once at construction; calling the + guidance many times during sampling never re-patches them. + + .. important:: + + ``y`` must be **patcheable** in the same way as the latent state + :math:`\mathbf{x}`, so its spatial dimensions must equal the + global resolution :math:`(H, W)`. This is a stronger requirement + than the global counterpart + :class:`~physicsnemo.diffusion.guidance.ModelConsistencyDPSGuidance`, + which allows arbitrary observation shapes. The operator + :math:`A` must therefore produce observations matching the input + spatial resolution (e.g. channel-selection, pointwise + nonlinearities, local convolutions within an overlap region). + + The ``observation_operator`` must be a differentiable callable with + the following signature: + + .. code-block:: python + + def observation_operator( + x_0: Tensor, # shape: (K, C, Hp, Wp) + ) -> Tensor: ... # shape: (K, C_obs, Hp, Wp) + + When ``norm`` is a callable, it must have the signature: + + .. code-block:: python + + def norm( + y_pred: Tensor, # shape: (K, C_obs, Hp, Wp) + y_true: Tensor, # shape: (K, C_obs, Hp, Wp) + ) -> Tensor: ... # shape: (K,) scalar loss per batch element + + Parameters + ---------- + predictor : MultiDiffusionPredictor + Predictor used to pre-patch ``y`` and (optionally) fuse the + guidance. Stored on ``self.predictor`` for later access. + observation_operator : callable + Differentiable patch-local observation operator :math:`A`. See + the signature above. + y : Tensor + Global observations of shape :math:`(B, C_{obs}, H, W)` matching + the latent's global spatial shape. + std_y : float + Standard deviation of the measurement noise :math:`\sigma_y`. + norm : int or callable, default=2 + Loss applied to the residual. An ``int`` selects the + corresponding Lp norm; a callable replaces it with a custom loss + of the signature above. + gamma : float, default=0.0 + SDA covariance scaling factor :math:`\Gamma`. Set to ``0`` for + classical DPS without SDA scaling. + sigma_fn : callable or None, default=None + Function mapping diffusion time to noise level :math:`\sigma(t)`. + Required when ``gamma > 0``. Typically obtained from a noise + scheduler, e.g. + :meth:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler.sigma`. + alpha_fn : callable or None, default=None + Function mapping diffusion time to signal coefficient + :math:`\alpha(t)`. Defaults to :math:`\alpha(t) = 1`. Typically + obtained from a noise scheduler, e.g. + :meth:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler.alpha`. + fuse : bool, default=False + Whether :meth:`__call__` fuses the guidance term to the global + resolution in full batch mode (``slice_start=None``). Ignored in + chunked batch mode. + retain_graph : bool, default=False + Retain the computation graph after the gradient call. Required + on all but the last guidance when combining multiple + autograd-based guidances in a single + :class:`MultiDiffusionDPSScorePredictor`. + create_graph : bool, default=False + Allow higher-order derivatives. + + Note + ---- + References: + + - DPS: `Diffusion Posterior Sampling for General Noisy Inverse Problems + `_ + - SDA: `Score-based Data Assimilation `_ + + See Also + -------- + :class:`~physicsnemo.diffusion.guidance.ModelConsistencyDPSGuidance` : + Global counterpart for non-patch-local operators. + :class:`MultiDiffusionDPSScorePredictor` : + Score predictor that consumes this guidance. + + Examples + -------- + **Example 1:** Patch-local channel selection. The operator selects + the first channel of each patch, clearly patch-local. Inputs are + chunk-sized patched tensors: + + >>> import torch + >>> from physicsnemo.core import Module + >>> from physicsnemo.diffusion.multi_diffusion import ( + ... MultiDiffusionModel2D, MultiDiffusionPredictor, + ... MultiDiffusionModelConsistencyDPSGuidance, + ... ) + >>> + >>> class Backbone(Module): + ... def __init__(self): + ... super().__init__() + ... self.net = torch.nn.Conv2d(3, 3, 1) + ... def forward(self, x, t, condition=None): + ... return self.net(x) + >>> + >>> md = MultiDiffusionModel2D(Backbone(), global_spatial_shape=(16, 16)) + >>> md.set_random_patching(patch_shape=(8, 8), patch_num=4) + >>> _ = md.eval() + >>> predictor = MultiDiffusionPredictor(md, chunk_size=2) + >>> predictor.set_patching(overlap_pix=0, boundary_pix=0) + >>> + >>> A = lambda x: x[:, :1] + >>> y_obs = torch.randn(2, 1, 16, 16) + >>> + >>> guidance = MultiDiffusionModelConsistencyDPSGuidance( + ... predictor=predictor, observation_operator=A, y=y_obs, std_y=0.1, + ... ) + >>> x_chunk = torch.randn(2, 3, 8, 8, requires_grad=True) + >>> t_chunk = torch.tensor([1.0, 1.0]) + >>> x0_chunk = x_chunk * 0.9 + >>> guidance(x_chunk, t_chunk, x0_chunk, slice_start=0).shape + torch.Size([2, 3, 8, 8]) + + **Example 2:** SDA-scaled guidance with a nonlinear patch-local + operator (here a sigmoid response on the first channel), plugged + into the full sampling stack: + + >>> from physicsnemo.diffusion.multi_diffusion import ( + ... MultiDiffusionDPSScorePredictor, + ... ) + >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler + >>> from physicsnemo.diffusion.samplers import sample + >>> + >>> scheduler = EDMNoiseScheduler() + >>> A_nl = lambda x_0: torch.sigmoid(x_0[:, :1]) + >>> y_obs_nl = torch.rand(2, 1, 16, 16) + >>> + >>> guidance_sda = MultiDiffusionModelConsistencyDPSGuidance( + ... predictor=predictor, + ... observation_operator=A_nl, + ... y=y_obs_nl, + ... std_y=0.075, + ... gamma=0.05, # enable SDA scaling + ... sigma_fn=scheduler.sigma, + ... alpha_fn=scheduler.alpha, + ... ) + >>> dps = MultiDiffusionDPSScorePredictor( + ... x0_predictor=predictor, + ... x0_to_score_fn=scheduler.x0_to_score, + ... guidances=guidance_sda, + ... ) + >>> denoiser = scheduler.get_denoiser(score_predictor=dps) + >>> xN = torch.randn(2, 3, 16, 16) + >>> x0 = sample(denoiser, xN, scheduler, num_steps=4) + >>> x0.shape + torch.Size([2, 3, 16, 16]) + """ + + def __init__( + self, + predictor: MultiDiffusionPredictor, + observation_operator: Callable[ + [Float[Tensor, "K C Hp Wp"]], Float[Tensor, "K C_obs Hp Wp"] + ], + y: Float[Tensor, "B C_obs H W"], + std_y: float, + norm: int + | Callable[ + [Float[Tensor, "K C_obs Hp Wp"], Float[Tensor, "K C_obs Hp Wp"]], + Float[Tensor, " K"], + ] = 2, + gamma: float = 0.0, + sigma_fn: Callable[[Float[Tensor, " *shape"]], Float[Tensor, " *shape"]] + | None = None, + alpha_fn: Callable[[Float[Tensor, " *shape"]], Float[Tensor, " *shape"]] + | None = None, + fuse: bool = False, + retain_graph: bool = False, + create_graph: bool = False, + ) -> None: + if gamma > 0 and sigma_fn is None: + raise ValueError("sigma_fn must be provided when gamma > 0") + self.predictor = predictor + # Pre-patch observations once via the predictor's patch_fn. + self._y_patched: Tensor = predictor.patch_fn(y) + self.observation_operator = observation_operator + self.std_y = std_y + # Resolve the loss callable at construction so __call__ has no branch. + if isinstance(norm, int): + self._loss_fn: Callable[[Tensor, Tensor], Tensor] = _lp_loss_fn(norm) + else: + self._loss_fn = norm + self.gamma = gamma + self.sigma_fn = ( + sigma_fn if sigma_fn is not None else lambda t: torch.zeros_like(t) + ) + self.alpha_fn = ( + alpha_fn if alpha_fn is not None else lambda t: torch.ones_like(t) + ) + self.fuse = fuse + self.retain_graph = retain_graph + self.create_graph = create_graph + + def __call__( + self, + x: Float[Tensor, "K C Hp Wp"], + t: Float[Tensor, " K"], + x_0: Float[Tensor, "K C Hp Wp"], + slice_start: int | None = None, + ) -> Float[Tensor, "K C Hp Wp"] | Float[Tensor, "B C H W"]: + r"""Compute the patch-local likelihood score guidance term. + + See :class:`MultiDiffusionDPSGuidance` for the meaning of + ``slice_start`` (full vs chunked batch mode) and the :math:`K` + chunk-size convention. + + Parameters + ---------- + x : Tensor + Noisy patched latent slice :math:`\mathbf{x}_t^k`, of shape + :math:`(K, C, H_p, W_p)`. Must have ``requires_grad=True`` + and be part of a computational graph connecting to ``x_0``. + Its ``dtype`` and ``device`` determine those of all internal + computations. + t : Tensor + Patched diffusion time slice, shape :math:`(K,)`. + x_0 : Tensor + Estimate of the patched clean state + :math:`\hat{\mathbf{x}}_0^k(\mathbf{x}_t^k, t)`, of shape + :math:`(K, C, H_p, W_p)`. Must be computed from ``x`` so + gradients can backpropagate. + slice_start : int or None, default=None + Chunk offset along the :math:`(P \times B)` dimension. See + class docstring. + + Returns + ------- + Tensor + Patch-local guidance term of shape :math:`(K, C, H_p, W_p)`. + Fused to the global resolution :math:`(B, C, H, W)` when + ``slice_start=None`` and ``fuse=True`` was passed at + construction. + """ + if not torch.compiler.is_compiling() and torch.is_inference_mode_enabled(): + raise RuntimeError( + "MultiDiffusionModelConsistencyDPSGuidance requires autograd " + "but torch inference mode is enabled." + ) + + if slice_start is None: + y_chunk = self._y_patched.to(dtype=x.dtype, device=x.device) + else: + K = x.shape[0] + y_chunk = self._y_patched[slice_start : slice_start + K].to( + dtype=x.dtype, device=x.device + ) + + with torch.enable_grad(): + y_pred = self.observation_operator(x_0) + loss = self._loss_fn(y_pred, y_chunk) + grad_x = torch.autograd.grad( + outputs=loss.sum(), + inputs=x, + retain_graph=self.retain_graph, + create_graph=self.create_graph, + )[0] + + expected_shape = (-1,) + (1,) * (x.ndim - 1) + t_bc = t.reshape(expected_shape) + sigma_t = self.sigma_fn(t_bc) + alpha_t = self.alpha_fn(t_bc) + variance = self.std_y**2 + self.gamma * (sigma_t**2) / (alpha_t**2) + + g = -grad_x / (2 * variance) + if slice_start is None and self.fuse: + return self.predictor.fuse_fn(g) + return g + + +class MultiDiffusionDataConsistencyDPSGuidance: + r"""Patch-local DPS guidance for masked observations with Gaussian + noise. + + Multi-diffusion counterpart of + :class:`~physicsnemo.diffusion.guidance.DataConsistencyDPSGuidance`, + intended for masked observations whose mask decomposes along the + multi-diffusion patch grid. Use cases: inpainting, sparse pointwise + data assimilation on large domains. Implements the + :class:`MultiDiffusionDPSGuidance` protocol, see it for the two-mode + (``slice_start``) semantics and the :math:`K` chunk-size convention. + + Computes the likelihood score assuming Gaussian measurement noise + with standard deviation :math:`\sigma_y`. Letting :math:`k` index + the current patch chunk: + + .. math:: + + \nabla_{\mathbf{x}} \log p(\mathbf{y}^k | \mathbf{x}_t^k) + = -\frac{1}{2 \left( \sigma_y^2 + \Gamma \frac{\sigma(t)^2}{\alpha(t)^2} + \right)} \nabla_{\mathbf{x}^k} + \| \mathbf{M}^k \odot (\hat{\mathbf{x}}_0^k - \mathbf{y}^k) \|^2 + + where :math:`\mathbf{M}` is a binary mask (1 = observed, 0 = missing) + and :math:`\odot` denotes element-wise multiplication. The scaling + incorporates an SDA correction through :math:`\Gamma`. The L2 norm + can be replaced by other Lp norms or a custom loss function via the + ``norm`` parameter. + + Both ``mask`` and ``y`` are pre-patched once at construction; + calling the guidance many times during sampling never re-patches + them. + + .. important:: + + ``mask`` and ``y`` must be **patcheable** in the same way as the + latent state :math:`\mathbf{x}`, so their spatial dimensions + must equal the global resolution :math:`(H, W)`. The mask + defines per-pixel observability within the global spatial + domain. + + When ``norm`` is a callable, it must have the signature: + + .. code-block:: python + + def norm( + y_pred: Tensor, # shape: (K, C, Hp, Wp) + y_true: Tensor, # shape: (K, C, Hp, Wp) + ) -> Tensor: ... # shape: (K,) scalar loss per batch element + + Parameters + ---------- + predictor : MultiDiffusionPredictor + Predictor used to pre-patch ``mask`` and ``y`` and (optionally) + fuse the guidance. Stored on ``self.predictor`` for later access. + mask : Tensor + Boolean mask of shape :math:`(B, C, H, W)`. ``True`` marks + observed locations, ``False`` marks missing. + y : Tensor + Observed values of shape :math:`(B, C, H, W)`. Values at + unobserved locations are ignored. + std_y : float + Standard deviation of the measurement noise :math:`\sigma_y`. + norm : int or callable, default=2 + Loss applied to the masked residual. An ``int`` selects the + corresponding Lp norm; a callable replaces it with a custom loss + of the signature above. + gamma : float, default=0.0 + SDA covariance scaling factor :math:`\Gamma`. Set to ``0`` for + classical DPS without SDA scaling. + sigma_fn : callable or None, default=None + Function mapping diffusion time to noise level :math:`\sigma(t)`. + Required when ``gamma > 0``. Typically obtained from a noise + scheduler, e.g. + :meth:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler.sigma`. + alpha_fn : callable or None, default=None + Function mapping diffusion time to signal coefficient + :math:`\alpha(t)`. Defaults to :math:`\alpha(t) = 1`. Typically + obtained from a noise scheduler, e.g. + :meth:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler.alpha`. + fuse : bool, default=False + Whether :meth:`__call__` fuses the guidance term to the global + resolution in full batch mode (``slice_start=None``). Ignored in + chunked batch mode. + retain_graph : bool, default=False + Retain the computation graph after the gradient call. Required + on all but the last guidance when combining multiple + autograd-based guidances in a single + :class:`MultiDiffusionDPSScorePredictor`. + create_graph : bool, default=False + Allow higher-order derivatives. + + Note + ---- + References: + + - DPS: `Diffusion Posterior Sampling for General Noisy Inverse Problems + `_ + - SDA: `Score-based Data Assimilation `_ + + See Also + -------- + :class:`~physicsnemo.diffusion.guidance.DataConsistencyDPSGuidance` : + Global counterpart for non-patch-local masks. + :class:`MultiDiffusionDPSScorePredictor` : + Score predictor that consumes this guidance. + + Examples + -------- + **Example 1:** Inpainting on a large domain. The mask is a spatial + pattern, so it decomposes along the patch grid: + + >>> import torch + >>> from physicsnemo.core import Module + >>> from physicsnemo.diffusion.multi_diffusion import ( + ... MultiDiffusionModel2D, MultiDiffusionPredictor, + ... MultiDiffusionDataConsistencyDPSGuidance, + ... ) + >>> + >>> class Backbone(Module): + ... def __init__(self): + ... super().__init__() + ... self.net = torch.nn.Conv2d(3, 3, 1) + ... def forward(self, x, t, condition=None): + ... return self.net(x) + >>> + >>> md = MultiDiffusionModel2D(Backbone(), global_spatial_shape=(16, 16)) + >>> md.set_random_patching(patch_shape=(8, 8), patch_num=4) + >>> _ = md.eval() + >>> predictor = MultiDiffusionPredictor(md, chunk_size=2) + >>> predictor.set_patching(overlap_pix=0, boundary_pix=0) + >>> + >>> mask = torch.zeros(2, 3, 16, 16, dtype=torch.bool) + >>> mask[:, :, 4:, :] = True + >>> y_obs = torch.randn(2, 3, 16, 16) + >>> + >>> guidance = MultiDiffusionDataConsistencyDPSGuidance( + ... predictor=predictor, mask=mask, y=y_obs, std_y=0.1, + ... ) + >>> x_chunk = torch.randn(2, 3, 8, 8, requires_grad=True) + >>> t_chunk = torch.tensor([1.0, 1.0]) + >>> x0_chunk = x_chunk * 0.9 + >>> guidance(x_chunk, t_chunk, x0_chunk, slice_start=0).shape + torch.Size([2, 3, 8, 8]) + + **Example 2:** SDA-scaled guidance with the L1 norm for robustness, + plugged into the full sampling stack: + + >>> from physicsnemo.diffusion.multi_diffusion import ( + ... MultiDiffusionDPSScorePredictor, + ... ) + >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler + >>> from physicsnemo.diffusion.samplers import sample + >>> + >>> scheduler = EDMNoiseScheduler() + >>> + >>> mask = torch.zeros(2, 3, 16, 16, dtype=torch.bool) + >>> mask[:, :, 2, 3] = True + >>> mask[:, :, 5, 6] = True + >>> y_obs = torch.randn(2, 3, 16, 16) + >>> + >>> guidance_sda = MultiDiffusionDataConsistencyDPSGuidance( + ... predictor=predictor, + ... mask=mask, + ... y=y_obs, + ... std_y=0.075, + ... norm=1, # L1 norm for robustness + ... gamma=1.0, # enable SDA scaling + ... sigma_fn=scheduler.sigma, + ... alpha_fn=scheduler.alpha, + ... ) + >>> dps = MultiDiffusionDPSScorePredictor( + ... x0_predictor=predictor, + ... x0_to_score_fn=scheduler.x0_to_score, + ... guidances=guidance_sda, + ... ) + >>> denoiser = scheduler.get_denoiser(score_predictor=dps) + >>> xN = torch.randn(2, 3, 16, 16) + >>> x0 = sample(denoiser, xN, scheduler, num_steps=4) + >>> x0.shape + torch.Size([2, 3, 16, 16]) + """ + + def __init__( + self, + predictor: MultiDiffusionPredictor, + mask: Bool[Tensor, "B C H W"], + y: Float[Tensor, "B C H W"], + std_y: float, + norm: int + | Callable[ + [Float[Tensor, "K C Hp Wp"], Float[Tensor, "K C Hp Wp"]], + Float[Tensor, " K"], + ] = 2, + gamma: float = 0.0, + sigma_fn: Callable[[Float[Tensor, " *shape"]], Float[Tensor, " *shape"]] + | None = None, + alpha_fn: Callable[[Float[Tensor, " *shape"]], Float[Tensor, " *shape"]] + | None = None, + fuse: bool = False, + retain_graph: bool = False, + create_graph: bool = False, + ) -> None: + if gamma > 0 and sigma_fn is None: + raise ValueError("sigma_fn must be provided when gamma > 0") + self.predictor = predictor + # Pre-patch mask and observations once via the predictor's patch_fn. + patch = predictor.patch_fn + self._mask_patched: Tensor = patch(mask.float()) + self._y_patched: Tensor = patch(y) + self.std_y = std_y + # Resolve the loss callable at construction so __call__ has no branch. + if isinstance(norm, int): + self._loss_fn: Callable[[Tensor, Tensor], Tensor] = _lp_loss_fn(norm) + else: + self._loss_fn = norm + self.gamma = gamma + self.sigma_fn = ( + sigma_fn if sigma_fn is not None else lambda t: torch.zeros_like(t) + ) + self.alpha_fn = ( + alpha_fn if alpha_fn is not None else lambda t: torch.ones_like(t) + ) + self.fuse = fuse + self.retain_graph = retain_graph + self.create_graph = create_graph + + def __call__( + self, + x: Float[Tensor, "K C Hp Wp"], + t: Float[Tensor, " K"], + x_0: Float[Tensor, "K C Hp Wp"], + slice_start: int | None = None, + ) -> Float[Tensor, "K C Hp Wp"] | Float[Tensor, "B C H W"]: + r"""Compute the patch-local likelihood score guidance term. + + See :class:`MultiDiffusionDPSGuidance` for the meaning of + ``slice_start`` (full vs chunked batch mode) and the :math:`K` + chunk-size convention. + + Parameters + ---------- + x : Tensor + Noisy patched latent slice :math:`\mathbf{x}_t^k`, of shape + :math:`(K, C, H_p, W_p)`. Must have ``requires_grad=True`` + and be part of a computational graph connecting to ``x_0``. + Its ``dtype`` and ``device`` determine those of all internal + computations. + t : Tensor + Patched diffusion time slice, shape :math:`(K,)`. + x_0 : Tensor + Estimate of the patched clean state + :math:`\hat{\mathbf{x}}_0^k(\mathbf{x}_t^k, t)`, of shape + :math:`(K, C, H_p, W_p)`. Must be computed from ``x`` so + gradients can backpropagate. + slice_start : int or None, default=None + Chunk offset along the :math:`(P \times B)` dimension. See + class docstring. + + Returns + ------- + Tensor + Patch-local guidance term of shape :math:`(K, C, H_p, W_p)`. + Fused to the global resolution :math:`(B, C, H, W)` when + ``slice_start=None`` and ``fuse=True`` was passed at + construction. + """ + if not torch.compiler.is_compiling() and torch.is_inference_mode_enabled(): + raise RuntimeError( + "MultiDiffusionDataConsistencyDPSGuidance requires autograd " + "but torch inference mode is enabled." + ) + + if slice_start is None: + mask_chunk = self._mask_patched.to(dtype=x.dtype, device=x.device) + y_chunk = self._y_patched.to(dtype=x.dtype, device=x.device) + else: + K = x.shape[0] + mask_chunk = self._mask_patched[slice_start : slice_start + K].to( + dtype=x.dtype, device=x.device + ) + y_chunk = self._y_patched[slice_start : slice_start + K].to( + dtype=x.dtype, device=x.device + ) + + with torch.enable_grad(): + y_pred = mask_chunk * x_0 + y_true = mask_chunk * y_chunk + loss = self._loss_fn(y_pred, y_true) + grad_x = torch.autograd.grad( + outputs=loss.sum(), + inputs=x, + retain_graph=self.retain_graph, + create_graph=self.create_graph, + )[0] + + expected_shape = (-1,) + (1,) * (x.ndim - 1) + t_bc = t.reshape(expected_shape) + sigma_t = self.sigma_fn(t_bc) + alpha_t = self.alpha_fn(t_bc) + variance = self.std_y**2 + self.gamma * (sigma_t**2) / (alpha_t**2) + + g = -grad_x / (2 * variance) + if slice_start is None and self.fuse: + return self.predictor.fuse_fn(g) + return g diff --git a/physicsnemo/diffusion/multi_diffusion/models.py b/physicsnemo/diffusion/multi_diffusion/models.py index a9238fe0cf..29211b2435 100644 --- a/physicsnemo/diffusion/multi_diffusion/models.py +++ b/physicsnemo/diffusion/multi_diffusion/models.py @@ -384,6 +384,10 @@ class MultiDiffusionModel2D(Module): torch.Size([2, 3, 16, 16]) """ + # Class-level type annotation so static type checkers resolve the subscript + # operations on _patch_shape (set in __init__ via register_buffer). + _patch_shape: Tensor + def __init__( self, model: Module, @@ -401,6 +405,11 @@ def __init__( self._patching: RandomPatching2D | GridPatching2D | None = None self._fuse: bool = False self._skip_positional_embedding_injection: bool = False + # Persistent buffer so that patch_shape survives checkpoint save/load. + # Zeros sentinel means "not yet configured". + self.register_buffer( + "_patch_shape", torch.zeros(2, dtype=torch.long), persistent=True + ) # Normalise condition flags to defaultdict for uniform access if not isinstance(condition_patch, (bool, dict)): raise TypeError( @@ -482,6 +491,39 @@ def condition_interp(self) -> bool | Dict[str, bool]: """Whether conditioning tensors are interpolated to patch resolution.""" return self._condition_interp + @property + def patch_shape(self) -> tuple[int, int] | None: + r"""Spatial shape :math:`(H_p, W_p)` of each patch, or ``None`` if no + patching strategy has been configured yet. + + The value is read from the live patching object when available, and + falls back to the persistent checkpoint buffer when the model was + loaded from a checkpoint but ``set_grid_patching`` / ``set_random_patching`` + have not been called yet. + + Examples + -------- + >>> import torch + >>> from physicsnemo.core import Module + >>> from physicsnemo.diffusion.multi_diffusion import MultiDiffusionModel2D + >>> class M(Module): + ... def __init__(self): super().__init__(); self.net = torch.nn.Conv2d(3,3,1) + ... def forward(self, x, t, condition=None): return self.net(x) + >>> md = MultiDiffusionModel2D(M(), global_spatial_shape=(16, 16)) + >>> md.patch_shape is None + True + >>> md.set_grid_patching(patch_shape=(8, 8)) + >>> md.patch_shape + (8, 8) + """ + patching = self._patching + if patching is not None: + return patching.patch_shape + ps = self._patch_shape + if int(ps[0]) > 0 or int(ps[1]) > 0: + return (int(ps[0]), int(ps[1])) + return None + # ------------------------------------------------------------------ # Patching strategy configuration # ------------------------------------------------------------------ @@ -534,6 +576,8 @@ def set_random_patching( patch_num=patch_num, ) self._fuse = False + self._patch_shape[0] = patch_shape[0] + self._patch_shape[1] = patch_shape[1] def reset_patch_indices(self) -> None: r"""Re-draw random patch positions for the current random patching @@ -601,6 +645,8 @@ def set_grid_patching( boundary_pix=boundary_pix, ) self._fuse = fuse + self._patch_shape[0] = patch_shape[0] + self._patch_shape[1] = patch_shape[1] # ------------------------------------------------------------------ # Public patching utilities @@ -904,16 +950,25 @@ def forward( P = patching.patch_num - # Determine original batch size B + # B is only consumed by PE injection and fusing. When neither runs + # (e.g., MultiDiffusionPredictor calls into this method with both + # disabled to stream partial chunks), B is unused — skip computing + # and validating it so partial (K, C, Hp, Wp) tensors with K < P + # can be passed through. + _b_consumed = ( + self.pos_embd is not None and not self._skip_positional_embedding_injection + ) or self._fuse if x_is_patched: - if not torch.compiler.is_compiling(): - if x.shape[0] % P != 0: - raise ValueError( - f"x_is_patched=True but x batch dim " - f"({x.shape[0]}) is not divisible by patch_num " - f"({P})." - ) - B = x.shape[0] // P + if ( + _b_consumed + and not torch.compiler.is_compiling() + and x.shape[0] % P != 0 + ): + raise ValueError( + f"x_is_patched=True but x batch dim ({x.shape[0]}) is " + f"not divisible by patch_num ({P})." + ) + B = x.shape[0] // P if _b_consumed else 0 else: B = x.shape[0] diff --git a/physicsnemo/diffusion/multi_diffusion/predictor.py b/physicsnemo/diffusion/multi_diffusion/predictor.py index 641be3a1aa..209fdb1438 100644 --- a/physicsnemo/diffusion/multi_diffusion/predictor.py +++ b/physicsnemo/diffusion/multi_diffusion/predictor.py @@ -16,13 +16,16 @@ """Multi-diffusion predictor wrapper for patch-based diffusion sampling.""" -from typing import Any +import warnings +from typing import Any, Callable, Iterator, cast +import torch from jaxtyping import Float from tensordict import TensorDict from torch import Tensor +from torch.utils.checkpoint import checkpoint -from physicsnemo.diffusion.base import Predictor +from physicsnemo.diffusion.base import Predictor, PredictorType from physicsnemo.diffusion.multi_diffusion.models import MultiDiffusionModel2D from physicsnemo.diffusion.multi_diffusion.patching import GridPatching2D from physicsnemo.diffusion.utils.utils import _unwrap_module @@ -36,13 +39,35 @@ class MultiDiffusionPredictor(Predictor): plugs into any sampling utility that accepts a ``Predictor`` (:func:`~physicsnemo.diffusion.samplers.sample`, :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.get_denoiser`, - and all standard solvers) with no other changes. All patch-based logic - — patching the state, running the per-patch predictions and fusing them - back to the global domain — is handled internally. + and all standard solvers) with no other changes. All patch-based logic + (patching the state, running the per-patch predictions, fusing them back + to the global domain) is handled internally. - The wrapped model must have grid patching configured via - :meth:`~MultiDiffusionModel2D.set_grid_patching` before constructing the - predictor. + The patching strategy must be configured by calling :meth:`set_patching` + before any other method. ``patch_shape`` and ``global_shape`` default + to the values saved on the wrapped model (typically restored from the + training-time checkpoint), while ``overlap_pix`` and ``boundary_pix`` + must be provided explicitly because they cannot be inferred from a + model trained with random patching. + + On very large global domains the full :math:`P \times B` activation + tensor may not fit in GPU memory, leading to OOM errors. Two + independent strategies mitigate this: + + - ``chunk_size`` processes the :math:`P \times B` patches in + consecutive chunks of at most ``chunk_size`` rows. Useful for + both plain inference and gradient-based use cases. Trades batch + parallelism for memory. + - ``use_checkpointing`` recomputes activations on demand during + backpropagation instead of storing them. Only meaningful when + gradients flow through the predictor, the typical use case being + DPS guidance (see + :class:`~physicsnemo.diffusion.multi_diffusion.MultiDiffusionDPSScorePredictor`). + Trades compute for memory. + + The two options can be combined; for use cases that need explicit + control over chunk-level processing, the streaming iterator + :meth:`chunks` exposes the per-chunk model outputs directly. .. warning:: @@ -54,15 +79,40 @@ class MultiDiffusionPredictor(Predictor): Parameters ---------- model : MultiDiffusionModel2D - A trained multi-diffusion model with grid patching already configured. + A trained multi-diffusion model. The grid patching configuration + must be supplied through :meth:`set_patching` after construction. condition : torch.Tensor, TensorDict, or None, optional, default=None - When provided, the shape should be :math:`(B, *cond_dims)`. - Conditioning information at the global resolution, bound once at - construction and reused at every diffusion step. Pass ``None`` for - unconditional models. + Conditioning at the global resolution, bound once at construction + and reused at every diffusion step. Shape :math:`(B, *cond\_dims)`. + Pass ``None`` for unconditional models. fuse : bool, default=True Whether to fuse per-patch outputs back to the global resolution before returning. + chunk_size : int or None, default=None + Number of patch rows along the :math:`P \times B` dimension + processed per model call. ``None`` runs all patches in a single + call. Set to a small integer to reduce peak GPU memory when the + full :math:`P \times B` activation tensor does not fit at once. + use_checkpointing : bool, default=False + Trade compute for memory: activations are recomputed on demand + during backpropagation instead of being stored from the forward + pass. Useful when differentiating through the predictor on large + domains. Works with or without ``chunk_size``. + prediction_type : PredictorType, default="x0" + Output type of the wrapped model. One of ``"x0"``, ``"score"``, or + ``"epsilon"``. The predictor always exposes an x0-compatible + output; pass the appropriate conversion function below when the + model does not directly predict x0. + score_to_x0_fn : callable, optional + Conversion ``(score, x_t, t) -> x0`` applied to the model output. + Required when ``prediction_type="score"``. Typically obtained from + a noise scheduler, e.g. + :meth:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler.score_to_x0`. + epsilon_to_x0_fn : callable, optional + Conversion ``(epsilon, x_t, t) -> x0`` applied to the model output. + Required when ``prediction_type="epsilon"``. Typically obtained from + a noise scheduler, e.g. + :meth:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler.epsilon_to_x0`. **model_kwargs : Any Additional keyword arguments bound once at construction and forwarded to the wrapped model at every call. @@ -70,17 +120,19 @@ class MultiDiffusionPredictor(Predictor): See Also -------- :class:`~physicsnemo.diffusion.multi_diffusion.MultiDiffusionModel2D` : - The multi-diffusion wrapper used for training; its grid patching - must be configured before creating the predictor. + The multi-diffusion wrapper used for training. :class:`~physicsnemo.diffusion.Predictor` : The protocol this class implements. :func:`~physicsnemo.diffusion.samplers.sample` : The main sampling entry point. + :class:`~physicsnemo.diffusion.multi_diffusion.MultiDiffusionDPSScorePredictor` : + Patch-local DPS score predictor that consumes + :meth:`chunks` for memory-efficient guided sampling. Examples -------- **Example 1:** Predictor in isolation. Input and output live at the - global resolution; patching, per-patch prediction and fusing are all + global resolution; patching, per-patch prediction, and fusing are all handled internally: >>> import torch @@ -96,19 +148,19 @@ class MultiDiffusionPredictor(Predictor): ... def forward(self, x, t, condition=None): ... return self.net(x) >>> - >>> # Create a trained multi-diffusion model (training omitted here) + >>> # Train and save the model (training omitted here) >>> md = MultiDiffusionModel2D(Backbone(), global_spatial_shape=(16, 16)) - >>> md.set_grid_patching(patch_shape=(8, 8)) # P = 4 patches per sample + >>> md.set_random_patching(patch_shape=(8, 8), patch_num=4) # training config >>> _ = md.eval() >>> >>> predictor = MultiDiffusionPredictor(md) + >>> predictor.set_patching(overlap_pix=0, boundary_pix=0) # P = 4 patches per sample >>> x = torch.randn(2, 3, 16, 16) # global-resolution state >>> t = 0.5 * torch.ones(2) >>> predictor(x, t).shape # fused output at global resolution torch.Size([2, 3, 16, 16]) >>> - >>> # fuse=False returns raw per-patch predictions — (P*B, C, Hp, Wp) - >>> predictor.fuse = False + >>> predictor.fuse = False # raw per-patch predictions instead >>> predictor(x, t).shape torch.Size([8, 3, 8, 8]) @@ -119,12 +171,12 @@ class MultiDiffusionPredictor(Predictor): >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> from physicsnemo.diffusion.samplers import sample >>> - >>> # The wrapped model must have grid patching + fuse=True for sampling >>> md = MultiDiffusionModel2D(Backbone(), global_spatial_shape=(16, 16)) - >>> md.set_grid_patching(patch_shape=(8, 8), overlap_pix=2, fuse=True) + >>> md.set_random_patching(patch_shape=(8, 8), patch_num=4) >>> _ = md.eval() >>> >>> predictor = MultiDiffusionPredictor(md) + >>> predictor.set_patching(overlap_pix=2, boundary_pix=0) >>> scheduler = EDMNoiseScheduler() >>> denoiser = scheduler.get_denoiser(x0_predictor=predictor) >>> xN = torch.randn(2, 3, 16, 16) # initial noise at global resolution @@ -132,8 +184,8 @@ class MultiDiffusionPredictor(Predictor): >>> x0.shape torch.Size([2, 3, 16, 16]) - **Example 3:** Conditional sampling with mixed conditioning — an image - that shares the spatial resolution (patched like the state) and a vector + **Example 3:** Conditional sampling with mixed conditioning, an image + sharing the spatial resolution (patched like the state) and a vector (repeated across patches). Both kinds are bound once at construction and handled internally: @@ -153,9 +205,8 @@ class MultiDiffusionPredictor(Predictor): ... MultiCondBackbone(), ... global_spatial_shape=(16, 16), ... condition_patch={"image": True}, # image is patched - ... # vector has no flag: default is repeat across patches ... ) - >>> md.set_grid_patching(patch_shape=(8, 8), fuse=True) + >>> md.set_random_patching(patch_shape=(8, 8), patch_num=4) >>> _ = md.eval() >>> >>> condition = TensorDict({ @@ -163,11 +214,80 @@ class MultiDiffusionPredictor(Predictor): ... "vector": torch.randn(2, 5), ... }, batch_size=[2]) >>> predictor = MultiDiffusionPredictor(md, condition=condition) + >>> predictor.set_patching(overlap_pix=0, boundary_pix=0) >>> denoiser = scheduler.get_denoiser(x0_predictor=predictor) >>> xN = torch.randn(2, 3, 16, 16) >>> x0 = sample(denoiser, xN, scheduler, num_steps=4) >>> x0.shape torch.Size([2, 3, 16, 16]) + + **Example 4:** Memory-efficient inference on large domains. + ``chunk_size`` and ``use_checkpointing`` are independent and address + different bottlenecks. ``chunk_size`` is useful for any kind of + inference (with or without gradients) and reduces peak memory by + sacrificing some batch parallelism. ``use_checkpointing`` is only + meaningful when gradients flow through the predictor (see + :class:`~physicsnemo.diffusion.multi_diffusion.MultiDiffusionDPSScorePredictor` + for the DPS guidance use case) and trades compute for memory by + replaying the forward during backpropagation. + + Set ``chunk_size`` to process patches in chunks instead of all at once + (helpful for plain inference on a domain that would otherwise OOM): + + >>> md = MultiDiffusionModel2D(Backbone(), global_spatial_shape=(16, 16)) + >>> md.set_random_patching(patch_shape=(8, 8), patch_num=4) + >>> _ = md.eval() + >>> predictor = MultiDiffusionPredictor(md, chunk_size=2) + >>> predictor.set_patching(overlap_pix=0, boundary_pix=0) + >>> x = torch.randn(3, 3, 16, 16) # B = 3 + >>> t = 0.5 * torch.ones(3) + >>> predictor(x, t).shape # fused (B, C, H, W) + torch.Size([3, 3, 16, 16]) + + Set ``use_checkpointing=True`` (independent of chunking) when + differentiating through the predictor on a large domain: + + >>> predictor = MultiDiffusionPredictor(md, use_checkpointing=True) + >>> predictor.set_patching(overlap_pix=0, boundary_pix=0) + >>> x = torch.randn(3, 3, 16, 16, requires_grad=True) + >>> y = predictor(x, t) + >>> grad = torch.autograd.grad(y.sum(), x)[0] + >>> grad.shape + torch.Size([3, 3, 16, 16]) + + Combine both for differentiable inference on very large domains. + A typical use case is computing the gradient of a DPS guidance + likelihood through the predictor (see + :class:`~physicsnemo.diffusion.multi_diffusion.MultiDiffusionDPSScorePredictor`): + + >>> predictor = MultiDiffusionPredictor(md, chunk_size=2, use_checkpointing=True) + >>> predictor.set_patching(overlap_pix=0, boundary_pix=0) + >>> x = torch.randn(3, 3, 16, 16, requires_grad=True) + >>> y = predictor(x, t) + >>> grad = torch.autograd.grad(y.sum(), x)[0] + >>> grad.shape + torch.Size([3, 3, 16, 16]) + + Use :meth:`chunks` for explicit per-chunk control. The iterator yields + ``(slice_start, x0_chunk, x_chunk, t_chunk)`` tuples; the caller is + responsible for fusing via :meth:`fuse_fn` after the loop. This is + functionally equivalent to ``__call__`` with ``chunk_size`` set, but + exposes the intermediate per-chunk values so callers can interleave + their own per-chunk processing (e.g. accumulating patch-local guidance + terms): + + >>> predictor = MultiDiffusionPredictor(md, chunk_size=4, use_checkpointing=True) + >>> predictor.set_patching(overlap_pix=0, boundary_pix=0) + >>> x = torch.randn(3, 3, 16, 16, requires_grad=True) # B = 3, P*B = 12 + >>> t = 0.5 * torch.ones(3) + >>> outs = [] + >>> for s, x0_c, x_c, t_c in predictor.chunks(x, t): + ... outs.append(x0_c) # (chunk_size=4, C, Hp, Wp) + >>> x0_patched = torch.cat(outs, dim=0) # (P*B=12, C, Hp, Wp) + >>> x0_global = predictor.fuse_fn(x0_patched) # (B=3, C, H, W) + >>> grad = torch.autograd.grad(x0_global.sum(), x)[0] + >>> grad.shape + torch.Size([3, 3, 16, 16]) """ def __init__( @@ -175,61 +295,442 @@ def __init__( model: MultiDiffusionModel2D, condition: Float[Tensor, " B *cond_dims"] | TensorDict | None = None, fuse: bool = True, + chunk_size: int | None = None, + use_checkpointing: bool = False, + prediction_type: PredictorType = "x0", + score_to_x0_fn: Callable[ + [Float[Tensor, " B *dims"], Float[Tensor, " B *dims"], Float[Tensor, " B"]], + Float[Tensor, " B *dims"], + ] + | None = None, + epsilon_to_x0_fn: Callable[ + [Float[Tensor, " B *dims"], Float[Tensor, " B *dims"], Float[Tensor, " B"]], + Float[Tensor, " B *dims"], + ] + | None = None, **model_kwargs: Any, ) -> None: self._md_model: MultiDiffusionModel2D = _unwrap_module( model, MultiDiffusionModel2D ) + self.model = model + self._model_kwargs = model_kwargs + self._chunk_size = chunk_size + self._use_checkpointing = use_checkpointing + self._fuse: bool = fuse + self._cond_input = condition - if not isinstance(self._md_model._patching, GridPatching2D): - raise RuntimeError( - "MultiDiffusionPredictor requires grid patching to be configured. " - "Call model.set_grid_patching() before creating the predictor." + # Predictor-owned patching parameters; default to the values saved on + # the wrapped model. set_patching() overrides them at call time and + # warns when the override differs from the saved value. + self._patch_shape: tuple[int, int] | None = self._md_model.patch_shape + self._global_shape: tuple[int, int] = tuple(self._md_model.global_spatial_shape) + + # Caches populated by set_patching(); guarded against use before then. + self._patching: GridPatching2D | None = None + self._P: int | None = None + self._cond_patched: Tensor | TensorDict | None = None + self._pos_embd_patched: Tensor | None = None + + # PE injection is handled by this class from a pre-patched cache; + # suppress the wrapper's per-step PE injection. + self._md_model._skip_positional_embedding_injection = True + # Internal model fusing is handled externally by this predictor (via + # fuse_fn); keep it disabled so gradient checkpointing replays observe + # a stable model state across forward and backward passes. + self._md_model._fuse = False + + # Prediction-type conversion (same pattern as MultiDiffusionMSEDSMLoss). + match prediction_type: + case "x0": + self._to_x0: Callable[[Tensor, Tensor, Tensor], Tensor] = ( + lambda pred, _x, _t: pred + ) + case "score": + if score_to_x0_fn is None: + raise ValueError( + "score_to_x0_fn must be provided when prediction_type='score'." + ) + self._to_x0 = score_to_x0_fn + case "epsilon": + if epsilon_to_x0_fn is None: + raise ValueError( + "epsilon_to_x0_fn must be provided when prediction_type='epsilon'." + ) + self._to_x0 = epsilon_to_x0_fn + case _: + raise ValueError( + f"prediction_type must be 'x0', 'score', or 'epsilon', " + f"got '{prediction_type}'." + ) + + # Bind the model-call helper once so the use_checkpointing branch is + # resolved at construction rather than on every forward call. + self._call_model: Callable[[Tensor, Tensor, Any], Tensor] = ( + self._call_model_with_checkpoint + if self._use_checkpointing + else self._call_model_direct + ) + + @property + def fuse(self) -> bool: + """Whether per-patch outputs are fused back to the global resolution + before being returned.""" + return self._fuse + + @fuse.setter + def fuse(self, value: bool) -> None: + """Set whether per-patch outputs are fused before being returned.""" + self._fuse = value + + @property + def patch_shape(self) -> tuple[int, int] | None: + r"""Spatial shape :math:`(H_p, W_p)` of each patch.""" + return self._md_model.patch_shape + + def patch_fn( + self, + x: Float[Tensor, "B C H W"], + ) -> Float[Tensor, "P_times_B C Hp Wp"]: + r"""Patch a global-resolution spatial tensor. + + Forwards to + :meth:`physicsnemo.diffusion.multi_diffusion.MultiDiffusionModel2D.patch_x` + on the wrapped model. Useful for pre-patching auxiliary tensors + (observations, masks) outside the predictor — for example in the + constructor of a DPS guidance — so that all patching uses the same + grid configuration as the predictor. + + Parameters + ---------- + x : Tensor + Global-resolution tensor of shape :math:`(B, C, H, W)`. + + Returns + ------- + Tensor + Patched tensor of shape :math:`(P \times B, C, H_p, W_p)`. + """ + self._check_patching_set() + return self._md_model.patch_x(x) + + def fuse_fn( + self, + patched: Float[Tensor, "P_times_B C Hp Wp"], + ) -> Float[Tensor, "B C H W"]: + r"""Fuse a complete patched tensor back to the global resolution. + + General-purpose fusing utility: takes a patched tensor of shape + :math:`(P \times B, C, H_p, W_p)` and returns the corresponding + global-resolution tensor :math:`(B, C, H, W)`. The original batch + size :math:`B` is inferred as ``patched.shape[0] // patch_num``; + the input must therefore contain the full :math:`P \times B` rows. + + Parameters + ---------- + patched : Tensor + Full patched tensor of shape :math:`(P \times B, C, H_p, W_p)`. + + Returns + ------- + Tensor + Fused tensor of shape :math:`(B, C, H, W)`. + + Raises + ------ + ValueError + If ``patched.shape[0]`` is not divisible by ``patch_num``. + """ + self._check_patching_set() + P = cast(int, self._P) + if not torch.compiler.is_compiling() and patched.shape[0] % P != 0: + raise ValueError( + f"patched.shape[0] ({patched.shape[0]}) is not divisible by " + f"patch_num ({P}); fuse_fn requires the full " + f"(P*B, …) tensor." ) + B = patched.shape[0] // P + return self._md_model.fuse(patched, batch_size=B) - self._patching: GridPatching2D = self._md_model._patching + def set_patching( + self, + overlap_pix: int, + boundary_pix: int, + *, + patch_shape: tuple[int, int] | None = None, + global_shape: tuple[int, int] | None = None, + ) -> None: + r"""Set the grid patching configuration. - self.model = model - self._model_kwargs = model_kwargs - self._P: int = self._patching.patch_num + Must be called once after construction and before any other method + (:meth:`__call__`, :meth:`chunks`, :meth:`patch_fn`, :meth:`fuse_fn`, + ...). Calling it again reconfigures the patching and rebuilds the + internal pre-patched caches. + + ``overlap_pix`` and ``boundary_pix`` are required: they cannot be + inferred from the wrapped model (which is typically trained with a + random patching strategy that has neither concept). + ``patch_shape`` and ``global_shape`` default to the values saved in + the wrapped model. Overriding either emits a warning, since + mismatching the training-time geometry can produce unexpected + results (in particular, positional embeddings baked at model + construction will no longer match a different ``global_shape``). + + Parameters + ---------- + overlap_pix : int + Overlapping pixels between adjacent patches. + boundary_pix : int + Boundary pixels padded on each side. + patch_shape : tuple[int, int], optional + Override for the patch spatial shape :math:`(H_p, W_p)`. Default + to the value saved on the wrapped model. + global_shape : tuple[int, int], optional + Override for the global spatial shape :math:`(H, W)`. Default to + the value saved on the wrapped model. + """ + if patch_shape is not None: + new_ps = tuple(patch_shape) + if self._patch_shape is not None and new_ps != self._patch_shape: + warnings.warn( + f"Overriding saved patch_shape {self._patch_shape} with " + f"{new_ps}. Inference-time patching that differs from " + f"training may produce unexpected results.", + stacklevel=2, + ) + self._patch_shape = new_ps + + if global_shape is not None: + new_gs = tuple(global_shape) + if new_gs != self._global_shape: + warnings.warn( + f"Overriding saved global_spatial_shape " + f"{self._global_shape} with {new_gs}. Positional " + f"embeddings baked at model construction will no " + f"longer match the new shape.", + stacklevel=2, + ) + self._global_shape = new_gs + + if self._patch_shape is None: + raise RuntimeError( + "patch_shape is not available on the wrapped model and was " + "not provided. Pass patch_shape=(Hp, Wp) explicitly." + ) - # Pre-patch condition once (without PE) - self._cond_patched: Tensor | TensorDict | None = self._md_model.patch_condition( - condition + # Sync the underlying model's global_spatial_shape so its patching + # logic uses the same geometry as the predictor. + self._md_model.global_spatial_shape = self._global_shape + self._md_model.set_grid_patching( + patch_shape=self._patch_shape, + overlap_pix=overlap_pix, + boundary_pix=boundary_pix, + fuse=False, ) + self._patching = cast(GridPatching2D, self._md_model._patching) + self._P = self._patching.patch_num - # Pre-patch PE for B=1, expanded to (P*B) at call time + # Rebuild pre-patched condition / PE caches now that patching is set. + self._cond_patched = self._md_model.patch_condition(self._cond_input) if self._md_model.pos_embd is not None: - self._pos_embd_patched: Tensor | None = self._md_model.patch_x( + self._pos_embd_patched = self._md_model.patch_x( self._md_model.pos_embd.unsqueeze(0) ) # (P, C_PE, Hp, Wp) else: self._pos_embd_patched = None - # PE will be injected by this class from the pre-patched cache; - # suppress the wrapper's own per-step PE injection to avoid redundant work - self._md_model._skip_positional_embedding_injection = True + def _build_cond(self, B: int) -> Tensor | TensorDict | None: + # Expand the pre-patched condition to batch size B and inject PE. + cond = self._cond_patched + if self._pos_embd_patched is not None: + P = cast(int, self._P) + pe = self._pos_embd_patched.repeat_interleave(B, dim=0) + cond = self._md_model._inject_patched_pos_embd(cond, pe, P * B) + return cond - self.fuse = fuse + def chunks( + self, + x: Float[Tensor, "B C H W"], + t: Float[Tensor, " B"], + ) -> Iterator[ + tuple[ + int, # slice_start: row index of x0_chunk along (P*B) + Float[Tensor, "K C Hp Wp"], # x0_chunk: model output (converted to x0) + Float[Tensor, "K C Hp Wp"], # x_chunk: noisy input slice + Float[Tensor, " K"], # t_chunk: time slice + ] + ]: + r"""Stream the per-chunk model outputs alongside their inputs. - @property - def fuse(self) -> bool: - """Whether the predictor fuses per-patch outputs back to the global - resolution at each call.""" - return self._md_model._fuse + Always returns patched outputs and does not fuse them. This makes + the iterator particularly useful for use cases that need to combine + the per-chunk model output with auxiliary patched data before + fusing, such as patch-local DPS guidance (see + :class:`~physicsnemo.diffusion.multi_diffusion.MultiDiffusionDPSScorePredictor`). + On large global domains, the equivalent non-chunked call + :meth:`__call__` may run out of GPU memory; iterating the patches + in chunks of size ``chunk_size`` keeps the peak activation + footprint bounded. Use :meth:`fuse_fn` after the loop to fuse the + concatenated chunks back to the global resolution. - @fuse.setter - def fuse(self, value: bool) -> None: - """Enable or disable fusing at each call.""" - self._md_model._fuse = value + Functionally equivalent to ``__call__`` with ``chunk_size`` set + (which streams internally and fuses at the end). Use this iterator + when explicit per-chunk control is needed; otherwise prefer the + higher-level ``__call__``. + + Patching is performed once at the start of the iteration; subsequent + iterations only slice the pre-patched tensors. Requires + ``chunk_size`` to be set at construction. + + Parameters + ---------- + x : Tensor + Noisy latent at global resolution, shape :math:`(B, C, H, W)`. + t : Tensor + Diffusion time, shape :math:`(B,)`. + + Returns + ------- + Iterator yielding tuples ``(slice_start, x0_chunk, x_chunk, t_chunk)``: + + - ``slice_start`` (``int``): row index of the chunk along the + :math:`(P \times B)` dimension. Allows downstream consumers (e.g. + patch-local DPS guidances) to align their own pre-patched data + with the current chunk. + - ``x0_chunk`` (``Tensor``): model output for this chunk (already + converted to x0 when a conversion was configured), shape + :math:`(K, C, H_p, W_p)` with :math:`K \leq chunk\_size`. All + chunks have ``chunk_size`` rows except possibly the last, which + may be smaller when :math:`P \times B` is not divisible by + ``chunk_size``. + - ``x_chunk`` (``Tensor``): noisy input slice corresponding to + ``x0_chunk``, same shape. + - ``t_chunk`` (``Tensor``): time slice corresponding to + ``x0_chunk``, shape :math:`(K,)`. + + Raises + ------ + RuntimeError + If ``chunk_size`` was not set at construction, or if + :meth:`set_patching` has not been called. + + Examples + -------- + Streaming inference with a per-chunk computation in the loop body + (here just collecting a per-chunk statistic, but in practice this + is where DPS guidance terms or other patch-local processing would + plug in): + + >>> import torch + >>> from physicsnemo.core import Module + >>> from physicsnemo.diffusion.multi_diffusion import ( + ... MultiDiffusionModel2D, MultiDiffusionPredictor, + ... ) + >>> class Backbone(Module): + ... def __init__(self): + ... super().__init__() + ... self.net = torch.nn.Conv2d(3, 3, 1) + ... def forward(self, x, t, condition=None): + ... return self.net(x) + >>> md = MultiDiffusionModel2D(Backbone(), global_spatial_shape=(16, 16)) + >>> md.set_random_patching(patch_shape=(8, 8), patch_num=4) + >>> _ = md.eval() + >>> predictor = MultiDiffusionPredictor(md, chunk_size=2) + >>> predictor.set_patching(overlap_pix=0, boundary_pix=0) + >>> x = torch.randn(2, 3, 16, 16) + >>> t = 0.5 * torch.ones(2) + >>> chunks_list, chunk_norms = [], [] + >>> for s, x0_c, x_c, t_c in predictor.chunks(x, t): + ... chunk_norms.append(x0_c.pow(2).sum()) # per-chunk computation + ... chunks_list.append(x0_c) + >>> x0_global = predictor.fuse_fn(torch.cat(chunks_list, dim=0)) + >>> x0_global.shape + torch.Size([2, 3, 16, 16]) + """ + self._check_patching_set() + if self._chunk_size is None: + raise RuntimeError( + "chunk_size must be set at construction to use chunks(). " + "Pass chunk_size= to MultiDiffusionPredictor.__init__." + ) + + B = x.shape[0] + x_patched = self._md_model.patch_x(x) # (P*B, C, Hp, Wp) + t_patched = self._md_model.patch_t(t) # (P*B,) + cond = self._build_cond(B) + + K = self._chunk_size + PB = x_patched.shape[0] + for s in range(0, PB, K): + e = min(s + K, PB) + x_c = x_patched[s:e] + t_c = t_patched[s:e] + c_c = cond[s:e] if cond is not None else None + out = self._call_model(x_c, t_c, c_c) + x0_c = self._to_x0(out, x_c, t_c) + yield s, x0_c, x_c, t_c + + def _check_patching_set(self) -> None: + """Raise if :meth:`set_patching` has not been called yet. + + Guarded by ``torch.compiler.is_compiling`` so it is a no-op under + ``torch.compile``. + """ + if not torch.compiler.is_compiling() and self._patching is None: + raise RuntimeError( + "Grid patching is not configured. Call set_patching(" + "overlap_pix, boundary_pix, ...) before any other method." + ) + + def _call_model_direct( + self, + x_p: Tensor, + t_p: Tensor, + cond: Any, + ) -> Tensor: + """Call the wrapped model on the patched inputs, no checkpointing.""" + return self._md_model( + x_p, + t_p, + condition=cond, + x_is_patched=True, + t_is_patched=True, + condition_is_patched=True, + **self._model_kwargs, + ) + + def _call_model_with_checkpoint( + self, + x_p: Tensor, + t_p: Tensor, + cond: Any, + ) -> Tensor: + """Call the wrapped model under :func:`torch.utils.checkpoint.checkpoint`. + + ``cond`` is captured as a default argument so each invocation binds + its own value, which is required for correct backward replay when + called inside a loop. + """ + + def _inner(xc: Tensor, tc: Tensor, _cond=cond) -> Tensor: # noqa: ANN001 + return self._md_model( + xc, + tc, + condition=_cond, + x_is_patched=True, + t_is_patched=True, + condition_is_patched=True, + **self._model_kwargs, + ) + + return checkpoint(_inner, x_p, t_p, use_reentrant=False) def __call__( self, x: Float[Tensor, "B C H W"], t: Float[Tensor, " B"], ) -> Float[Tensor, "B C H W"] | Float[Tensor, "P_times_B C Hp Wp"]: - r"""Run the predictor on a noisy latent and diffusion time at the - global resolution. + r"""Run the predictor on a noisy latent and diffusion time. Parameters ---------- @@ -241,25 +742,21 @@ def __call__( Returns ------- torch.Tensor - If ``self.fuse`` is ``True``: prediction at the global - resolution, shape :math:`(B, C, H, W)`. + If ``self.fuse=True``: prediction at the global resolution, + shape :math:`(B, C, H, W)`. Otherwise: per-patch predictions, shape :math:`(P \times B, C, H_p, W_p)`. """ + self._check_patching_set() + if self._chunk_size is not None: + x0_chunks = [x0_c for _, x0_c, _, _ in self.chunks(x, t)] + output = torch.cat(x0_chunks, dim=0) # (P*B, C, Hp, Wp) + return self.fuse_fn(output) if self._fuse else output + B = x.shape[0] x_patched = self._md_model.patch_x(x) # (P*B, C, Hp, Wp) t_patched = self._md_model.patch_t(t) # (P*B,) - cond = self._cond_patched - if self._pos_embd_patched is not None: - # Expand cached PE from (P, ...) to (P*B, ...) and inject - pe = self._pos_embd_patched.repeat_interleave(B, dim=0) - cond = self._md_model._inject_patched_pos_embd(cond, pe, self._P * B) - return self._md_model( - x_patched, - t_patched, - condition=cond, - x_is_patched=True, - t_is_patched=True, - condition_is_patched=True, - **self._model_kwargs, - ) + cond = self._build_cond(B) + result = self._call_model(x_patched, t_patched, cond) + result = self._to_x0(result, x_patched, t_patched) + return self.fuse_fn(result) if self._fuse else result diff --git a/test/diffusion/data/test_multi_diffusion_models_cond_interp.mdlus b/test/diffusion/data/test_multi_diffusion_models_cond_interp.mdlus index c4baf2dae0..02f7dc4001 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_cond_interp.mdlus and b/test/diffusion/data/test_multi_diffusion_models_cond_interp.mdlus differ diff --git a/test/diffusion/data/test_multi_diffusion_models_cond_patch.mdlus b/test/diffusion/data/test_multi_diffusion_models_cond_patch.mdlus index dd54061dd0..c88b2bf26e 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_cond_patch.mdlus and b/test/diffusion/data/test_multi_diffusion_models_cond_patch.mdlus differ diff --git a/test/diffusion/data/test_multi_diffusion_models_cond_vec_img.mdlus b/test/diffusion/data/test_multi_diffusion_models_cond_vec_img.mdlus index 602261f791..5932084f9e 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_cond_vec_img.mdlus and b/test/diffusion/data/test_multi_diffusion_models_cond_vec_img.mdlus differ diff --git a/test/diffusion/data/test_multi_diffusion_models_edm_precond.mdlus b/test/diffusion/data/test_multi_diffusion_models_edm_precond.mdlus index 688df1c037..47052f01c1 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_edm_precond.mdlus and b/test/diffusion/data/test_multi_diffusion_models_edm_precond.mdlus differ diff --git a/test/diffusion/data/test_multi_diffusion_models_posembd_learn.mdlus b/test/diffusion/data/test_multi_diffusion_models_posembd_learn.mdlus index 553360c5fd..2a04b7773a 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_posembd_learn.mdlus and b/test/diffusion/data/test_multi_diffusion_models_posembd_learn.mdlus differ diff --git a/test/diffusion/data/test_multi_diffusion_models_posembd_sin.mdlus b/test/diffusion/data/test_multi_diffusion_models_posembd_sin.mdlus index ece97562fc..1be2991550 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_posembd_sin.mdlus and b/test/diffusion/data/test_multi_diffusion_models_posembd_sin.mdlus differ diff --git a/test/diffusion/data/test_multi_diffusion_models_uncond.mdlus b/test/diffusion/data/test_multi_diffusion_models_uncond.mdlus index e90f1d7328..c3897dba69 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_uncond.mdlus and b/test/diffusion/data/test_multi_diffusion_models_uncond.mdlus differ diff --git a/test/diffusion/data/test_multi_diffusion_predictor_cond_interp.mdlus b/test/diffusion/data/test_multi_diffusion_predictor_cond_interp.mdlus index 64d959616c..fd55bd3971 100644 Binary files a/test/diffusion/data/test_multi_diffusion_predictor_cond_interp.mdlus and b/test/diffusion/data/test_multi_diffusion_predictor_cond_interp.mdlus differ diff --git a/test/diffusion/data/test_multi_diffusion_predictor_cond_patch.mdlus b/test/diffusion/data/test_multi_diffusion_predictor_cond_patch.mdlus index e018b5ef03..ce866b8b5b 100644 Binary files a/test/diffusion/data/test_multi_diffusion_predictor_cond_patch.mdlus and b/test/diffusion/data/test_multi_diffusion_predictor_cond_patch.mdlus differ diff --git a/test/diffusion/data/test_multi_diffusion_predictor_cond_vec_img.mdlus b/test/diffusion/data/test_multi_diffusion_predictor_cond_vec_img.mdlus index 3fdf0f40ae..33744c9d78 100644 Binary files a/test/diffusion/data/test_multi_diffusion_predictor_cond_vec_img.mdlus and b/test/diffusion/data/test_multi_diffusion_predictor_cond_vec_img.mdlus differ diff --git a/test/diffusion/data/test_multi_diffusion_predictor_posembd_learn.mdlus b/test/diffusion/data/test_multi_diffusion_predictor_posembd_learn.mdlus index a0f3f4fa0c..de2d6eb8ca 100644 Binary files a/test/diffusion/data/test_multi_diffusion_predictor_posembd_learn.mdlus and b/test/diffusion/data/test_multi_diffusion_predictor_posembd_learn.mdlus differ diff --git a/test/diffusion/data/test_multi_diffusion_predictor_posembd_sin.mdlus b/test/diffusion/data/test_multi_diffusion_predictor_posembd_sin.mdlus index f6914a8334..9c96c8e239 100644 Binary files a/test/diffusion/data/test_multi_diffusion_predictor_posembd_sin.mdlus and b/test/diffusion/data/test_multi_diffusion_predictor_posembd_sin.mdlus differ diff --git a/test/diffusion/data/test_multi_diffusion_predictor_uncond.mdlus b/test/diffusion/data/test_multi_diffusion_predictor_uncond.mdlus index 10aa08a0d8..f2e6eb1716 100644 Binary files a/test/diffusion/data/test_multi_diffusion_predictor_uncond.mdlus and b/test/diffusion/data/test_multi_diffusion_predictor_uncond.mdlus differ diff --git a/test/diffusion/test_dps_guidance.py b/test/diffusion/test_dps_guidance.py index d141c45662..fc1351fd78 100644 --- a/test/diffusion/test_dps_guidance.py +++ b/test/diffusion/test_dps_guidance.py @@ -353,7 +353,6 @@ def test_custom_attributes(self): ) assert g.std_y == pytest.approx(0.5) assert g.gamma == pytest.approx(2.0) - assert g.norm == 1 assert g.retain_graph is True assert g.create_graph is True @@ -395,7 +394,6 @@ def test_custom_attributes(self): ) assert g.std_y == pytest.approx(0.5) assert g.gamma == pytest.approx(2.0) - assert g.norm == 1 assert g.retain_graph is True assert g.create_graph is True diff --git a/test/diffusion/test_multi_diffusion_predictor.py b/test/diffusion/test_multi_diffusion_predictor.py index dad9d2b580..d45afc2122 100644 --- a/test/diffusion/test_multi_diffusion_predictor.py +++ b/test/diffusion/test_multi_diffusion_predictor.py @@ -38,7 +38,6 @@ IMG_W, INPUT_SHAPE, MD_CONFIGS, - PATCH_NUM, PATCH_SHAPE, _create_md_model, _create_md_model_edm_precond, @@ -72,7 +71,9 @@ def _create_predictor( fuse=fuse, ) condition = _make_condition(config_name, img_shape=img_shape, device=device) - return MultiDiffusionPredictor(md, condition=condition, fuse=fuse) + pred = MultiDiffusionPredictor(md, condition=condition, fuse=fuse) + pred.set_patching(overlap_pix=overlap_pix, boundary_pix=boundary_pix) + return pred # ============================================================================= @@ -88,18 +89,22 @@ class TestConstructor: to end by the non-regression tests. """ - @pytest.mark.parametrize( - "setup", - ["none", "random"], - ids=["no_patching", "random_patching"], - ) - def test_requires_grid_patching(self, setup): - """Predictor construction raises when grid patching is not active.""" + def test_set_patching_requires_patch_shape(self): + """set_patching raises when no patch_shape is available on the model + and none is provided explicitly.""" + md = _create_md_model("uncond") + pred = MultiDiffusionPredictor(md) + with pytest.raises(RuntimeError, match="patch_shape"): + pred.set_patching(overlap_pix=0, boundary_pix=0) + + def test_methods_require_set_patching(self): + """Predictor methods raise when set_patching has not been called.""" md = _create_md_model("uncond") - if setup == "random": - md.set_random_patching(patch_shape=PATCH_SHAPE, patch_num=PATCH_NUM) - with pytest.raises(RuntimeError, match="grid patching"): - MultiDiffusionPredictor(md) + pred = MultiDiffusionPredictor(md) + x = make_input(INPUT_SHAPE, seed=GLOBAL_SEED) + t = torch.rand(BATCH) + with pytest.raises(RuntimeError, match="set_patching"): + pred(x, t) @pytest.mark.parametrize("fuse", [True, False], ids=["fuse_true", "fuse_false"]) @pytest.mark.parametrize( @@ -114,10 +119,9 @@ def test_public_api(self, config_name, fuse): assert pred.fuse is fuse # .model is the MultiDiffusionModel2D the predictor wraps assert isinstance(pred.model, MultiDiffusionModel2D) - # .fuse setter round-trips and is reflected on the wrapped model + # .fuse setter round-trips on the predictor pred.fuse = not fuse assert pred.fuse is (not fuse) - assert pred.model._fuse is (not fuse) # ============================================================================= @@ -219,6 +223,7 @@ def create_fn(): md.set_grid_patching(patch_shape=PATCH_SHAPE, fuse=True) condition = _make_condition(config_name, device=device) pred = MultiDiffusionPredictor(md, condition=condition, fuse=True) + pred.set_patching(overlap_pix=0, boundary_pix=0) x = make_input(INPUT_SHAPE, seed=GLOBAL_SEED, device=device) t = make_input((BATCH,), seed=GLOBAL_SEED + 1, device=device).abs() + 0.1 @@ -263,6 +268,7 @@ def test_gradient_flow_conditional(self, device): cond_img = make_input(INPUT_SHAPE, seed=99, device=device).requires_grad_(True) condition = TensorDict({"image": cond_img}, batch_size=[BATCH]) pred = MultiDiffusionPredictor(md, condition=condition, fuse=True) + pred.set_patching(overlap_pix=0, boundary_pix=0) x = make_input(INPUT_SHAPE, seed=GLOBAL_SEED, device=device).requires_grad_( True @@ -277,6 +283,7 @@ def test_gradient_flow_posembd(self, device): md.set_grid_patching(patch_shape=PATCH_SHAPE, fuse=True) condition = _make_condition("posembd_learn", device=device) pred = MultiDiffusionPredictor(md, condition=condition, fuse=True) + pred.set_patching(overlap_pix=0, boundary_pix=0) x = make_input(INPUT_SHAPE, seed=GLOBAL_SEED, device=device).requires_grad_( True @@ -346,6 +353,7 @@ def test_forward_non_regression(self, deterministic_settings, device, tolerances md = _create_md_model_edm_precond().to(device) md.set_grid_patching(patch_shape=PATCH_SHAPE, fuse=True) pred = MultiDiffusionPredictor(md, fuse=True) + pred.set_patching(overlap_pix=0, boundary_pix=0) x = make_input(INPUT_SHAPE, seed=GLOBAL_SEED, device=device) t = make_input((BATCH,), seed=GLOBAL_SEED + 1, device=device).abs() + 0.1 diff --git a/test/diffusion/test_multi_diffusion_sampling.py b/test/diffusion/test_multi_diffusion_sampling.py index 2e63b03a8f..13be807da1 100644 --- a/test/diffusion/test_multi_diffusion_sampling.py +++ b/test/diffusion/test_multi_diffusion_sampling.py @@ -106,6 +106,7 @@ def _make_sampling_components( ) condition = _make_condition(md_config, img_shape=img_shape, device=device) predictor = MultiDiffusionPredictor(md, condition=condition, fuse=True) + predictor.set_patching(overlap_pix=overlap_pix, boundary_pix=boundary_pix) denoiser = scheduler.get_denoiser(x0_predictor=predictor) H, W = img_shape