diff --git a/nslsii/tests/test_redis_integration.py b/nslsii/tests/test_redis_integration.py index 4b0110d..9e461ed 100644 --- a/nslsii/tests/test_redis_integration.py +++ b/nslsii/tests/test_redis_integration.py @@ -1,8 +1,11 @@ """Tests for redis-related parameter changes in utils, __init__, and sync_experiment.""" +import os from unittest.mock import MagicMock, patch, mock_open +from pathlib import Path import pytest +from pytest_mock import MockerFixture import sys @@ -155,3 +158,105 @@ def test_switch_redis_proposal_passes_redis_db(switch_mocks): call_kwargs = switch_mocks["open_redis_client"].call_args[1] assert call_kwargs["redis_db"] == 7 + + +@pytest.fixture +def secret_file(tmp_path: Path): + os.environ["REDIS_SECRET_FILE"] = str(tmp_path / "secret") + with open(tmp_path / "secret", "w") as fp: + fp.write("redis_secret") + + +@pytest.mark.parametrize( + ("url", "port", "ssl", "loc", "db", "es_acronym", "bl_acronym", "expected_url"), + [ + ( + None, + None, + True, + "xf28id2", + 0, + None, + "XPD", + "xf28id2-xpd-redis1.nsls2.bnl.gov", + ), + ( + None, + None, + True, + "xf28id2", + 0, + "XPDD", + "XPD", + "xf28id2-xpdd-redis1.nsls2.bnl.gov", + ), + (None, None, False, "xf28id2", 0, None, "XPD", "info.xpd.nsls2.bnl.gov"), + (None, None, False, "xf28id2", 0, "XPDD", "XPD", "info.xpd.nsls2.bnl.gov"), + ( + "xf28id2-xpd-redis1.nsls2.bnl.gov", + None, + True, + "xf28id2", + 0, + None, + None, + "xf28id2-xpd-redis1.nsls2.bnl.gov", + ), + ( + "xf28id2-xpd-redis1.nsls2.bnl.gov", + 1234, + True, + "xf28id2", + 0, + None, + None, + "xf28id2-xpd-redis1.nsls2.bnl.gov", + ), + ( + "info.xpd.nsls2.bnl.gov", + None, + False, + "xf28id2", + 0, + None, + None, + "info.xpd.nsls2.bnl.gov", + ), + ], +) +def test_open_redis_client_uses_es_bl_acronym_vars( + secret_file, + mocker: MockerFixture, + url: str | None, + port: int | None, + ssl: bool, + loc: str, + db: int, + es_acronym: str | None, + bl_acronym: str | None, + expected_url: str, +): + mock_redis = mocker.patch("nslsii.utils.Redis") + + expected_port = port or (6379 if not ssl else 6380) + if es_acronym: + os.environ["ENDSTATION_ACRONYM"] = es_acronym + if bl_acronym: + os.environ["BEAMLINE_ACRONYM"] = bl_acronym + + open_redis_client( + redis_url=url, redis_port=port, redis_ssl=ssl, redis_location=loc, redis_db=db + ) + mock_redis.assert_called_with( + host=expected_url, + port=expected_port, + ssl=ssl, + password="redis_secret" if ssl else None, + db=db, + ) + + +def test_cannot_find_client_location(): + os.environ["BEAMLINE_ACRONYM"] = "XPD" + with pytest.raises(RuntimeError, match="Failed to derive redis server url"): + open_redis_client(redis_location="xf27id1", redis_ssl=True) diff --git a/nslsii/utils.py b/nslsii/utils.py index f8c8ebc..66e7688 100644 --- a/nslsii/utils.py +++ b/nslsii/utils.py @@ -54,6 +54,9 @@ def open_redis_client( if os.getenv("REDIS_HOST"): redis_url = os.getenv("REDIS_HOST") if redis_url is None: + endstation_acronym = os.getenv("ENDSTATION_ACRONYM") + beamline_acronym = os.getenv("BEAMLINE_ACRONYM") + if redis_ssl: client_loc_id = ( redis_location if redis_location else socket.gethostname().split("-")[0] @@ -61,6 +64,14 @@ def open_redis_client( client_locations = [ location for location in redis_hosts if client_loc_id in location ] + redis_host_acronym = endstation_acronym or beamline_acronym + if redis_host_acronym is not None: + client_locations = [ + location + for location in client_locations + if f"-{redis_host_acronym.lower()}-" in location + ] + if len(client_locations) != 1: raise RuntimeError( "Failed to derive redis server url, please specify using the " @@ -69,8 +80,7 @@ def open_redis_client( else: redis_url = client_locations[0] else: - tla = os.getenv("BEAMLINE_ACRONYM").lower() - redis_url = f"info.{tla}.nsls2.bnl.gov" + redis_url = f"info.{beamline_acronym.lower()}.nsls2.bnl.gov" if redis_ssl: redis_pw = os.getenv("REDIS_PASSWORD") diff --git a/requirements-dev.txt b/requirements-dev.txt index 846bdc6..fb2ca3c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -15,4 +15,5 @@ pre-commit ruff doct pims -pytz \ No newline at end of file +pytz +pytest-mock