Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions phaser/engines/conventional/run.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import typing as t
import numpy

from phaser.utils.misc import mask_fraction_of_groups
from phaser.utils.num import assert_dtype, cast_array_module, to_numpy, to_complex_dtype
Expand Down
3 changes: 2 additions & 1 deletion phaser/engines/gradient/run.py
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these too

Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from phaser.hooks.solver import GradientSolver
from phaser.hooks.regularization import CostRegularizer, GroupConstraint
from phaser.plan import GradientEnginePlan
from phaser.types import process_flag, ReconsVar
from phaser.types import process_flag, flag_any_true, ReconsVar
from ..common.simulation import GroupManager, make_propagators, tilt_propagators, slice_forwards, stream_patterns


Expand Down Expand Up @@ -314,6 +314,7 @@ def iter_patterns(groups: t.Iterable[NDArray[numpy.int_]]) -> t.Iterable[t.Tuple
(update, iter_solver_states[sol_i]) = solver.update(
state, iter_solver_states[sol_i], filter_vars(iter_grads, solver.params), losses['total_loss']
)

state = apply_update(state, update)

if 'positions' in update:
Expand Down
27 changes: 22 additions & 5 deletions phaser/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .hooks import EngineHook, Hook, ObjectHook, RawData
from .plan import GradientEnginePlan, ReconsPlan, EnginePlan, ScanHook, ProbeHook, TiltHook
from .state import Patterns, ReconsState, PartialReconsState, IterState, PreparedRecons
from .observer import Observer, LoggingObserver, PatienceObserver, SaveObserver, ObserverSet
from .observer import Observer, LoggingObserver, PatienceObserver, RelMsSSIMObserver, SaveObserver, ObserverSet


def execute_plan(
Expand Down Expand Up @@ -57,10 +57,23 @@ def execute_engine(

engine_i = recons.state.iter.engine_num

if plan.early_termination:
engine_observer = ObserverSet((recons.observer, PatienceObserver(
plan.early_termination, plan.early_termination_smoothing
)))
extra_observers: t.List[Observer] = []

if plan.calc_rel_msssim is not False:
extra_observers.append(RelMsSSIMObserver(plan.calc_rel_msssim))

if any(v is not None for v in (
plan.early_termination_loss, plan.early_termination_obj_rel_msssim, plan.early_termination_probe_rel_msssim
)):
extra_observers.append(PatienceObserver(
patience_loss=plan.early_termination_loss,
patience_obj_rel_msssim=plan.early_termination_obj_rel_msssim,
patience_probe_rel_msssim=plan.early_termination_probe_rel_msssim,
smoothing=plan.early_termination_smoothing,
))

if extra_observers:
engine_observer = ObserverSet((recons.observer, *extra_observers))
else:
engine_observer = recons.observer

Expand Down Expand Up @@ -307,6 +320,8 @@ def initialize_reconstruction(
if init_state.scan is not None and plan.init.scan is None:
logging.info("Re-using scan from initial state...")
scan = init_state.scan
scan = scan.astype(dtype)

else:
logging.info("Initializing scan...")
scan = pane.from_data(scan_hook, ScanHook)( # type: ignore
Expand All @@ -316,6 +331,8 @@ def initialize_reconstruction(
if init_state.tilt is not None and plan.init.tilt is None:
logging.info("Re-using tilt from initial state...")
tilt = init_state.tilt
tilt = tilt.astype(dtype)

elif tilt_hook is not None:
Comment on lines 323 to 336
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this changed? I don't think scan and tilt should have to be dtype, double-precision is fine

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've run into errors of dtype conflict when loading init state .h5 before, i am not sure if the recent changes have fixed that.

logging.info("Initializing tilt...")
tilt = pane.from_data(tilt_hook, TiltHook)( # type: ignore
Expand Down
149 changes: 124 additions & 25 deletions phaser/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from phaser.types import EarlyTermination, flag_any_true, process_flag

if t.TYPE_CHECKING:
from phaser.hooks.schedule import FlagArgs
from phaser.hooks.schedule import FlagArgs, FlagLike
from typing_extensions import Self

P = t.ParamSpec('P')
Expand Down Expand Up @@ -169,44 +169,143 @@ def finish_recons(self, state: ReconsState):


class PatienceObserver(Observer):
def __init__(self, patience: int, smoothing: float = 0.1, continue_next_engine: bool = True):
self.patience: int = patience
self.no_improvement_iter: int = 0
self.best_error: t.Optional[float] = None
self.smoothed_error: t.Optional[float] = None
# metrics where higher values indicate improvement
_HIGHER_IS_BETTER: t.FrozenSet[str] = frozenset({'obj_rel_msssim', 'probe_rel_msssim'})

def __init__(
self,
patience_loss: t.Optional[int] = None,
patience_obj_rel_msssim: t.Optional[int] = None,
patience_probe_rel_msssim: t.Optional[int] = None,
smoothing: float = 0.1,
continue_next_engine: bool = True,
):
self.smoothing: float = smoothing
self.continue_next_engine: bool = continue_next_engine

# build active metric table: key -> patience
self._patience: t.Dict[str, int] = {}
if patience_loss is not None:
self._patience['total_loss'] = patience_loss
if patience_obj_rel_msssim is not None:
self._patience['obj_rel_msssim'] = patience_obj_rel_msssim
if patience_probe_rel_msssim is not None:
self._patience['probe_rel_msssim'] = patience_probe_rel_msssim

self._best: t.Dict[str, float] = {}
self._last_improvement_iter: t.Dict[str, int] = {}
self._smoothed: t.Dict[str, float] = {}

def init_engine(
self, init_state: ReconsState, *, recons_name: str,
plan: EnginePlan, **kwargs: t.Any
):
self.no_improvement_iter = 0
self._best = {}
self._last_improvement_iter = {}
self._smoothed = {}

def update_iteration(self, state: ReconsState, i: int, n: int, errors: t.Dict[str, float]):
current_iter = int(state.iter.total_iter)

for key, patience in self._patience.items():
# read value: loss from errors dict every iteration;
# ssim metrics only when a new value was computed this iteration
if key == 'total_loss':
value: t.Optional[float] = errors.get('total_loss')
else:
prog = state.progress.get(key) if state.progress else None
if prog is None or not len(prog.values):
continue
# skip if no new ssim value was produced this iteration
if not len(prog.iters) or prog.iters[-1] != current_iter:
continue
value = prog.values[-1]

if value is None:
continue

# exponential moving average
if key not in self._smoothed:
self._smoothed[key] = value
else:
self._smoothed[key] = (1 - self.smoothing) * self._smoothed[key] + self.smoothing * value

higher_is_better = key in self._HIGHER_IS_BETTER
improved = (
key not in self._best
or (higher_is_better and value > self._best[key])
or (not higher_is_better and value < self._best[key])
)

if improved:
self._best[key] = value
self._last_improvement_iter[key] = current_iter

def _error_from_state(self, state: t.Union[ReconsState, PartialReconsState]) -> t.Optional[float]:
if state.progress is None or (progress := state.progress['total_loss']) is None or not len(progress.values):
return None
return progress.values[-1]
iters_without_improvement = current_iter - self._last_improvement_iter.get(key, current_iter)
if iters_without_improvement >= patience:
logging.info(
f"Early termination: {key} no improvement for {iters_without_improvement} iterations"
)
raise EarlyTermination(state, self.continue_next_engine)
Comment on lines +244 to +249
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How should we handle the following case?

early_termination_obj_rel_msssim: 5
calc_rel_msssim: {'every': 10}

Right now it looks like it'll basically ignore patience; maybe an alternate way to specify is that the patience number is the number of evaluations, not the number of iterations



class RelMsSSIMObserver(Observer):
"""Computes obj_rel_msssim and probe_rel_msssim at each calc_rel_msssim flag firing."""

def __init__(self, calc_rel_msssim: 'FlagLike'):
from phaser.types import process_flag, flag_any_true
self._calc_rel_msssim_raw = calc_rel_msssim
self._calc_rel_msssim_flag = process_flag(calc_rel_msssim)
self._ssim_enabled: bool = False
# CPU-side snapshot: (total_iter, obj_phase, probe_abs) as numpy arrays
self._prev_snapshot: t.Optional[t.Tuple[int, 'numpy.ndarray', 'numpy.ndarray']] = None

def init_engine(
self, init_state: ReconsState, *, recons_name: str,
plan: EnginePlan, **kwargs: t.Any
):
from phaser.types import flag_any_true
self._ssim_enabled = flag_any_true(self._calc_rel_msssim_raw, plan.niter)
self._prev_snapshot = None

if self._ssim_enabled:
for k in ('obj_rel_msssim', 'probe_rel_msssim'):
if k not in init_state.progress:
init_state.progress[k] = ProgressState()

def update_iteration(self, state: ReconsState, i: int, n: int, errors: t.Dict[str, float]):
if (error := errors.get('total_loss')) is None:
if not self._ssim_enabled:
return
if not self._calc_rel_msssim_flag({'state': state, 'niter': n}):
return

if self.best_error is None or error < self.best_error:
self.best_error = error
self.no_improvement_iter = 0
else:
self.no_improvement_iter += 1
from phaser.utils.num import get_array_module, to_numpy
from phaser.utils.analysis import structural_similarity

# Exponential moving average
if self.smoothed_error is None:
self.smoothed_error = error
else:
self.smoothed_error = (1 - self.smoothing) * self.smoothed_error + self.smoothing * error
xp = get_array_module(state.object.data)
total_iter = int(state.iter.total_iter)

# transfer only the two arrays needed; forces GPU→CPU sync here
obj_now = to_numpy(xp.angle(state.object.data))
probe_now = to_numpy(xp.abs(state.probe.data))
Comment on lines +289 to +290
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to be doing just the phase of object and amplitude of probe? Have you compared the approaches?


if self._prev_snapshot is not None:
prev_iter, obj_prev, probe_prev = self._prev_snapshot

ssim_o = structural_similarity(obj_now, obj_prev)
state.progress['obj_rel_msssim'].iters.append(total_iter)
state.progress['obj_rel_msssim'].values.append(ssim_o)

ssim_p = structural_similarity(probe_now, probe_prev)
state.progress['probe_rel_msssim'].iters.append(total_iter)
state.progress['probe_rel_msssim'].values.append(ssim_p)

logging.info(
f"Relative multiscale SSIM (iters {prev_iter}→{total_iter}): "
f"obj={ssim_o:.4f} probe={ssim_p:.4f}"
)

if self.no_improvement_iter >= self.patience:
logging.info(f"Early termination: no improvement for {self.patience} iterations")
raise EarlyTermination(state, self.continue_next_engine)
self._prev_snapshot = (total_iter, obj_now, probe_now)


class SaveObserver(Observer):
Expand Down
13 changes: 10 additions & 3 deletions phaser/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,22 @@ class EnginePlan(Dataclass, kw_only=True):
save_images: FlagLike = False
save_options: SaveOptions = SaveOptions()

early_termination: t.Optional[int] = None
"""Terminate after n iterations without improvement"""
early_termination_loss: t.Optional[int] = None
"""Terminate after n iterations without improvement in total_loss"""
early_termination_obj_rel_msssim: t.Optional[int] = None
"""Terminate after n iterations without improvement in obj_rel_msssim (requires calc_rel_msssim to be enabled)"""
early_termination_probe_rel_msssim: t.Optional[int] = None
"""Terminate after n iterations without improvement in probe_rel_msssim (requires calc_rel_msssim to be enabled)"""
early_termination_smoothing: float = 0.9
"""
Smoothing factor to apply to error measurement for early termination.
NOTE: Low smoothing factor means a large amount of smoothing!
(smooths over ~1/smoothing iterations)
"""

calc_rel_msssim: FlagLike = False
"""Compute SSIM between consecutive iterations as a convergence metric. Use SimpleFlag(every=N) to compute every N iterations."""

check_every_group: bool = False
send_every_group: bool = False

Expand Down Expand Up @@ -154,7 +161,7 @@ class GradientEnginePlan(EnginePlan):
regularizers: t.List[CostRegularizerHook]
group_constraints: t.List[GroupConstraintHook]
iter_constraints: t.List[IterConstraintHook]


class SGDSolverPlan(Dataclass, kw_only=True):
learning_rate: ScheduleLike
Expand Down
Loading
Loading