Skip to content

Multi diffusion guidance#1632

Open
CharlelieLrt wants to merge 5 commits into
NVIDIA:mainfrom
CharlelieLrt:Multi-diffusion-guidance
Open

Multi diffusion guidance#1632
CharlelieLrt wants to merge 5 commits into
NVIDIA:mainfrom
CharlelieLrt:Multi-diffusion-guidance

Conversation

@CharlelieLrt
Copy link
Copy Markdown
Collaborator

PhysicsNeMo Pull Request

Description

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
@CharlelieLrt CharlelieLrt requested a review from pzharrington May 9, 2026 02:49
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 9, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@CharlelieLrt CharlelieLrt added the ! - Release PRs or Issues releating to a release label May 9, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 9, 2026

Greptile Summary

This PR adds patch-local DPS (Diffusion Posterior Sampling) guidance to the multi-diffusion module, enabling memory-efficient guided sampling on large spatial domains by streaming per-chunk model outputs and guidance terms without materialising the full (P×B, …) activation tensor.

  • New dps_guidance.py: introduces MultiDiffusionDPSScorePredictor, MultiDiffusionModelConsistencyDPSGuidance, and MultiDiffusionDataConsistencyDPSGuidance, each documented with extensive doctests and math.
  • Extended predictor.py: adds chunk_size, use_checkpointing, prediction_type, set_patching(), chunks(), patch_fn(), and fuse_fn() to MultiDiffusionPredictor, decoupling patching configuration from model construction and exposing a streaming iterator for downstream guidance consumers.
  • Extended models.py: adds a persistent _patch_shape buffer (survives checkpoint round-trips) and a patch_shape property, and relaxes the batch-divisibility check to allow partial chunk tensors to pass through the model.

Important Files Changed

Filename Overview
physicsnemo/diffusion/multi_diffusion/dps_guidance.py New file introducing patch-local DPS guidance classes; contains a dead-code None guard after torch.autograd.grad in both guidance __call__ methods (the guard can never trigger without allow_unused=True).
physicsnemo/diffusion/multi_diffusion/predictor.py Significantly extended with chunk_size, use_checkpointing, prediction_type, set_patching(), chunks(), patch_fn(), and fuse_fn(). The constructor permanently mutates private fields (_fuse, _skip_positional_embedding_injection, global_spatial_shape) on the shared wrapped model, which can silently break other users of the same model instance.
physicsnemo/diffusion/multi_diffusion/models.py Adds _patch_shape persistent buffer and patch_shape property; relaxes the patched-batch-size divisibility check for chunk streaming. Minor: sentinel check uses or instead of and, which could return a partially-initialized shape.
physicsnemo/diffusion/multi_diffusion/init.py Exports the four new DPS guidance symbols from dps_guidance.py; no issues.

Comments Outside Diff (2)

  1. physicsnemo/diffusion/multi_diffusion/models.py

    P2 Sentinel check uses or instead of and

    The zero-sentinel for "not yet configured" is [0, 0]. Using or returns a patch shape even when only one dimension is non-zero (e.g., a hypothetical corrupted buffer [8, 0]), which would later propagate as an invalid shape. Using and correctly requires both dimensions to be positive before declaring the shape valid.

  2. physicsnemo/diffusion/multi_diffusion/predictor.py, line 1337-1341 (link)

    P2 MultiDiffusionPredictor permanently mutates the wrapped model's private state

    __init__ unconditionally sets self._md_model._fuse = False and self._md_model._skip_positional_embedding_injection = True, and set_patching() overwrites self._md_model.global_spatial_shape. Because _md_model is the unwrapped model instance (possibly shared with a training loop or another predictor), these side-effects persist after the predictor is constructed. If the same MultiDiffusionModel2D is used for both training and inference, or if two predictors wrap the same model, the mutations from one predictor silently break the other. Consider either operating on a copy, or explicitly documenting that the caller must not use the model independently after wrapping it in a predictor.

Reviews (1): Last reviewed commit: "Add missing DPS guidance module" | Re-trigger Greptile

Comment thread physicsnemo/diffusion/multi_diffusion/dps_guidance.py Outdated
Comment thread physicsnemo/diffusion/multi_diffusion/dps_guidance.py Outdated
Copy link
Copy Markdown
Collaborator

@pzharrington pzharrington left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving to get into rc, will add comments later

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

! - Release PRs or Issues releating to a release

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants