diff --git a/phaser/engines/conventional/run.py b/phaser/engines/conventional/run.py index 2264805..04dee87 100644 --- a/phaser/engines/conventional/run.py +++ b/phaser/engines/conventional/run.py @@ -1,4 +1,6 @@ import logging +import typing as t +import numpy from phaser.utils.misc import mask_fraction_of_groups from phaser.utils.num import assert_dtype, cast_array_module, to_numpy, to_complex_dtype diff --git a/phaser/engines/gradient/run.py b/phaser/engines/gradient/run.py index 489ac37..524ea2e 100644 --- a/phaser/engines/gradient/run.py +++ b/phaser/engines/gradient/run.py @@ -20,7 +20,7 @@ from phaser.hooks.solver import GradientSolver from phaser.hooks.regularization import CostRegularizer, GroupConstraint from phaser.plan import GradientEnginePlan -from phaser.types import process_flag, ReconsVar +from phaser.types import process_flag, flag_any_true, ReconsVar from ..common.simulation import GroupManager, make_propagators, tilt_propagators, slice_forwards, stream_patterns @@ -314,6 +314,7 @@ def iter_patterns(groups: t.Iterable[NDArray[numpy.int_]]) -> t.Iterable[t.Tuple (update, iter_solver_states[sol_i]) = solver.update( state, iter_solver_states[sol_i], filter_vars(iter_grads, solver.params), losses['total_loss'] ) + state = apply_update(state, update) if 'positions' in update: diff --git a/phaser/execute.py b/phaser/execute.py index 593a71c..2fb3a58 100644 --- a/phaser/execute.py +++ b/phaser/execute.py @@ -15,7 +15,7 @@ from .hooks import EngineHook, Hook, ObjectHook, RawData from .plan import GradientEnginePlan, ReconsPlan, EnginePlan, ScanHook, ProbeHook, TiltHook from .state import Patterns, ReconsState, PartialReconsState, IterState, PreparedRecons -from .observer import Observer, LoggingObserver, PatienceObserver, SaveObserver, ObserverSet +from .observer import Observer, LoggingObserver, PatienceObserver, RelMsSSIMObserver, SaveObserver, ObserverSet def execute_plan( @@ -57,10 +57,23 @@ def execute_engine( engine_i = recons.state.iter.engine_num - if plan.early_termination: - engine_observer = ObserverSet((recons.observer, PatienceObserver( - plan.early_termination, plan.early_termination_smoothing - ))) + extra_observers: t.List[Observer] = [] + + if plan.calc_rel_msssim is not False: + extra_observers.append(RelMsSSIMObserver(plan.calc_rel_msssim)) + + if any(v is not None for v in ( + plan.early_termination_loss, plan.early_termination_obj_rel_msssim, plan.early_termination_probe_rel_msssim + )): + extra_observers.append(PatienceObserver( + patience_loss=plan.early_termination_loss, + patience_obj_rel_msssim=plan.early_termination_obj_rel_msssim, + patience_probe_rel_msssim=plan.early_termination_probe_rel_msssim, + smoothing=plan.early_termination_smoothing, + )) + + if extra_observers: + engine_observer = ObserverSet((recons.observer, *extra_observers)) else: engine_observer = recons.observer @@ -307,6 +320,8 @@ def initialize_reconstruction( if init_state.scan is not None and plan.init.scan is None: logging.info("Re-using scan from initial state...") scan = init_state.scan + scan = scan.astype(dtype) + else: logging.info("Initializing scan...") scan = pane.from_data(scan_hook, ScanHook)( # type: ignore @@ -316,6 +331,8 @@ def initialize_reconstruction( if init_state.tilt is not None and plan.init.tilt is None: logging.info("Re-using tilt from initial state...") tilt = init_state.tilt + tilt = tilt.astype(dtype) + elif tilt_hook is not None: logging.info("Initializing tilt...") tilt = pane.from_data(tilt_hook, TiltHook)( # type: ignore diff --git a/phaser/observer.py b/phaser/observer.py index dbb9f53..6693736 100644 --- a/phaser/observer.py +++ b/phaser/observer.py @@ -10,7 +10,7 @@ from phaser.types import EarlyTermination, flag_any_true, process_flag if t.TYPE_CHECKING: - from phaser.hooks.schedule import FlagArgs + from phaser.hooks.schedule import FlagArgs, FlagLike from typing_extensions import Self P = t.ParamSpec('P') @@ -169,44 +169,143 @@ def finish_recons(self, state: ReconsState): class PatienceObserver(Observer): - def __init__(self, patience: int, smoothing: float = 0.1, continue_next_engine: bool = True): - self.patience: int = patience - self.no_improvement_iter: int = 0 - self.best_error: t.Optional[float] = None - self.smoothed_error: t.Optional[float] = None + # metrics where higher values indicate improvement + _HIGHER_IS_BETTER: t.FrozenSet[str] = frozenset({'obj_rel_msssim', 'probe_rel_msssim'}) + + def __init__( + self, + patience_loss: t.Optional[int] = None, + patience_obj_rel_msssim: t.Optional[int] = None, + patience_probe_rel_msssim: t.Optional[int] = None, + smoothing: float = 0.1, + continue_next_engine: bool = True, + ): self.smoothing: float = smoothing self.continue_next_engine: bool = continue_next_engine + # build active metric table: key -> patience + self._patience: t.Dict[str, int] = {} + if patience_loss is not None: + self._patience['total_loss'] = patience_loss + if patience_obj_rel_msssim is not None: + self._patience['obj_rel_msssim'] = patience_obj_rel_msssim + if patience_probe_rel_msssim is not None: + self._patience['probe_rel_msssim'] = patience_probe_rel_msssim + + self._best: t.Dict[str, float] = {} + self._last_improvement_iter: t.Dict[str, int] = {} + self._smoothed: t.Dict[str, float] = {} + def init_engine( self, init_state: ReconsState, *, recons_name: str, plan: EnginePlan, **kwargs: t.Any ): - self.no_improvement_iter = 0 + self._best = {} + self._last_improvement_iter = {} + self._smoothed = {} + + def update_iteration(self, state: ReconsState, i: int, n: int, errors: t.Dict[str, float]): + current_iter = int(state.iter.total_iter) + + for key, patience in self._patience.items(): + # read value: loss from errors dict every iteration; + # ssim metrics only when a new value was computed this iteration + if key == 'total_loss': + value: t.Optional[float] = errors.get('total_loss') + else: + prog = state.progress.get(key) if state.progress else None + if prog is None or not len(prog.values): + continue + # skip if no new ssim value was produced this iteration + if not len(prog.iters) or prog.iters[-1] != current_iter: + continue + value = prog.values[-1] + + if value is None: + continue + + # exponential moving average + if key not in self._smoothed: + self._smoothed[key] = value + else: + self._smoothed[key] = (1 - self.smoothing) * self._smoothed[key] + self.smoothing * value + + higher_is_better = key in self._HIGHER_IS_BETTER + improved = ( + key not in self._best + or (higher_is_better and value > self._best[key]) + or (not higher_is_better and value < self._best[key]) + ) + + if improved: + self._best[key] = value + self._last_improvement_iter[key] = current_iter - def _error_from_state(self, state: t.Union[ReconsState, PartialReconsState]) -> t.Optional[float]: - if state.progress is None or (progress := state.progress['total_loss']) is None or not len(progress.values): - return None - return progress.values[-1] + iters_without_improvement = current_iter - self._last_improvement_iter.get(key, current_iter) + if iters_without_improvement >= patience: + logging.info( + f"Early termination: {key} no improvement for {iters_without_improvement} iterations" + ) + raise EarlyTermination(state, self.continue_next_engine) + + +class RelMsSSIMObserver(Observer): + """Computes obj_rel_msssim and probe_rel_msssim at each calc_rel_msssim flag firing.""" + + def __init__(self, calc_rel_msssim: 'FlagLike'): + from phaser.types import process_flag, flag_any_true + self._calc_rel_msssim_raw = calc_rel_msssim + self._calc_rel_msssim_flag = process_flag(calc_rel_msssim) + self._ssim_enabled: bool = False + # CPU-side snapshot: (total_iter, obj_phase, probe_abs) as numpy arrays + self._prev_snapshot: t.Optional[t.Tuple[int, 'numpy.ndarray', 'numpy.ndarray']] = None + + def init_engine( + self, init_state: ReconsState, *, recons_name: str, + plan: EnginePlan, **kwargs: t.Any + ): + from phaser.types import flag_any_true + self._ssim_enabled = flag_any_true(self._calc_rel_msssim_raw, plan.niter) + self._prev_snapshot = None + + if self._ssim_enabled: + for k in ('obj_rel_msssim', 'probe_rel_msssim'): + if k not in init_state.progress: + init_state.progress[k] = ProgressState() def update_iteration(self, state: ReconsState, i: int, n: int, errors: t.Dict[str, float]): - if (error := errors.get('total_loss')) is None: + if not self._ssim_enabled: + return + if not self._calc_rel_msssim_flag({'state': state, 'niter': n}): return - if self.best_error is None or error < self.best_error: - self.best_error = error - self.no_improvement_iter = 0 - else: - self.no_improvement_iter += 1 + from phaser.utils.num import get_array_module, to_numpy + from phaser.utils.analysis import structural_similarity - # Exponential moving average - if self.smoothed_error is None: - self.smoothed_error = error - else: - self.smoothed_error = (1 - self.smoothing) * self.smoothed_error + self.smoothing * error + xp = get_array_module(state.object.data) + total_iter = int(state.iter.total_iter) + + # transfer only the two arrays needed; forces GPU→CPU sync here + obj_now = to_numpy(xp.angle(state.object.data)) + probe_now = to_numpy(xp.abs(state.probe.data)) + + if self._prev_snapshot is not None: + prev_iter, obj_prev, probe_prev = self._prev_snapshot + + ssim_o = structural_similarity(obj_now, obj_prev) + state.progress['obj_rel_msssim'].iters.append(total_iter) + state.progress['obj_rel_msssim'].values.append(ssim_o) + + ssim_p = structural_similarity(probe_now, probe_prev) + state.progress['probe_rel_msssim'].iters.append(total_iter) + state.progress['probe_rel_msssim'].values.append(ssim_p) + + logging.info( + f"Relative multiscale SSIM (iters {prev_iter}→{total_iter}): " + f"obj={ssim_o:.4f} probe={ssim_p:.4f}" + ) - if self.no_improvement_iter >= self.patience: - logging.info(f"Early termination: no improvement for {self.patience} iterations") - raise EarlyTermination(state, self.continue_next_engine) + self._prev_snapshot = (total_iter, obj_now, probe_now) class SaveObserver(Observer): diff --git a/phaser/plan.py b/phaser/plan.py index 0e68ad6..5e1247a 100644 --- a/phaser/plan.py +++ b/phaser/plan.py @@ -85,8 +85,12 @@ class EnginePlan(Dataclass, kw_only=True): save_images: FlagLike = False save_options: SaveOptions = SaveOptions() - early_termination: t.Optional[int] = None - """Terminate after n iterations without improvement""" + early_termination_loss: t.Optional[int] = None + """Terminate after n iterations without improvement in total_loss""" + early_termination_obj_rel_msssim: t.Optional[int] = None + """Terminate after n iterations without improvement in obj_rel_msssim (requires calc_rel_msssim to be enabled)""" + early_termination_probe_rel_msssim: t.Optional[int] = None + """Terminate after n iterations without improvement in probe_rel_msssim (requires calc_rel_msssim to be enabled)""" early_termination_smoothing: float = 0.9 """ Smoothing factor to apply to error measurement for early termination. @@ -94,6 +98,9 @@ class EnginePlan(Dataclass, kw_only=True): (smooths over ~1/smoothing iterations) """ + calc_rel_msssim: FlagLike = False + """Compute SSIM between consecutive iterations as a convergence metric. Use SimpleFlag(every=N) to compute every N iterations.""" + check_every_group: bool = False send_every_group: bool = False @@ -154,7 +161,7 @@ class GradientEnginePlan(EnginePlan): regularizers: t.List[CostRegularizerHook] group_constraints: t.List[GroupConstraintHook] iter_constraints: t.List[IterConstraintHook] - + class SGDSolverPlan(Dataclass, kw_only=True): learning_rate: ScheduleLike diff --git a/phaser/utils/analysis.py b/phaser/utils/analysis.py index 024c66c..9c3c9c5 100644 --- a/phaser/utils/analysis.py +++ b/phaser/utils/analysis.py @@ -294,4 +294,139 @@ def align_and_correlate(mat: NDArray[numpy.floating]) -> NDArray[numpy.floating] if return_crop: return upsamp_obj[(slice(None), *crop)], ground_truth[tuple(crop)], crop - return upsamp_obj[(slice(None), *crop)], ground_truth[tuple(crop)] \ No newline at end of file + return upsamp_obj[(slice(None), *crop)], ground_truth[tuple(crop)] + + +def _uniform_filter_spatial(im, size: int, xp: t.Any): + """ + Separable box filter over the last two spatial dims only (any ndim >= 2). + Accepts stacked inputs e.g. (N, H, W), filtering H and W only — enabling + fused multi-statistic computation in one call. + + Dispatches to: + - scipy.ndimage.uniform_filter for numpy + - cupyx.scipy.ndimage.uniform_filter for cupy (GPU-native) + - cumsum-based separable filter for JAX / other backends (XLA-friendly) + """ + xp_name = getattr(xp, '__name__', '') + sizes = [1] * (im.ndim - 2) + [size, size] + + if xp_name == 'numpy': + from scipy.ndimage import uniform_filter + return uniform_filter(im, sizes) + + if 'cupy' in xp_name: + from cupyx.scipy.ndimage import uniform_filter + return uniform_filter(im, sizes) + + # JAX or other: cumsum box filter along axes -2 and -1 only (XLA-friendly) + def _along_axis(arr, axis: int): + pad = size // 2 + pad_config = [(0, 0)] * arr.ndim + pad_config[axis] = (pad, pad) + padded = xp.pad(arr, pad_config, mode='reflect') + zero_shape = list(padded.shape) + zero_shape[axis] = 1 + cs = xp.concatenate( + [xp.zeros(zero_shape, dtype=padded.dtype), xp.cumsum(padded, axis=axis)], + axis=axis, + ) + n = arr.shape[axis] + sl_end = [slice(None)] * arr.ndim + sl_end[axis] = slice(size, size + n) + sl_beg = [slice(None)] * arr.ndim + sl_beg[axis] = slice(0, n) + return (cs[tuple(sl_end)] - cs[tuple(sl_beg)]) / size + + return _along_axis(_along_axis(im, -2), -1) + + +def structural_similarity( + im1, + im2, + data_range=None, + win_size: int = 3, + num_scales: int = 3, + **kwargs, +) -> float: + """ + Multi-scale contrast-structure similarity (geometric mean across scales). + + Computes the contrast-structure (CS) component of SSIM at each scale of a + bilinear downsampling pyramid, then combines as a geometric mean: + result = (cs_1 * cs_2 * ... * cs_N)^(1/N) + + Luminance is omitted. Equal scale weights are used. + + Efficient implementation: + - fused filter pass: all statistics filtered in one call per scale + - bilinear downsampling pyramid via affine_transform + - fully on-device: only the final scalar crosses the device boundary + + Parameters + ---------- + im1, im2 : ndarray + Arrays from any supported backend (numpy, JAX, cupy). + data_range : float, optional + Computed from im2 if not provided. + win_size : int + Box filter size in pixels (default 3). + num_scales : int + Number of pyramid levels (default 3). + + Returns + ------- + mssim : float + MS-SSIM value in [0, 1]. + """ + from phaser.utils.image import affine_transform as _affine_transform + + def _resample(im, target_shape): + scale_y = im.shape[-2] / target_shape[-2] + scale_x = im.shape[-1] / target_shape[-1] + matrix = numpy.array([[scale_y, 0.0], [0.0, scale_x]]) + offset = numpy.array([0.5 * (scale_y - 1.0), 0.5 * (scale_x - 1.0)]) + return _affine_transform(im, matrix, offset=offset, output_shape=target_shape[-2:], order=1) + + xp = get_array_module(im1, im2) + + im1 = im1.astype(numpy.float64) + im2 = im2.astype(numpy.float64) + + if im1.shape != im2.shape: + im2 = _resample(im2, im1.shape) + if data_range is None: + data_range = float(im2.max() - im2.min()) + + C2 = (0.03 * data_range) ** 2 + + pad = (win_size - 1) // 2 + weight = 1.0 / num_scales + + mssim = 1.0 + for scale in range(num_scales): + if min(im1.shape[-2:]) < win_size: + break + + # fused: stack [im1, im2, im1², im2², im1·im2] and filter in one pass + stacked = xp.stack([im1, im2, im1 * im1, im2 * im2, im1 * im2]) + f = _uniform_filter_spatial(stacked, win_size, xp) + ux, uy, uxx, uyy, uxy = f[0], f[1], f[2], f[3], f[4] + + vx = uxx - ux * ux + vy = uyy - uy * uy + vxy = uxy - ux * uy + + # crop boundary artifacts + s = (slice(pad, -pad), slice(pad, -pad)) + vx, vy, vxy = vx[s], vy[s], vxy[s] + + cs = float(xp.mean((2 * vxy + C2) / (vx + vy + C2))) + mssim *= cs ** weight + + if scale < num_scales - 1: + new_shape = (im1.shape[0], im1.shape[-2] // 2, im1.shape[-1] // 2) + im1 = _resample(im1, new_shape) + im2 = _resample(im2, new_shape) + + return mssim \ No newline at end of file