Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
standard solvers). Patching primitives (`BasePatching2D`,
`GridPatching2D`, `RandomPatching2D`) are exposed under the same
subpackage and are `torch.compile`-friendly with `fullgraph=True`.
`MultiDiffusionPredictor` supports memory-efficient inference on
large domains via `chunk_size` and `use_checkpointing`. The
subpackage also ships patch-local DPS guidance:
`MultiDiffusionDPSScorePredictor` (drop-in score predictor that plugs
into the standard sampling stack),
`MultiDiffusionDataConsistencyDPSGuidance` for inpainting and sparse
data assimilation, and `MultiDiffusionModelConsistencyDPSGuidance` for
generic patch-local observation operators. Use these instead of the
global `DPSScorePredictor` to run guided sampling on domains that
would otherwise OOM.
- Adds `"epsilon"` as a supported prediction type throughout the diffusion
framework, alongside the existing `"x0"` and `"score"` modes. A new
`PredictorType = Literal["x0", "score", "epsilon"]` alias in
Expand Down
44 changes: 21 additions & 23 deletions physicsnemo/diffusion/guidance/dps_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@
from physicsnemo.diffusion.base import Predictor


def _lp_loss_fn(p: int) -> Callable[[Tensor, Tensor], Tensor]:
Copy link
Copy Markdown
Collaborator

@pzharrington pzharrington May 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add jaxtpying annotations to these Tensors if possible? (and where it's used below)

"""Return a per-batch-element Lp loss function with exponent ``p``."""

def _loss(y_pred: Tensor, y_true: Tensor) -> Tensor:
residual = (y_pred - y_true).reshape(y_pred.shape[0], -1)
return residual.abs().pow(p).sum(dim=1)

return _loss


@runtime_checkable
class DPSGuidance(Protocol):
r"""
Expand Down Expand Up @@ -658,7 +668,10 @@ def __init__(
self.observation_operator = observation_operator
self.y = y
self.std_y = std_y
self.norm = norm
if isinstance(norm, int):
self._loss_fn: Callable[[Tensor, Tensor], Tensor] = _lp_loss_fn(norm)
else:
self._loss_fn = norm
self.gamma = gamma
self.sigma_fn = (
sigma_fn if sigma_fn is not None else lambda t: torch.zeros_like(t)
Expand Down Expand Up @@ -708,17 +721,8 @@ def __call__(
y = self.y.to(dtype=x.dtype, device=x.device)

with torch.enable_grad():
# Compute predicted observations
y_pred = self.observation_operator(x_0)

# Compute loss
if callable(self.norm):
loss = self.norm(y_pred, y)
else:
residual = (y_pred - y).reshape(y_pred.shape[0], -1)
loss = residual.abs().pow(self.norm).sum(dim=1)

# Compute gradient of loss w.r.t. x (backprop through x_0)
loss = self._loss_fn(y_pred, y)
grad_x = torch.autograd.grad(
outputs=loss.sum(),
inputs=x,
Expand Down Expand Up @@ -956,7 +960,7 @@ def __init__(
std_y: float,
norm: int
| Callable[
[Float[Tensor, "B *dims"], Float[Tensor, "B *dims"]], # noqa: F821
[Float[Tensor, " B *dims"], Float[Tensor, " B *dims"]], # noqa: F821
Float[Tensor, " B"],
] = 2,
gamma: float = 0.0,
Expand All @@ -972,7 +976,10 @@ def __init__(
self.mask = mask.float()
self.y = y
self.std_y = std_y
self.norm = norm
if isinstance(norm, int):
self._loss_fn: Callable[[Tensor, Tensor], Tensor] = _lp_loss_fn(norm)
else:
self._loss_fn = norm
self.gamma = gamma
self.sigma_fn = (
sigma_fn if sigma_fn is not None else lambda t: torch.zeros_like(t)
Expand Down Expand Up @@ -1023,18 +1030,9 @@ def __call__(
y = self.y.to(dtype=x.dtype, device=x.device)

with torch.enable_grad():
# Compute masked predicted and observed values
y_pred = mask * x_0
y_true = mask * y

# Compute loss
if callable(self.norm):
loss = self.norm(y_pred, y_true)
else:
residual = (y_pred - y_true).reshape(x_0.shape[0], -1)
loss = residual.abs().pow(self.norm).sum(dim=1)

# Compute gradient of loss w.r.t. x (backprop through x_0)
loss = self._loss_fn(y_pred, y_true)
grad_x = torch.autograd.grad(
outputs=loss.sum(),
inputs=x,
Expand Down
6 changes: 6 additions & 0 deletions physicsnemo/diffusion/multi_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .dps_guidance import (
MultiDiffusionDataConsistencyDPSGuidance,
MultiDiffusionDPSGuidance,
MultiDiffusionDPSScorePredictor,
MultiDiffusionModelConsistencyDPSGuidance,
)
from .losses import MultiDiffusionMSEDSMLoss, MultiDiffusionWeightedMSEDSMLoss
from .models import MultiDiffusionModel2D
from .patching import (
Expand Down
Loading
Loading