diff --git a/README.md b/README.md index db0fc9a..b04bab9 100644 --- a/README.md +++ b/README.md @@ -24,10 +24,72 @@ mpiexec -np -bind-to hwthread python -u runscript.py Dict[str, Any]: + """Load a YAML config, resolving ``_extends`` chains and validating. + + ``_extends`` may be a single path or a list of paths. Each parent is + loaded recursively (parents may themselves ``_extends``) and merged + in order, with the current file's keys taking precedence. + + Relative paths in ``_extends`` resolve relative to the file that + declares them. + """ + cfg = _load_with_inheritance(path, _seen=set()) + validate(cfg) + return cfg + + +def _load_with_inheritance(path: str, *, _seen: set) -> Dict[str, Any]: + abspath = os.path.abspath(path) + if abspath in _seen: + raise ConfigError( + f"Circular `_extends` chain detected involving {abspath}" + ) + _seen = _seen | {abspath} + + with open(abspath, 'r') as f: + raw = yaml.safe_load(f) or {} + + extends = raw.pop('_extends', None) + if extends is None: + return raw + + if isinstance(extends, str): + parents: List[str] = [extends] + elif isinstance(extends, list): + parents = list(extends) + else: + raise ConfigError( + f"`_extends` in {abspath} must be a string or list, got {type(extends).__name__}" + ) + + here = os.path.dirname(abspath) + merged: Dict[str, Any] = {} + for parent in parents: + parent_path = parent if os.path.isabs(parent) else os.path.join(here, parent) + merged = deep_merge(merged, _load_with_inheritance(parent_path, _seen=_seen)) + return deep_merge(merged, raw) + + +def deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]: + """Recursively merge ``override`` onto ``base``, preferring ``override``. + + Nested dicts merge key-by-key. Lists and scalars are replaced, not + appended — this matches operator intuition ("override X" means + "replace X," not "extend X"). + """ + out = dict(base) + for k, v in override.items(): + if k in out and isinstance(out[k], dict) and isinstance(v, dict): + out[k] = deep_merge(out[k], v) + else: + out[k] = v + return out + + +# --------------------------------------------------------------------------- +# validation +# --------------------------------------------------------------------------- + + +# Required top-level keys and the rough shape we expect. Kept as plain +# code rather than a third-party schema lib so this module has no new +# install-time dependencies; the checks below are cheap and the error +# messages are deliberately operator-friendly. +_REQUIRED_TOP: Tuple[str, ...] = ( + 'rank', + 'algorithm', + 'sampling_rate', + 'files', + 'encoder', + 'decoder', + 'ripples', + 'kinematics', +) +_KNOWN_ALGORITHMS = ('clusterless_decoder', 'clusterless_classifier') +_KNOWN_DATASOURCES = ('trodes', 'synthetic') +_REQUIRED_RANK_ROLES = ('supervisor', 'decoders', 'encoders', 'ripples', 'gui') +_REQUIRED_SAMPLING = ('spikes', 'lfp', 'position') +_REQUIRED_FILES = ('output_dir', 'prefix') +_REQUIRED_ENCODER = ('mark_dim', 'bufsize', 'spk_amp', 'position') +_REQUIRED_ENCODER_POSITION = ('lower', 'upper', 'num_bins', 'arm_ids', 'arm_coords') +_REQUIRED_DECODER = ('bufsize', 'time_bin', 'cred_int_bufsize') + + +def validate(cfg: Dict[str, Any]) -> None: + """Raise ``ConfigError`` with a clear message if ``cfg`` is malformed.""" + errors: List[str] = [] + + for k in _REQUIRED_TOP: + if k not in cfg: + errors.append(f"missing required top-level key '{k}'") + + if cfg.get('algorithm') and cfg['algorithm'] not in _KNOWN_ALGORITHMS: + errors.append( + f"algorithm={cfg['algorithm']!r} is not one of {_KNOWN_ALGORITHMS}" + ) + + ds = cfg.get('datasource', 'trodes') + if ds not in _KNOWN_DATASOURCES: + errors.append( + f"datasource={ds!r} is not one of {_KNOWN_DATASOURCES}" + ) + + rank = cfg.get('rank', {}) + if isinstance(rank, dict): + for role in _REQUIRED_RANK_ROLES: + v = rank.get(role) + if v is None: + errors.append(f"rank.{role} is missing") + elif not isinstance(v, list) or not v: + errors.append(f"rank.{role} must be a non-empty list of ints, got {v!r}") + for role in ('supervisor', 'gui'): + if isinstance(rank.get(role), list) and len(rank[role]) != 1: + errors.append( + f"rank.{role} must contain exactly one rank, got {rank[role]!r}" + ) + else: + errors.append(f"rank must be a mapping, got {type(rank).__name__}") + + sr = cfg.get('sampling_rate', {}) + if isinstance(sr, dict): + for k in _REQUIRED_SAMPLING: + if k not in sr: + errors.append(f"sampling_rate.{k} is missing") + elif not isinstance(sr[k], (int, float)) or sr[k] <= 0: + errors.append(f"sampling_rate.{k} must be a positive number, got {sr[k]!r}") + + files = cfg.get('files', {}) + if isinstance(files, dict): + for k in _REQUIRED_FILES: + if not files.get(k): + errors.append(f"files.{k} is missing or empty") + + enc = cfg.get('encoder', {}) + if isinstance(enc, dict): + for k in _REQUIRED_ENCODER: + if k not in enc: + errors.append(f"encoder.{k} is missing") + pos = enc.get('position', {}) + if isinstance(pos, dict): + for k in _REQUIRED_ENCODER_POSITION: + if k not in pos: + errors.append(f"encoder.position.{k} is missing") + + dec = cfg.get('decoder', {}) + if isinstance(dec, dict): + for k in _REQUIRED_DECODER: + if k not in dec: + errors.append(f"decoder.{k} is missing") + tb = dec.get('time_bin', {}) + if isinstance(tb, dict): + for k in ('samples', 'delay_samples'): + if k not in tb: + errors.append(f"decoder.time_bin.{k} is missing") + + # Cross-field: each decoder rank must be a key in decoder_assignment. + dec_ranks = (cfg.get('rank') or {}).get('decoders') or [] + assignment = cfg.get('decoder_assignment') or {} + if isinstance(assignment, dict): + for r in dec_ranks: + if r not in assignment: + errors.append( + f"decoder_assignment is missing an entry for rank {r}" + ) + + # Cross-field: encoder.mark_dim must match across encoder and any + # synthetic-source override. + syn = cfg.get('synthetic') or {} + if isinstance(enc, dict) and isinstance(syn, dict): + if 'mark_dim' in syn and 'mark_dim' in enc and syn['mark_dim'] != enc['mark_dim']: + errors.append( + f"synthetic.mark_dim ({syn['mark_dim']}) " + f"!= encoder.mark_dim ({enc['mark_dim']})" + ) + + if errors: + bullet = '\n - '.join(errors) + raise ConfigError(f"config validation failed:\n - {bullet}") diff --git a/realtime_decoder/synthetic.py b/realtime_decoder/synthetic.py new file mode 100644 index 0000000..f3c8e30 --- /dev/null +++ b/realtime_decoder/synthetic.py @@ -0,0 +1,305 @@ +"""Synthetic data source for the realtime_decoder. + +Lets you install the package and run the full MPI pipeline end-to-end +without any acquisition hardware (Trodes, SpikeGLX, Open Ephys, etc.). + +This is intended for: + * smoke-testing a fresh install + * developing/debugging the decoder loop on a laptop + * regression testing in CI + +It is NOT intended to be biologically realistic. The synthetic spikes are +Poisson with a fixed rate, marks are gaussian, and the synthetic position +walks back and forth on a simple linear track. The goal is to exercise the +data path and message plumbing, not to validate decoding accuracy. + +Wiring is parallel to ``trodesnet.py``: a ``SyntheticDataReceiver`` that +mirrors ``TrodesDataReceiver`` and a ``SyntheticClient`` that mirrors +``TrodesClient``. The dispatch happens in ``runscript.py`` based on +``config['datasource']``. + +Config block expected under the top-level ``synthetic`` key (all optional): + + synthetic: + spike_rate_hz: 20 # per ntrode, Poisson + mark_dim: 8 # must match encoder.mark_dim + mark_amplitude_uv: 80 # mean spike amplitude + track_length_cm: 200 # walk distance + walk_speed_cm_s: 15 # synthetic animal speed + startup_delay_s: 1.0 # delay before firing the startup callback + run_duration_s: 60 # auto-terminate after this long + voltage_scaling_factor: 0.195 +""" + +import time + +import numpy as np + +from realtime_decoder import utils +from realtime_decoder.base import DataSourceReceiver +from realtime_decoder.datatypes import ( + Datatypes, + LFPPoint, + SpikePoint, + CameraModulePoint, +) + + +_DEFAULTS = { + 'spike_rate_hz': 20.0, + 'mark_dim': 4, + 'mark_amplitude_uv': 80.0, + 'track_length_cm': 200.0, + 'walk_speed_cm_s': 15.0, + 'startup_delay_s': 1.0, + 'run_duration_s': 60.0, + 'voltage_scaling_factor': 1.0, +} + + +def _params(config): + """Read the ``synthetic`` config block, applying defaults.""" + p = dict(_DEFAULTS) + p.update(config.get('synthetic') or {}) + return p + + +class SyntheticDataReceiver(DataSourceReceiver): + """Drop-in synthetic replacement for ``trodesnet.TrodesDataReceiver``. + + Generates LFP / spike / position samples on demand at clock-driven + rates. ``__next__`` returns None when no sample is due yet, matching + the non-blocking semantics of the Trodes receiver — the polling main + loops do not need to know they are reading synthetic data. + """ + + def __init__(self, comm, rank, config, datatype): + if datatype not in ( + Datatypes.LFP, + Datatypes.SPIKES, + Datatypes.LINEAR_POSITION, + ): + raise TypeError(f"Invalid datatype {datatype}") + super().__init__(comm, rank, config, datatype) + + self._p = _params(config) + self._started = False + self._stopped = False + + self.ntrode_ids = [] + + # Per-stream pacing: we advance a deterministic virtual clock + # (sample index) from t=0 at activate(), and emit samples as + # wall-clock catches up. This gives roughly the same delivery + # cadence as a live acquisition system at the configured rates. + self._t0_wall = None + self._next_sample_idx = 0 + if datatype == Datatypes.LFP: + self._fs = config['sampling_rate']['lfp'] + elif datatype == Datatypes.SPIKES: + self._fs = config['sampling_rate']['spikes'] + else: # LINEAR_POSITION + self._fs = config['sampling_rate']['position'] + + self._spike_clock = config['sampling_rate']['spikes'] + + # Spike-stream specific: independent Poisson process per ntrode. + # ``_next_spike_sample[ntid]`` stores the spike-clock sample index + # at which that ntrode's next spike will fire. + self._rng = np.random.default_rng(seed=rank * 1009 + int(datatype)) + self._next_spike_sample = {} + self._mark_dim = self._p['mark_dim'] + self._amp = self._p['mark_amplitude_uv'] + + # ------------------------------------------------------------------ + # DataSourceReceiver contract + # ------------------------------------------------------------------ + + def register_datatype_channel(self, channel): + ntrode_id = int(channel) + if self.datatype in (Datatypes.LFP, Datatypes.SPIKES): + if ntrode_id not in self.ntrode_ids: + self.ntrode_ids.append(ntrode_id) + # position has no channels + + def activate(self): + self._t0_wall = time.time() + self._next_sample_idx = 0 + if self.datatype == Datatypes.SPIKES: + for ntid in self.ntrode_ids: + self._schedule_next_spike(ntid, sample_now=0) + self._started = True + self.class_log.debug( + f"Synthetic {self.datatype.name} datastream activated " + f"({len(self.ntrode_ids)} ntrodes)" + ) + + def deactivate(self): + self._started = False + + def stop_iterator(self): + raise StopIteration() + + def __next__(self): + if not self._started: + return None + + elapsed = time.time() - self._t0_wall + if ( + self._p['run_duration_s'] > 0 + and elapsed > self._p['run_duration_s'] + and not self._stopped + ): + # one-time log; the supervisor's termination is wired + # through SyntheticClient.receive() below. + self._stopped = True + + if self.datatype == Datatypes.LFP: + return self._next_lfp(elapsed) + elif self.datatype == Datatypes.SPIKES: + return self._next_spike(elapsed) + else: + return self._next_position(elapsed) + + # ------------------------------------------------------------------ + # Per-datatype generators + # ------------------------------------------------------------------ + + def _next_lfp(self, elapsed): + target_idx = int(elapsed * self._fs) + if target_idx < self._next_sample_idx: + return None + idx = self._next_sample_idx + self._next_sample_idx += 1 + # white-ish noise sized to (num_channels,), scaled the same way + # TrodesDataReceiver does (raw * voltage_scaling_factor) + n = max(1, len(self.ntrode_ids)) + raw = self._rng.standard_normal(n) * 200.0 # ~uV range pre-scale + data = raw * self._p['voltage_scaling_factor'] + local_ts = idx # LFP uses spike-clock timestamps in real Trodes; + # at fs_lfp=1500, fs_spike=30000 the ratio is 20, but downstream + # only cares about monotonicity within a stream, so use idx. + system_ts = time.time_ns() + return LFPPoint( + local_ts, + list(self.ntrode_ids), + data, + system_ts, + time.time_ns(), + ) + + def _next_spike(self, elapsed): + if not self.ntrode_ids: + return None + spike_sample_now = int(elapsed * self._spike_clock) + # Find any ntrode whose next-spike sample has arrived. + for ntid in self.ntrode_ids: + if self._next_spike_sample[ntid] <= spike_sample_now: + ts = self._next_spike_sample[ntid] + self._schedule_next_spike(ntid, sample_now=spike_sample_now) + # mark vector: gaussian around _amp, all channels positive + samples = ( + self._rng.standard_normal(self._mark_dim) * 8.0 + self._amp + ) / self._p['voltage_scaling_factor'] + # SpikePoint.data is later multiplied by voltage_scaling_factor + # in real Trodes; the encoder reads `max(mark_vec)` so we just + # need the post-scaling magnitudes to clear `encoder.spk_amp`. + return SpikePoint( + ts, + ntid, + samples * self._p['voltage_scaling_factor'], + time.time_ns(), + time.time_ns(), + ) + return None + + def _next_position(self, elapsed): + target_idx = int(elapsed * self._fs) + if target_idx < self._next_sample_idx: + return None + idx = self._next_sample_idx + self._next_sample_idx += 1 + + # triangle-wave walk along a single linear segment between 0 and + # track_length_cm + L = self._p['track_length_cm'] + v = self._p['walk_speed_cm_s'] + t = elapsed + period = 2.0 * L / max(v, 1e-6) + phase = (t % period) / period # 0..1 + pos_cm = L * (1.0 - abs(2.0 * phase - 1.0)) + # x/y/x2/y2 in "pixel" units — kinematics.scale_factor converts back + sf = self.config['kinematics']['scale_factor'] + x = pos_cm / sf + y = 100.0 # constant + x2 = x + 5.0 + y2 = y + return CameraModulePoint( + idx, + segment=0, + position=pos_cm, + x=x, + y=y, + x2=x2, + y2=y2, + t_recv_data=time.time_ns(), + ) + + # ------------------------------------------------------------------ + # internals + # ------------------------------------------------------------------ + + def _schedule_next_spike(self, ntid, *, sample_now): + rate = max(self._p['spike_rate_hz'], 1e-6) + # exponential inter-arrival in seconds → samples + gap_s = self._rng.exponential(1.0 / rate) + gap_samples = max(1, int(gap_s * self._spike_clock)) + self._next_spike_sample[ntid] = sample_now + gap_samples + + +class SyntheticClient(object): + """Drop-in synthetic replacement for ``trodesnet.TrodesClient``. + + Exposes the same surface used by the supervisor and stim decider: + * ``set_startup_callback`` / ``set_termination_callback`` + * ``receive`` (called from the supervisor main loop) + * ``send_statescript_shortcut_message`` (called from stim_decider) + + ``receive`` fires the startup callback once after ``startup_delay_s`` + of wall clock has elapsed, and fires termination once ``run_duration_s`` + has elapsed. + """ + + def __init__(self, config): + self._startup_callback = utils.nop + self._termination_callback = utils.nop + self._p = _params(config) + self._t0_wall = time.time() + self._started = False + self._terminated = False + # log-only buffer of "shortcut messages" the stim decider would + # have sent to ECU; useful for asserting in tests later. + self.sent_shortcuts = [] + + def receive(self): + elapsed = time.time() - self._t0_wall + if not self._started and elapsed >= self._p['startup_delay_s']: + self._started = True + self._startup_callback() + if ( + self._started + and not self._terminated + and self._p['run_duration_s'] > 0 + and elapsed >= self._p['run_duration_s'] + self._p['startup_delay_s'] + ): + self._terminated = True + self._termination_callback() + + def send_statescript_shortcut_message(self, val): + self.sent_shortcuts.append((time.time_ns(), int(val))) + + def set_startup_callback(self, callback): + self._startup_callback = callback + + def set_termination_callback(self, callback): + self._termination_callback = callback diff --git a/runscript.py b/runscript.py index 3e920c8..2f374f5 100644 --- a/runscript.py +++ b/runscript.py @@ -1,5 +1,6 @@ import os import argparse +import sys import time import datetime import logging @@ -11,12 +12,30 @@ from mpi4py import MPI from realtime_decoder import ( - datatypes, position, trodesnet, stimulation, + datatypes, position, trodesnet, synthetic, stimulation, main_process, ripple_process, encoder_process, decoder_process, gui_process, base, messages, - merge_rec + merge_rec, config_loader ) + +def _data_source_factory(config): + """Pick the (receiver_class, client_class) pair for the configured + data source. + + `datasource: trodes` (default) uses the live Trodes streams. + `datasource: synthetic` uses the in-process generator from + `realtime_decoder.synthetic` — install-and-run with no hardware. + """ + ds = config.get('datasource', 'trodes') + if ds == 'trodes': + return trodesnet.TrodesDataReceiver, trodesnet.TrodesClient + if ds == 'synthetic': + return synthetic.SyntheticDataReceiver, synthetic.SyntheticClient + raise ValueError( + f"Unknown datasource {ds!r}; expected 'trodes' or 'synthetic'" + ) + # from line_profiler import LineProfiler class GuiProcessStub(base.RealtimeProcess, base.MessageHandler): @@ -101,8 +120,17 @@ def setup(config_path, numprocs): num_digits = len(str(comm.Get_size())) - with open(config_path, 'r') as f: - config = yaml.safe_load(f) + # Load via the resolver: handles `_extends` inheritance and runs + # validation up front so missing required keys produce one clear + # error before the MPI run starts, instead of an IndexError / + # KeyError deep inside a worker. + try: + config = config_loader.load_config(config_path) + except config_loader.ConfigError as exc: + if rank == 0: + print(f"[config] {exc}", file=sys.stderr, flush=True) + comm.Barrier() + sys.exit(2) os.makedirs(os.path.dirname(config['files']['output_dir']), exist_ok=True) prefix = config['files']['prefix'] @@ -169,21 +197,23 @@ def setup(config_path, numprocs): regloop = True ################################################# + DataReceiver, Client = _data_source_factory(config) + if rank in config['rank']['supervisor']: - trodes_client = trodesnet.TrodesClient(config) + net_client = Client(config) stim_decider = stimulation.TwoArmTrodesStimDecider( - comm, rank, config, trodes_client + comm, rank, config, net_client ) process = main_process.MainProcess( - comm, rank, config, stim_decider, trodes_client + comm, rank, config, stim_decider, net_client ) - trodes_client.set_startup_callback(process.startup) - trodes_client.set_termination_callback(process.trigger_termination) + net_client.set_startup_callback(process.startup) + net_client.set_termination_callback(process.trigger_termination) elif rank in config['rank']['ripples']: - lfp_interface = trodesnet.TrodesDataReceiver( + lfp_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.LFP ) - pos_interface = trodesnet.TrodesDataReceiver( + pos_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.LINEAR_POSITION ) process = ripple_process.RippleProcess( @@ -196,10 +226,10 @@ def setup(config_path, numprocs): # prof.print_stats() # regloop = False elif rank in config['rank']['encoders']: - spikes_interface = trodesnet.TrodesDataReceiver( + spikes_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.SPIKES ) - pos_interface = trodesnet.TrodesDataReceiver( + pos_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.LINEAR_POSITION ) pos_mapper = position.TrodesPositionMapper( @@ -211,7 +241,7 @@ def setup(config_path, numprocs): pos_mapper ) elif rank in config['rank']['decoders']: - pos_interface = trodesnet.TrodesDataReceiver( + pos_interface = DataReceiver( comm, rank, config, datatypes.Datatypes.LINEAR_POSITION ) pos_mapper = position.TrodesPositionMapper( diff --git a/setup.py b/setup.py index 6631171..c6840dd 100644 --- a/setup.py +++ b/setup.py @@ -21,11 +21,15 @@ def get_version_string(): 'Cython', 'trodesnetwork', 'pyqtgraph', - 'oyaml' + 'oyaml', + 'pyyaml', ], + extras_require={ + 'test': ['pytest'], + }, author_email='jpc6@rice.edu', description='Realtime clusterless decoding', - packages=find_packages(), + packages=find_packages(exclude=['tests', 'tests.*']), keywords="neuroscience clusterless decoding", include_package_data=True, platforms='any', diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..3322996 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,67 @@ +"""Shared test setup. + +Several realtime_decoder modules (base, position, synthetic, ...) import +`mpi4py` at module top. mpi4py is non-trivial to install in a CI image +(needs a system MPI build), and none of the deterministic logic we want +to unit-test actually exercises it. We stub it here so importing +realtime_decoder modules works on a plain `pip install pytest` env. + +If real mpi4py is already importable (developer machine with MPI +installed), we skip the stub and use the real module. +""" + +import sys +import types + + +def _install_mpi4py_stub(): + if 'mpi4py' in sys.modules: + return + try: + import mpi4py # noqa: F401 + return + except ImportError: + pass + + mpi4py_mod = types.ModuleType('mpi4py') + mpi_mod = types.ModuleType('mpi4py.MPI') + + class _Status: # minimal stand-in for MPI.Status() + source = 0 + tag = 0 + + class _Comm: + def Get_rank(self): + return 0 + + def Get_size(self): + return 1 + + def Barrier(self): + pass + + # the data-path methods are not exercised by any tests that + # touch this stub; leave them as raising stubs to surface + # accidental hot-path use. + def Send(self, *a, **k): + raise RuntimeError("MPI stub: Send called in a unit test") + + def Irecv(self, *a, **k): + raise RuntimeError("MPI stub: Irecv called in a unit test") + + def send(self, *a, **k): + raise RuntimeError("MPI stub: send called in a unit test") + + def irecv(self, *a, **k): + raise RuntimeError("MPI stub: irecv called in a unit test") + + mpi_mod.Status = _Status + mpi_mod.Comm = _Comm + mpi_mod.BYTE = object() # sentinel; tests don't inspect it + mpi_mod.COMM_WORLD = _Comm() + mpi4py_mod.MPI = mpi_mod + sys.modules['mpi4py'] = mpi4py_mod + sys.modules['mpi4py.MPI'] = mpi_mod + + +_install_mpi4py_stub() diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py new file mode 100644 index 0000000..b36a738 --- /dev/null +++ b/tests/test_config_loader.py @@ -0,0 +1,151 @@ +"""Unit tests for realtime_decoder.config_loader.""" + +import os +import textwrap + +import pytest + +from realtime_decoder import config_loader + + +# --------------------------------------------------------------------------- +# deep_merge +# --------------------------------------------------------------------------- + + +def test_deep_merge_nested_dicts(): + base = {'a': 1, 'b': {'c': 2, 'd': 3}} + over = {'b': {'d': 99, 'e': 4}, 'f': 5} + out = config_loader.deep_merge(base, over) + assert out == {'a': 1, 'b': {'c': 2, 'd': 99, 'e': 4}, 'f': 5} + # base must not be mutated + assert base == {'a': 1, 'b': {'c': 2, 'd': 3}} + + +def test_deep_merge_lists_are_replaced_not_extended(): + base = {'xs': [1, 2, 3]} + over = {'xs': [9]} + assert config_loader.deep_merge(base, over) == {'xs': [9]} + + +def test_deep_merge_scalar_overrides_dict(): + base = {'k': {'nested': True}} + over = {'k': 'flat'} + assert config_loader.deep_merge(base, over) == {'k': 'flat'} + + +# --------------------------------------------------------------------------- +# load_config (extends, validation) +# --------------------------------------------------------------------------- + + +def _write(p, contents): + p.write_text(textwrap.dedent(contents).lstrip()) + + +@pytest.fixture +def minimal_valid_config(): + """Returns a function that writes a minimal valid config to a path.""" + def _write_to(path): + _write(path, """ + algorithm: clusterless_decoder + datasource: synthetic + sampling_rate: {spikes: 30000, lfp: 1500, position: 30} + files: {output_dir: /tmp/x, prefix: x} + kinematics: {smoothing_filter: [1.0]} + rank: + supervisor: [0] + decoders: [1] + encoders: [2] + ripples: [3] + gui: [4] + decoder_assignment: {1: [1]} + ripples: {filter: {}, smoothing_filter: {}, threshold: {}} + encoder: + mark_dim: 4 + bufsize: 100 + spk_amp: 60 + position: + lower: 0 + upper: 10 + num_bins: 10 + arm_ids: [0] + arm_coords: [[0, 9]] + decoder: + bufsize: 100 + cred_int_bufsize: 10 + time_bin: {samples: 180, delay_samples: 180} + """) + return _write_to + + +def test_load_config_without_extends(tmp_path, minimal_valid_config): + p = tmp_path / 'cfg.yml' + minimal_valid_config(p) + cfg = config_loader.load_config(str(p)) + assert cfg['algorithm'] == 'clusterless_decoder' + + +def test_load_config_with_extends_merges_parent(tmp_path, minimal_valid_config): + parent = tmp_path / 'base.yml' + child = tmp_path / 'child.yml' + minimal_valid_config(parent) + _write(child, """ + _extends: base.yml + algorithm: clusterless_classifier + files: {output_dir: /tmp/child, prefix: child} + """) + cfg = config_loader.load_config(str(child)) + # child overrides + assert cfg['algorithm'] == 'clusterless_classifier' + assert cfg['files']['output_dir'] == '/tmp/child' + # inherited from parent + assert cfg['sampling_rate']['spikes'] == 30000 + + +def test_load_config_circular_extends_raises(tmp_path): + a = tmp_path / 'a.yml' + b = tmp_path / 'b.yml' + _write(a, '_extends: b.yml\n') + _write(b, '_extends: a.yml\n') + with pytest.raises(config_loader.ConfigError, match='Circular'): + config_loader.load_config(str(a)) + + +def test_validate_missing_required_key_raises(tmp_path): + p = tmp_path / 'cfg.yml' + _write(p, 'algorithm: clusterless_decoder\n') + with pytest.raises(config_loader.ConfigError) as exc: + config_loader.load_config(str(p)) + msg = str(exc.value) + assert 'rank' in msg + assert 'sampling_rate' in msg + + +def test_validate_unknown_algorithm_raises(tmp_path, minimal_valid_config): + p = tmp_path / 'cfg.yml' + minimal_valid_config(p) + # mutate the file to inject a bad algorithm + src = p.read_text().replace('clusterless_decoder', 'bogus_algo') + p.write_text(src) + with pytest.raises(config_loader.ConfigError, match='bogus_algo'): + config_loader.load_config(str(p)) + + +def test_validate_decoder_rank_missing_assignment(tmp_path, minimal_valid_config): + p = tmp_path / 'cfg.yml' + minimal_valid_config(p) + # remove the decoder_assignment line so rank 1 has no entry + src = p.read_text().replace('decoder_assignment: {1: [1]}', 'decoder_assignment: {}') + p.write_text(src) + with pytest.raises(config_loader.ConfigError, match='decoder_assignment'): + config_loader.load_config(str(p)) + + +def test_validate_mark_dim_mismatch_between_encoder_and_synthetic(tmp_path, minimal_valid_config): + p = tmp_path / 'cfg.yml' + minimal_valid_config(p) + src = p.read_text().rstrip() + '\nsynthetic: {mark_dim: 99}\n' + p.write_text(src) + with pytest.raises(config_loader.ConfigError, match='mark_dim'): + config_loader.load_config(str(p)) diff --git a/tests/test_position.py b/tests/test_position.py new file mode 100644 index 0000000..5e8c171 --- /dev/null +++ b/tests/test_position.py @@ -0,0 +1,114 @@ +"""Unit tests for realtime_decoder.position.""" + +import numpy as np +import pytest + +from realtime_decoder import position, datatypes + + +# --------------------------------------------------------------------------- +# PositionBinStruct +# --------------------------------------------------------------------------- + + +def test_position_bin_struct_edges_and_centers(): + s = position.PositionBinStruct(0, 10, 5) + np.testing.assert_allclose(s.pos_bin_edges, [0, 2, 4, 6, 8, 10]) + np.testing.assert_allclose(s.pos_bin_centers, [1, 3, 5, 7, 9]) + assert s.pos_bin_delta == 2.0 + assert s.num_bins == 5 + + +@pytest.mark.parametrize("pos,expected_bin", [ + (0.0, 0), + (1.99, 0), + (2.0, 1), + (5.5, 2), + (9.99, 4), +]) +def test_get_bin_within_range(pos, expected_bin): + s = position.PositionBinStruct(0, 10, 5) + assert s.get_bin(pos) == expected_bin + + +# --------------------------------------------------------------------------- +# TrodesPositionMapper +# --------------------------------------------------------------------------- + + +def _camera_point(*, segment, position_on_segment): + """Build a CameraModulePoint with the bare minimum the mapper reads.""" + return datatypes.CameraModulePoint( + timestamp=0, + segment=segment, + position=position_on_segment, + x=0.0, y=0.0, x2=0.0, y2=0.0, + t_recv_data=0, + ) + + +def test_position_mapper_basic(): + # 2 arms: segment 0 -> arm 0 (bins 0..3), segment 1 -> arm 1 (bins 5..8) + mapper = position.TrodesPositionMapper( + arm_ids=[0, 1], + arm_coords=[[0, 3], [5, 8]], + ) + # arm 0 has 4 bins; normalized edges [0, .25, .5, .75, 1] + assert mapper.map_position(_camera_point(segment=0, position_on_segment=0.0)) == 0 + assert mapper.map_position(_camera_point(segment=0, position_on_segment=0.5)) == 2 + # exact upper edge clamps to the last bin per the inclusive-upper rule + assert mapper.map_position(_camera_point(segment=0, position_on_segment=1.0)) == 3 + # arm 1 starts at bin 5 + assert mapper.map_position(_camera_point(segment=1, position_on_segment=0.0)) == 5 + assert mapper.map_position(_camera_point(segment=1, position_on_segment=1.0)) == 8 + + +def test_position_mapper_above_one_clamps(): + # numerical noise that pushes segment position slightly above 1.0 + # should still land in the last bin rather than crashing. + mapper = position.TrodesPositionMapper( + arm_ids=[0], + arm_coords=[[0, 4]], + ) + assert mapper.map_position(_camera_point(segment=0, position_on_segment=1.0001)) == 4 + + +# --------------------------------------------------------------------------- +# KinematicsEstimator +# --------------------------------------------------------------------------- + + +def test_kinematics_first_sample_returns_zero_speed(): + est = position.KinematicsEstimator( + scale_factor=1.0, dt=1.0, + xfilter=[1.0], yfilter=[1.0], speedfilter=[1.0], + ) + x, y, s = est.compute_kinematics(10.0, 20.0) + assert (x, y, s) == (10.0, 20.0, 0) + + +def test_kinematics_speed_unsmoothed_matches_euclid_distance(): + est = position.KinematicsEstimator( + scale_factor=1.0, dt=0.5, + xfilter=[1.0], yfilter=[1.0], speedfilter=[1.0], + ) + est.compute_kinematics(0.0, 0.0) # prime + x, y, s = est.compute_kinematics(3.0, 4.0) + # 5 units over 0.5s -> 10 units/sec + assert (x, y) == (3.0, 4.0) + assert np.isclose(s, 10.0) + + +def test_kinematics_smoothing_applies_fir(): + # 3-tap moving average: smoothed value of the third sample should + # equal the average of the last three inputs (* scale). + est = position.KinematicsEstimator( + scale_factor=1.0, dt=1.0, + xfilter=[1/3, 1/3, 1/3], + yfilter=[1/3, 1/3, 1/3], + speedfilter=[1.0], + ) + est.compute_kinematics(0.0, 0.0) # prime, returned raw + est.compute_kinematics(6.0, 0.0, smooth_x=True) # buf=[6,0,0] + x, _, _ = est.compute_kinematics(9.0, 0.0, smooth_x=True) # buf=[9,6,0] + assert np.isclose(x, 5.0) # (9+6+0)/3 diff --git a/tests/test_synthetic.py b/tests/test_synthetic.py new file mode 100644 index 0000000..c228468 --- /dev/null +++ b/tests/test_synthetic.py @@ -0,0 +1,127 @@ +"""Unit tests for realtime_decoder.synthetic. + +These tests don't spin up MPI; they construct receivers directly with a +fake comm and verify the per-datatype generators produce the right +shapes and types. +""" + +import time + +import numpy as np +import pytest + +from realtime_decoder import synthetic, datatypes + + +@pytest.fixture +def base_config(): + return { + 'sampling_rate': {'spikes': 30000, 'lfp': 1500, 'position': 30}, + 'kinematics': {'scale_factor': 0.25}, + 'synthetic': { + 'spike_rate_hz': 1000.0, # high so we see spikes fast + 'mark_dim': 4, + 'mark_amplitude_uv': 120.0, + 'track_length_cm': 40.0, + 'walk_speed_cm_s': 20.0, + 'startup_delay_s': 0.0, + 'run_duration_s': 5.0, + 'voltage_scaling_factor': 0.195, + }, + } + + +class _FakeComm: + """Minimal stub: receivers only need .Get_rank if anything; not used in __next__.""" + pass + + +def test_receiver_rejects_unknown_datatype(base_config): + with pytest.raises(TypeError): + synthetic.SyntheticDataReceiver(_FakeComm(), 0, base_config, datatype=999) + + +def test_lfp_receiver_emits_correct_shape(base_config): + rx = synthetic.SyntheticDataReceiver(_FakeComm(), 0, base_config, datatypes.Datatypes.LFP) + rx.register_datatype_channel(1) + rx.register_datatype_channel(2) + rx.activate() + # poll until we get a sample (LFP at 1500hz, so first sample ~immediate) + deadline = time.time() + 1.0 + sample = None + while time.time() < deadline: + sample = rx.__next__() + if sample is not None: + break + assert sample is not None, "no LFP sample within 1s" + assert isinstance(sample, datatypes.LFPPoint) + assert sample.data.shape == (2,) + + +def test_lfp_receiver_returns_none_before_activate(base_config): + rx = synthetic.SyntheticDataReceiver(_FakeComm(), 0, base_config, datatypes.Datatypes.LFP) + rx.register_datatype_channel(1) + assert rx.__next__() is None + + +def test_spike_receiver_emits_marks_above_amp_threshold(base_config): + rx = synthetic.SyntheticDataReceiver(_FakeComm(), 0, base_config, datatypes.Datatypes.SPIKES) + rx.register_datatype_channel(7) + rx.activate() + deadline = time.time() + 1.0 + spike = None + while time.time() < deadline: + spike = rx.__next__() + if spike is not None: + break + assert spike is not None, "no spike within 1s at 1khz rate" + assert isinstance(spike, datatypes.SpikePoint) + assert spike.elec_grp_id == 7 + assert spike.data.shape == (base_config['synthetic']['mark_dim'],) + # marks should clear a reasonable amplitude threshold post-scaling + assert float(np.max(spike.data)) > 50.0 + + +def test_position_receiver_walks_within_bounds(base_config): + base_config['synthetic']['walk_speed_cm_s'] = 200 # fast walk so we cover range quickly + rx = synthetic.SyntheticDataReceiver(_FakeComm(), 0, base_config, datatypes.Datatypes.LINEAR_POSITION) + rx.activate() + L = base_config['synthetic']['track_length_cm'] + deadline = time.time() + 2.0 + positions = [] + while time.time() < deadline and len(positions) < 30: + p = rx.__next__() + if p is not None: + positions.append(p.position) + assert isinstance(p, datatypes.CameraModulePoint) + assert positions, "no position samples emitted" + assert min(positions) >= 0.0 + assert max(positions) <= L + 1e-6 + + +def test_synthetic_client_fires_startup_callback(base_config): + base_config['synthetic']['startup_delay_s'] = 0.05 + base_config['synthetic']['run_duration_s'] = 0 # disable auto-term + client = synthetic.SyntheticClient(base_config) + + calls = {'startup': 0, 'term': 0} + client.set_startup_callback(lambda: calls.__setitem__('startup', calls['startup'] + 1)) + client.set_termination_callback(lambda: calls.__setitem__('term', calls['term'] + 1)) + + # before delay elapses, no callback + client.receive() + assert calls['startup'] == 0 + + time.sleep(0.1) + client.receive() + assert calls['startup'] == 1 + # subsequent calls should not refire startup + client.receive() + assert calls['startup'] == 1 + + +def test_synthetic_client_records_shortcut_messages(base_config): + client = synthetic.SyntheticClient(base_config) + client.send_statescript_shortcut_message(22) + client.send_statescript_shortcut_message(14) + assert [v for _, v in client.sent_shortcuts] == [22, 14] diff --git a/tests/test_transitions.py b/tests/test_transitions.py new file mode 100644 index 0000000..77affc5 --- /dev/null +++ b/tests/test_transitions.py @@ -0,0 +1,45 @@ +"""Unit tests for the transition model builders.""" + +import numpy as np + +from realtime_decoder import transitions + + +def test_sungod_transition_matrix_shape_and_row_sums(): + pos_bins = np.arange(10) + arm_coords = [[0, 3], [6, 9]] + bias = 1 + T = transitions.sungod_transition_matrix(pos_bins, arm_coords, bias) + + # square, sized to the number of position bins + assert T.shape == (len(pos_bins), len(pos_bins)) + + # rows that are not entirely zero should sum to 1 (within float + # tolerance). The gap rows between arms are masked to NaN by + # apply_no_anim_boundary and then zeroed by the function, so they + # legitimately sum to 0. + row_sums = T.sum(axis=1) + for s in row_sums: + assert np.isclose(s, 0.0) or np.isclose(s, 1.0) + + # at least the in-arm rows must sum to 1 + in_arm_rows = [r for arm in arm_coords for r in range(arm[0], arm[1] + 1)] + for r in in_arm_rows: + assert np.isclose(row_sums[r], 1.0), f"row {r} sums to {row_sums[r]}" + + +def test_sungod_transition_matrix_gap_rows_are_zero(): + pos_bins = np.arange(10) + arm_coords = [[0, 3], [6, 9]] + T = transitions.sungod_transition_matrix(pos_bins, arm_coords, bias=1) + # bins 4 and 5 are gaps; the corresponding rows and columns should + # be all zero so transition mass cannot flow through them. + assert np.all(T[4:6, :] == 0) + assert np.all(T[:, 4:6] == 0) + + +def test_sungod_transition_matrix_no_nans(): + pos_bins = np.arange(8) + arm_coords = [[0, 7]] + T = transitions.sungod_transition_matrix(pos_bins, arm_coords, bias=1) + assert not np.any(np.isnan(T)) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..6206c97 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,62 @@ +"""Unit tests for realtime_decoder.utils helpers.""" + +import numpy as np +import pytest + +from realtime_decoder import utils + + +def test_normalize_to_probability_simple(): + dist = np.array([1.0, 2.0, 3.0, 4.0]) + out = utils.normalize_to_probability(dist) + assert np.isclose(out.sum(), 1.0) + np.testing.assert_allclose(out, dist / dist.sum()) + + +def test_normalize_to_probability_ignores_nan(): + # np.nansum is used internally, so NaN entries should not bias + # the normalization of the other entries. + dist = np.array([1.0, np.nan, 3.0]) + out = utils.normalize_to_probability(dist) + # the two finite entries should sum to 1 between them (NaN propagates + # to its own bin but the divisor was sum(1+3)=4) + finite = out[np.isfinite(out)] + assert np.isclose(finite.sum(), 1.0) + + +def test_estimate_new_stats_matches_numpy(): + # Welford's online stats should converge to numpy's batch stats. + rng = np.random.default_rng(0) + values = rng.standard_normal(500) + mean = 0.0 + M2 = 0.0 + count = 0 + for v in values: + mean, M2, count = utils.estimate_new_stats(v, mean, M2, count) + # variance from M2; compare to numpy population variance (ddof=0) + var = M2 / count + assert np.isclose(mean, values.mean(), atol=1e-12) + assert np.isclose(var, values.var(), atol=1e-12) + + +def test_apply_no_anim_boundary_2d_fills_gaps(): + # arm_coords [[0,3],[6,9]] => bins 4 and 5 are "no animal" gaps. + x_bins = np.arange(10) + arm_coords = [[0, 3], [6, 9]] + image = np.ones((10, 10)) + out = utils.apply_no_anim_boundary(x_bins, arm_coords, image, fill=0) + # gap rows and columns should be zeroed + assert np.all(out[4:6, :] == 0) + assert np.all(out[:, 4:6] == 0) + # non-gap interior should still be 1 + assert out[1, 1] == 1 + assert out[7, 7] == 1 + + +def test_apply_no_anim_boundary_1d_fills_gaps(): + x_bins = np.arange(10) + arm_coords = [[0, 3], [6, 9]] + image = np.ones(10) + out = utils.apply_no_anim_boundary(x_bins, arm_coords, image, fill=-1) + assert np.all(out[4:6] == -1) + assert out[0] == 1 and out[9] == 1