diff --git a/src/ophyd_async/epics/pmac/_pmac_trajectory.py b/src/ophyd_async/epics/pmac/_pmac_trajectory.py index 235b3adf90..7f1403f3e8 100644 --- a/src/ophyd_async/epics/pmac/_pmac_trajectory.py +++ b/src/ophyd_async/epics/pmac/_pmac_trajectory.py @@ -14,6 +14,7 @@ DEFAULT_TIMEOUT, AsyncStatus, Device, + Reference, error_if_none, gather_dict, observe_value, @@ -52,7 +53,7 @@ class PmacTrajectoryTriggerLogic( Flyable, ): def __init__(self, pmac: PmacIO, name: str = "") -> None: - self.pmac = pmac + self.pmac_ref = Reference(pmac) self._next_pvt: PVT | None self._loaded: int = 0 self._trajectory_status: AsyncStatus | None = None @@ -65,7 +66,7 @@ async def prepare(self, value: Spec[Motor]): slice = path.consume(SLICE_SIZE) path_length = len(path) motors = slice.axes() - motor_info = await _PmacMotorInfo.from_motors(self.pmac, motors) + motor_info = await _PmacMotorInfo.from_motors(self.pmac_ref(), motors) ramp_up_pos, ramp_up_time = calculate_ramp_position_and_duration( slice, motor_info, True ) @@ -88,7 +89,7 @@ async def kickoff(self): ) # Wait for the ramp up to happen await wait_for_value( - self.pmac.trajectory.total_points, + self.pmac_ref().trajectory.total_points, lambda v: v >= 1, prepare_context.ramp_up_time + DEFAULT_TIMEOUT, ) @@ -108,16 +109,16 @@ async def stage(self) -> None: await self._stop_if_running() # Run an empty fly scan to reset EQU on Panda Brick - for use_axis in self.pmac.trajectory.use_axis.values(): + for use_axis in self.pmac_ref().trajectory.use_axis.values(): await use_axis.set(False) await asyncio.gather( - self.pmac.trajectory.time_array.set(np.array(0)), - self.pmac.trajectory.user_array.set(np.array(UserProgram.END)), - self.pmac.trajectory.points_to_build.set(1), + self.pmac_ref().trajectory.time_array.set(np.array(0)), + self.pmac_ref().trajectory.user_array.set(np.array(UserProgram.END)), + self.pmac_ref().trajectory.points_to_build.set(1), ) - await self.pmac.trajectory.build_profile.trigger() - await self.pmac.trajectory.execute_profile.set(True) + await self.pmac_ref().trajectory.build_profile.trigger() + await self.pmac_ref().trajectory.execute_profile.set(True) @AsyncStatus.wrap async def unstage(self) -> None: @@ -126,21 +127,23 @@ async def unstage(self) -> None: async def _stop_if_running(self): # Abort current trajectory, if one is running if ( - await self.pmac.trajectory.execute_state.get_value() + await self.pmac_ref().trajectory.execute_state.get_value() == PmacExecuteState.EXECUTING ): - await self.pmac.trajectory.abort_profile.trigger() + await self.pmac_ref().trajectory.abort_profile.trigger() @AsyncStatus.wrap async def _execute_trajectory(self, path: Path, motor_info: _PmacMotorInfo): - execute_status = self.pmac.trajectory.execute_profile.set(True, timeout=None) + execute_status = self.pmac_ref().trajectory.execute_profile.set( + True, timeout=None + ) # We consume SLICE_SIZE from self.path and parse a trajectory # containing at least 2 * SLICE_SIZE, as a gapless trajectory # will contain 2 points per slice frame. If gaps are present, # additional points are inserted, overfilling the buffer. min_buffer_size = SLICE_SIZE * 2 async for current_point in observe_value( - self.pmac.trajectory.total_points, + self.pmac_ref().trajectory.total_points, done_status=execute_status, timeout=DEFAULT_TIMEOUT, ): @@ -156,7 +159,7 @@ async def _append_trajectory( ): trajectory = await self._parse_trajectory(slice, path_length, motor_info) await self._set_trajectory_arrays(trajectory, motor_info) - await self.pmac.trajectory.append_profile.trigger() + await self.pmac_ref().trajectory.append_profile.trigger() async def _build_trajectory( self, @@ -173,16 +176,16 @@ async def _build_trajectory( } coros = [ - self.pmac.trajectory.profile_cs_name.set(motor_info.cs_port), - self.pmac.trajectory.calculate_velocities.set(False), + self.pmac_ref().trajectory.profile_cs_name.set(motor_info.cs_port), + self.pmac_ref().trajectory.calculate_velocities.set(False), self._set_trajectory_arrays(trajectory, motor_info), ] + [ - self.pmac.trajectory.use_axis[number].set(use) + self.pmac_ref().trajectory.use_axis[number].set(use) for number, use in use_axis.items() ] await asyncio.gather(*coros) - await self.pmac.trajectory.build_profile.trigger() + await self.pmac_ref().trajectory.build_profile.trigger() async def _parse_trajectory( self, @@ -216,20 +219,24 @@ async def _set_trajectory_arrays( coros = [] for motor, cs_index in motor_info.motor_cs_index.items(): coros.append( - self.pmac.trajectory.positions[cs_index].set( - trajectory.positions[motor] - ) + self.pmac_ref() + .trajectory.positions[cs_index] + .set(trajectory.positions[motor]) ) coros.append( - self.pmac.trajectory.velocities[cs_index].set( - trajectory.velocities[motor] - ) + self.pmac_ref() + .trajectory.velocities[cs_index] + .set(trajectory.velocities[motor]) ) coros.extend( [ - self.pmac.trajectory.time_array.set(trajectory.durations / TICK_S), - self.pmac.trajectory.user_array.set(trajectory.user_programs), - self.pmac.trajectory.points_to_build.set(len(trajectory.durations)), + self.pmac_ref().trajectory.time_array.set( + trajectory.durations / TICK_S + ), + self.pmac_ref().trajectory.user_array.set(trajectory.user_programs), + self.pmac_ref().trajectory.points_to_build.set( + len(trajectory.durations) + ), ] ) await asyncio.gather(*coros) @@ -237,7 +244,7 @@ async def _set_trajectory_arrays( async def _move_to_start( self, motor_info: _PmacMotorInfo, ramp_up_position: dict[Motor, np.float64] ): - coord = self.pmac.coord[motor_info.cs_number] + coord = self.pmac_ref().coord[motor_info.cs_number] coros = [] await coord.defer_moves.set(True) diff --git a/tests/unit_tests/epics/pmac/test_pmac_trajectory.py b/tests/unit_tests/epics/pmac/test_pmac_trajectory.py index 334ba7f23e..c63b5345f3 100644 --- a/tests/unit_tests/epics/pmac/test_pmac_trajectory.py +++ b/tests/unit_tests/epics/pmac/test_pmac_trajectory.py @@ -227,18 +227,18 @@ async def test_pmac_trajectory_complete(sim_motors: tuple[PmacIO, Motor, Motor]) async def test_pmac_trajectory_stage(sim_motors: tuple[PmacIO, Motor, Motor]): pmac_io, _, _ = sim_motors pmac_trajectory = PmacTrajectoryTriggerLogic(pmac_io) - mock_pmac_trajectory_io = get_mock(pmac_trajectory.pmac.trajectory) + mock_pmac_trajectory_io = get_mock(pmac_trajectory.pmac_ref().trajectory) await pmac_trajectory.stage() # Check that all axes are then set not be used assert all( get_mock(axis).put.assert_called_once_with(False) is None - for axis in pmac_trajectory.pmac.trajectory.use_axis.values() + for axis in pmac_trajectory.pmac_ref().trajectory.use_axis.values() ) # Check that an empty trajectory is then executed assert mock_pmac_trajectory_io.mock_calls[ - len(pmac_trajectory.pmac.trajectory.use_axis) : + len(pmac_trajectory.pmac_ref().trajectory.use_axis) : ] == [ call.time_array.put(np.array(0)), call.user_array.put(np.array(8)), @@ -268,7 +268,7 @@ async def test_trajectory_stop_if_running(sim_motors: tuple[PmacIO, Motor, Motor # Mocking that trajectory is executing set_mock_value( - pmac_trajectory.pmac.trajectory.execute_state, PmacExecuteState.EXECUTING + pmac_trajectory.pmac_ref().trajectory.execute_state, PmacExecuteState.EXECUTING ) # Method called as there is now a running trajectory