From b349dee3301efbf4d0eae3a4713b837bb6c798f4 Mon Sep 17 00:00:00 2001 From: Oli Wenman Date: Wed, 11 Mar 2026 17:25:38 +0000 Subject: [PATCH] Concept of checkbeam device --- src/dodal/beamlines/i09.py | 22 +++++ src/dodal/devices/pause_plan_device.py | 92 ++++++++++++++++++ tests/devices/test_pause_plan_device.py | 119 ++++++++++++++++++++++++ 3 files changed, 233 insertions(+) create mode 100644 src/dodal/devices/pause_plan_device.py create mode 100644 tests/devices/test_pause_plan_device.py diff --git a/src/dodal/beamlines/i09.py b/src/dodal/beamlines/i09.py index aad9259dbdc..82a8901f74a 100644 --- a/src/dodal/beamlines/i09.py +++ b/src/dodal/beamlines/i09.py @@ -11,6 +11,7 @@ from dodal.devices.fast_shutter import DualFastShutter, GenericFastShutter from dodal.devices.hutch_shutter import EXP_SHUTTER_2_INFIX, HutchShutter from dodal.devices.motors import XYZAzimuthPolarStage +from dodal.devices.pause_plan_device import PausePlanDevice from dodal.devices.pgm import PlaneGratingMonochromator from dodal.devices.selectable_source import SourceSelector from dodal.devices.synchrotron import Synchrotron @@ -126,3 +127,24 @@ def lakeshore() -> Lakeshore336: def smpm() -> XYZAzimuthPolarStage: """Sample Manipulator.""" return XYZAzimuthPolarStage(prefix=f"{I_PREFIX.beamline_prefix}-MO-SMPM-01:") + + +@devices.factory() +def checkbeam( + synchrotron: Synchrotron, dual_fast_shutter: DualFastShutter +) -> PausePlanDevice: + async def _close_shutters(): + await dual_fast_shutter.set(dual_fast_shutter.close_state) + + async def _open_shutters(): + await dual_fast_shutter.set(dual_fast_shutter.open_state) + + checkbeam = PausePlanDevice( + signals_to_condition={ + synchrotron.current: lambda rc: rc > 190, + synchrotron.top_up_start_countdown: lambda topup: topup < 5, + }, + callable_when_paused=_close_shutters, + callable_on_resume=_open_shutters, + ) + return checkbeam diff --git a/src/dodal/devices/pause_plan_device.py b/src/dodal/devices/pause_plan_device.py new file mode 100644 index 00000000000..bf2e5622abf --- /dev/null +++ b/src/dodal/devices/pause_plan_device.py @@ -0,0 +1,92 @@ +import asyncio +import contextlib +from collections.abc import Awaitable, Callable +from typing import Any + +from bluesky.protocols import Readable, Stageable +from ophyd_async.core import ( + AsyncStatus, + Device, + SignalR, + observe_value, +) + + +class PausePlanDevice(Device, Stageable, Readable): + def __init__( + self, + signals_to_condition: dict[SignalR[Any], Callable[[Any], bool]], + callable_when_paused: Callable[[], Awaitable[None]] | None = None, + callable_on_resume: Callable[[], Awaitable[None]] | None = None, + seconds_to_wait_before_resume: float = 5, + name: str = "", + ): + self._signals_to_condition = signals_to_condition + self._callable_when_paused = callable_when_paused + self._callable_on_resume = callable_on_resume + self._seconds_to_wait_before_resume = seconds_to_wait_before_resume + super().__init__(name) + + async def _pause(self): + """Pause until all signal conditions are met, calling hooks as needed.""" + # Check if we actually need to pause + values = await asyncio.gather( + *(sig.get_value() for sig in self._signals_to_condition) + ) + all_met = all( + pred(value) + for value, (_, pred) in zip( + values, self._signals_to_condition.items(), strict=True + ) + ) + if all_met: + return # no need to pause + + # Call pause hook + if self._callable_when_paused: + await self._callable_when_paused() + + latest = {} + event = asyncio.Event() + + async def watch(signal, predicate): + async for value in observe_value(signal): + latest[signal] = predicate(value) + if len(latest) == len(self._signals_to_condition) and all( + latest.values() + ): + event.set() + return + + tasks = [ + asyncio.create_task(watch(sig, pred)) + for sig, pred in self._signals_to_condition.items() + ] + + await event.wait() + + # Cancel watchers + for task in tasks: + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + + await asyncio.sleep(self._seconds_to_wait_before_resume) + # Call resume hook + if self._callable_on_resume: + await self._callable_on_resume() + + @AsyncStatus.wrap + async def stage(self): + await self._pause() + + @AsyncStatus.wrap + async def unstage(self): + pass + + async def read(self): + await self._pause() + return {} + + async def describe(self): + return {} diff --git a/tests/devices/test_pause_plan_device.py b/tests/devices/test_pause_plan_device.py new file mode 100644 index 00000000000..241e464447e --- /dev/null +++ b/tests/devices/test_pause_plan_device.py @@ -0,0 +1,119 @@ +import asyncio + +import pytest +from bluesky import RunEngine +from bluesky import plan_stubs as bps +from ophyd_async.core import InOut, SignalRW, init_devices, soft_signal_rw + +from dodal.devices.fast_shutter import GenericFastShutter +from dodal.devices.pause_plan_device import PausePlanDevice + + +@pytest.fixture +def shutter() -> GenericFastShutter: + with init_devices(mock=True): + shutter = GenericFastShutter( + "TEST:", open_state=InOut.OUT, close_state=InOut.IN + ) + return shutter + + +@pytest.fixture +def sig1() -> SignalRW[int]: + with init_devices(mock=True): + sig1 = soft_signal_rw(int, initial_value=0) + return sig1 + + +@pytest.fixture +def sig2() -> SignalRW[int]: + with init_devices(mock=True): + sig2 = soft_signal_rw(int, initial_value=0) + return sig2 + + +@pytest.fixture +async def pause_plan_device( + sig1: SignalRW[float], + sig2: SignalRW[float], + shutter: GenericFastShutter, +) -> PausePlanDevice: + + async def _close_shutter(): + await shutter.set(shutter.close_state) + + async def _open_shutter(): + await shutter.set(shutter.open_state) + + with init_devices(mock=True): + pause_plan_device = PausePlanDevice( + { + sig1: lambda v: v == 1, + sig2: lambda v: v > 5, + }, + callable_when_paused=_close_shutter, + callable_on_resume=_open_shutter, + seconds_to_wait_before_resume=0, + ) + return pause_plan_device + + +async def test_conditions_can_arrive_in_any_order( + pause_plan_device: PausePlanDevice, + sig1: SignalRW[float], + sig2: SignalRW[float], + shutter: GenericFastShutter, +): + await shutter.set(shutter.open_state) + + status = pause_plan_device.stage() + + await asyncio.sleep(0.1) + + assert not status.done + assert await shutter.shutter_state.get_value() == shutter.close_state + + await sig1.set(1) + await sig2.set(10) + + await status + assert status.success + assert await shutter.shutter_state.get_value() == shutter.open_state + + +@pytest.mark.asyncio +async def test_conditions_already_met( + pause_plan_device: PausePlanDevice, sig1: SignalRW[float], sig2: SignalRW[float] +): + await sig1.set(1) + await sig2.set(10) + + status = pause_plan_device.stage() + + await status + + assert status.success + + +async def test_pause_device_blocks_plan_until_conditions_met( + run_engine: RunEngine, + pause_plan_device: PausePlanDevice, + sig1: SignalRW[float], + sig2: SignalRW[float], +): + + start = asyncio.Event() + + async def update_signals(): + await start.wait() + await sig1.set(1) + await sig2.set(6) + + asyncio.create_task(update_signals()) + + def plan(): + start.set() + yield from bps.stage(pause_plan_device) + yield from bps.null() + + run_engine(plan())