-
Notifications
You must be signed in to change notification settings - Fork 18
add ssim observer #50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
4c351be
49e53bd
23ff7df
6ee250f
6c3f51f
8c48b09
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,7 @@ | |
| from .hooks import EngineHook, Hook, ObjectHook, RawData | ||
| from .plan import GradientEnginePlan, ReconsPlan, EnginePlan, ScanHook, ProbeHook, TiltHook | ||
| from .state import Patterns, ReconsState, PartialReconsState, IterState, PreparedRecons | ||
| from .observer import Observer, LoggingObserver, PatienceObserver, SaveObserver, ObserverSet | ||
| from .observer import Observer, LoggingObserver, PatienceObserver, RelMsSSIMObserver, SaveObserver, ObserverSet | ||
|
|
||
|
|
||
| def execute_plan( | ||
|
|
@@ -57,10 +57,23 @@ def execute_engine( | |
|
|
||
| engine_i = recons.state.iter.engine_num | ||
|
|
||
| if plan.early_termination: | ||
| engine_observer = ObserverSet((recons.observer, PatienceObserver( | ||
| plan.early_termination, plan.early_termination_smoothing | ||
| ))) | ||
| extra_observers: t.List[Observer] = [] | ||
|
|
||
| if plan.calc_rel_msssim is not False: | ||
| extra_observers.append(RelMsSSIMObserver(plan.calc_rel_msssim)) | ||
|
|
||
| if any(v is not None for v in ( | ||
| plan.early_termination_loss, plan.early_termination_obj_rel_msssim, plan.early_termination_probe_rel_msssim | ||
| )): | ||
| extra_observers.append(PatienceObserver( | ||
| patience_loss=plan.early_termination_loss, | ||
| patience_obj_rel_msssim=plan.early_termination_obj_rel_msssim, | ||
| patience_probe_rel_msssim=plan.early_termination_probe_rel_msssim, | ||
| smoothing=plan.early_termination_smoothing, | ||
| )) | ||
|
|
||
| if extra_observers: | ||
| engine_observer = ObserverSet((recons.observer, *extra_observers)) | ||
| else: | ||
| engine_observer = recons.observer | ||
|
|
||
|
|
@@ -307,6 +320,8 @@ def initialize_reconstruction( | |
| if init_state.scan is not None and plan.init.scan is None: | ||
| logging.info("Re-using scan from initial state...") | ||
| scan = init_state.scan | ||
| scan = scan.astype(dtype) | ||
|
|
||
| else: | ||
| logging.info("Initializing scan...") | ||
| scan = pane.from_data(scan_hook, ScanHook)( # type: ignore | ||
|
|
@@ -316,6 +331,8 @@ def initialize_reconstruction( | |
| if init_state.tilt is not None and plan.init.tilt is None: | ||
| logging.info("Re-using tilt from initial state...") | ||
| tilt = init_state.tilt | ||
| tilt = tilt.astype(dtype) | ||
|
|
||
| elif tilt_hook is not None: | ||
|
Comment on lines
323
to
336
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why was this changed? I don't think scan and tilt should have to be
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've run into errors of dtype conflict when loading init state .h5 before, i am not sure if the recent changes have fixed that. |
||
| logging.info("Initializing tilt...") | ||
| tilt = pane.from_data(tilt_hook, TiltHook)( # type: ignore | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,7 @@ | |
| from phaser.types import EarlyTermination, flag_any_true, process_flag | ||
|
|
||
| if t.TYPE_CHECKING: | ||
| from phaser.hooks.schedule import FlagArgs | ||
| from phaser.hooks.schedule import FlagArgs, FlagLike | ||
| from typing_extensions import Self | ||
|
|
||
| P = t.ParamSpec('P') | ||
|
|
@@ -169,44 +169,143 @@ def finish_recons(self, state: ReconsState): | |
|
|
||
|
|
||
| class PatienceObserver(Observer): | ||
| def __init__(self, patience: int, smoothing: float = 0.1, continue_next_engine: bool = True): | ||
| self.patience: int = patience | ||
| self.no_improvement_iter: int = 0 | ||
| self.best_error: t.Optional[float] = None | ||
| self.smoothed_error: t.Optional[float] = None | ||
| # metrics where higher values indicate improvement | ||
| _HIGHER_IS_BETTER: t.FrozenSet[str] = frozenset({'obj_rel_msssim', 'probe_rel_msssim'}) | ||
|
|
||
| def __init__( | ||
| self, | ||
| patience_loss: t.Optional[int] = None, | ||
| patience_obj_rel_msssim: t.Optional[int] = None, | ||
| patience_probe_rel_msssim: t.Optional[int] = None, | ||
| smoothing: float = 0.1, | ||
| continue_next_engine: bool = True, | ||
| ): | ||
| self.smoothing: float = smoothing | ||
| self.continue_next_engine: bool = continue_next_engine | ||
|
|
||
| # build active metric table: key -> patience | ||
| self._patience: t.Dict[str, int] = {} | ||
| if patience_loss is not None: | ||
| self._patience['total_loss'] = patience_loss | ||
| if patience_obj_rel_msssim is not None: | ||
| self._patience['obj_rel_msssim'] = patience_obj_rel_msssim | ||
| if patience_probe_rel_msssim is not None: | ||
| self._patience['probe_rel_msssim'] = patience_probe_rel_msssim | ||
|
|
||
| self._best: t.Dict[str, float] = {} | ||
| self._last_improvement_iter: t.Dict[str, int] = {} | ||
| self._smoothed: t.Dict[str, float] = {} | ||
|
|
||
| def init_engine( | ||
| self, init_state: ReconsState, *, recons_name: str, | ||
| plan: EnginePlan, **kwargs: t.Any | ||
| ): | ||
| self.no_improvement_iter = 0 | ||
| self._best = {} | ||
| self._last_improvement_iter = {} | ||
| self._smoothed = {} | ||
|
|
||
| def update_iteration(self, state: ReconsState, i: int, n: int, errors: t.Dict[str, float]): | ||
| current_iter = int(state.iter.total_iter) | ||
|
|
||
| for key, patience in self._patience.items(): | ||
| # read value: loss from errors dict every iteration; | ||
| # ssim metrics only when a new value was computed this iteration | ||
| if key == 'total_loss': | ||
| value: t.Optional[float] = errors.get('total_loss') | ||
| else: | ||
| prog = state.progress.get(key) if state.progress else None | ||
| if prog is None or not len(prog.values): | ||
| continue | ||
| # skip if no new ssim value was produced this iteration | ||
| if not len(prog.iters) or prog.iters[-1] != current_iter: | ||
| continue | ||
| value = prog.values[-1] | ||
|
|
||
| if value is None: | ||
| continue | ||
|
|
||
| # exponential moving average | ||
| if key not in self._smoothed: | ||
| self._smoothed[key] = value | ||
| else: | ||
| self._smoothed[key] = (1 - self.smoothing) * self._smoothed[key] + self.smoothing * value | ||
|
|
||
| higher_is_better = key in self._HIGHER_IS_BETTER | ||
| improved = ( | ||
| key not in self._best | ||
| or (higher_is_better and value > self._best[key]) | ||
| or (not higher_is_better and value < self._best[key]) | ||
| ) | ||
|
|
||
| if improved: | ||
| self._best[key] = value | ||
| self._last_improvement_iter[key] = current_iter | ||
|
|
||
| def _error_from_state(self, state: t.Union[ReconsState, PartialReconsState]) -> t.Optional[float]: | ||
| if state.progress is None or (progress := state.progress['total_loss']) is None or not len(progress.values): | ||
| return None | ||
| return progress.values[-1] | ||
| iters_without_improvement = current_iter - self._last_improvement_iter.get(key, current_iter) | ||
| if iters_without_improvement >= patience: | ||
| logging.info( | ||
| f"Early termination: {key} no improvement for {iters_without_improvement} iterations" | ||
| ) | ||
| raise EarlyTermination(state, self.continue_next_engine) | ||
|
Comment on lines
+244
to
+249
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How should we handle the following case? early_termination_obj_rel_msssim: 5
calc_rel_msssim: {'every': 10}Right now it looks like it'll basically ignore patience; maybe an alternate way to specify is that the patience number is the number of evaluations, not the number of iterations |
||
|
|
||
|
|
||
| class RelMsSSIMObserver(Observer): | ||
| """Computes obj_rel_msssim and probe_rel_msssim at each calc_rel_msssim flag firing.""" | ||
|
|
||
| def __init__(self, calc_rel_msssim: 'FlagLike'): | ||
| from phaser.types import process_flag, flag_any_true | ||
| self._calc_rel_msssim_raw = calc_rel_msssim | ||
| self._calc_rel_msssim_flag = process_flag(calc_rel_msssim) | ||
| self._ssim_enabled: bool = False | ||
| # CPU-side snapshot: (total_iter, obj_phase, probe_abs) as numpy arrays | ||
| self._prev_snapshot: t.Optional[t.Tuple[int, 'numpy.ndarray', 'numpy.ndarray']] = None | ||
|
|
||
| def init_engine( | ||
| self, init_state: ReconsState, *, recons_name: str, | ||
| plan: EnginePlan, **kwargs: t.Any | ||
| ): | ||
| from phaser.types import flag_any_true | ||
| self._ssim_enabled = flag_any_true(self._calc_rel_msssim_raw, plan.niter) | ||
| self._prev_snapshot = None | ||
|
|
||
| if self._ssim_enabled: | ||
| for k in ('obj_rel_msssim', 'probe_rel_msssim'): | ||
| if k not in init_state.progress: | ||
| init_state.progress[k] = ProgressState() | ||
|
|
||
| def update_iteration(self, state: ReconsState, i: int, n: int, errors: t.Dict[str, float]): | ||
| if (error := errors.get('total_loss')) is None: | ||
| if not self._ssim_enabled: | ||
| return | ||
| if not self._calc_rel_msssim_flag({'state': state, 'niter': n}): | ||
| return | ||
|
|
||
| if self.best_error is None or error < self.best_error: | ||
| self.best_error = error | ||
| self.no_improvement_iter = 0 | ||
| else: | ||
| self.no_improvement_iter += 1 | ||
| from phaser.utils.num import get_array_module, to_numpy | ||
| from phaser.utils.analysis import structural_similarity | ||
|
|
||
| # Exponential moving average | ||
| if self.smoothed_error is None: | ||
| self.smoothed_error = error | ||
| else: | ||
| self.smoothed_error = (1 - self.smoothing) * self.smoothed_error + self.smoothing * error | ||
| xp = get_array_module(state.object.data) | ||
| total_iter = int(state.iter.total_iter) | ||
|
|
||
| # transfer only the two arrays needed; forces GPU→CPU sync here | ||
| obj_now = to_numpy(xp.angle(state.object.data)) | ||
| probe_now = to_numpy(xp.abs(state.probe.data)) | ||
|
Comment on lines
+289
to
+290
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you want to be doing just the phase of object and amplitude of probe? Have you compared the approaches? |
||
|
|
||
| if self._prev_snapshot is not None: | ||
| prev_iter, obj_prev, probe_prev = self._prev_snapshot | ||
|
|
||
| ssim_o = structural_similarity(obj_now, obj_prev) | ||
| state.progress['obj_rel_msssim'].iters.append(total_iter) | ||
| state.progress['obj_rel_msssim'].values.append(ssim_o) | ||
|
|
||
| ssim_p = structural_similarity(probe_now, probe_prev) | ||
| state.progress['probe_rel_msssim'].iters.append(total_iter) | ||
| state.progress['probe_rel_msssim'].values.append(ssim_p) | ||
|
|
||
| logging.info( | ||
| f"Relative multiscale SSIM (iters {prev_iter}→{total_iter}): " | ||
| f"obj={ssim_o:.4f} probe={ssim_p:.4f}" | ||
| ) | ||
|
|
||
| if self.no_improvement_iter >= self.patience: | ||
| logging.info(f"Early termination: no improvement for {self.patience} iterations") | ||
| raise EarlyTermination(state, self.continue_next_engine) | ||
| self._prev_snapshot = (total_iter, obj_now, probe_now) | ||
|
|
||
|
|
||
| class SaveObserver(Observer): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these too