From 4b89967a916eea1f9de6fd0125af490bf72f2220 Mon Sep 17 00:00:00 2001 From: Jamie Gilbert Date: Thu, 14 May 2026 11:39:42 -0700 Subject: [PATCH 1/5] implementation of custom end era logic in ibis layer --- circe/execution/engine/custom_era.py | 149 ++++++++++++++++++++++ circe/execution/engine/end_strategy.py | 4 +- circe/execution/normalize/cohort.py | 6 +- circe/execution/normalize/end_strategy.py | 1 + tests/execution/test_api_ibis.py | 19 +-- tests/execution/test_error_messages.py | 12 +- 6 files changed, 165 insertions(+), 26 deletions(-) create mode 100644 circe/execution/engine/custom_era.py diff --git a/circe/execution/engine/custom_era.py b/circe/execution/engine/custom_era.py new file mode 100644 index 00000000..e9a93e39 --- /dev/null +++ b/circe/execution/engine/custom_era.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import ibis + +from ..plan.schema import PERSON_ID, START_DATE +from .end_strategy import _replace_end_date, attach_observation_bounds + + +def _compute_exposure_end_date(table, *, days_supply_override: int | None): + start = table["drug_exposure_start_date"].cast("date") + + if days_supply_override is not None: + return start + ibis.interval(days=days_supply_override) + + raw_end = ( + table["drug_exposure_end_date"].cast("date") + if "drug_exposure_end_date" in table.columns + else ibis.null().cast("date") + ) + days_supply = ( + table["days_supply"].cast("int64") if "days_supply" in table.columns else ibis.null().cast("int64") + ) + supply_end = start + days_supply.as_interval("D") + + return ibis.coalesce(raw_end, supply_end, start + ibis.interval(days=1)) + + +def _compute_eras(exposures, *, gap_days: int, offset: int): + padded = exposures.mutate(_padded_end=(exposures._exposure_end + ibis.interval(days=int(gap_days)))) + + ordering = [ + padded.start_date, + padded._padded_end.desc(), + padded._exposure_end.desc(), + ] + + cumulative_window = ibis.cumulative_window(group_by=padded.person_id, order_by=ordering) + ordered_window = ibis.window(group_by=padded.person_id, order_by=ordering) + + with_cummax = padded.mutate(_cummax_padded_end=padded._padded_end.max().over(cumulative_window)) + + with_prev = with_cummax.mutate(_prev_max=with_cummax._cummax_padded_end.lag().over(ordered_window)) + + marked = with_prev.mutate( + _is_new=ibis.ifelse( + with_prev._prev_max.isnull() | (with_prev._prev_max < with_prev.start_date), + ibis.literal(1, type="int64"), + ibis.literal(0, type="int64"), + ) + ) + + group_window = ibis.cumulative_window( + group_by=marked.person_id, + order_by=[ + marked.start_date, + marked._padded_end.desc(), + marked._exposure_end.desc(), + marked._is_new.desc(), + ], + ) + era_indexed = marked.mutate(_era_id=marked._is_new.sum().over(group_window)) + + collapsed = era_indexed.group_by(era_indexed.person_id, era_indexed._era_id).aggregate( + era_start_date=era_indexed.start_date.min(), + _max_exposure_end=era_indexed._exposure_end.max(), + ) + + return collapsed.select( + collapsed.person_id.cast("int64").name(PERSON_ID), + collapsed.era_start_date.cast("date").name("era_start_date"), + (collapsed._max_exposure_end + ibis.interval(days=int(offset))).cast("date").name("era_end_date"), + ) + + +def compute_drug_eras( + ctx, *, drug_codeset_id: int, gap_days: int, offset: int, days_supply_override: int | None +): + concept_ids = ctx.concept_ids_for_codeset(drug_codeset_id) + + if not concept_ids: + de = ctx.table("drug_exposure") + return de.filter(ibis.literal(False)).select( + de.person_id.cast("int64").name(PERSON_ID), + ibis.null().cast("date").name("era_start_date"), + ibis.null().cast("date").name("era_end_date"), + ) + + de = ctx.table("drug_exposure") + filtered = de.filter(de.drug_concept_id.isin(concept_ids)) + + prepared = filtered.select( + filtered.person_id.cast("int64").name("person_id"), + filtered.drug_exposure_start_date.cast("date").name("start_date"), + _compute_exposure_end_date(filtered, days_supply_override=days_supply_override).name("_exposure_end"), + ) + + return _compute_eras(prepared, gap_days=gap_days, offset=offset) + + +def apply_custom_era_strategy(events, strategy, ctx): + payload = strategy.payload + drug_codeset_id = payload["drug_codeset_id"] + gap_days = payload["gap_days"] + offset = payload["offset"] + days_supply_override = payload.get("days_supply_override") + + if drug_codeset_id is None: + with_bounds = attach_observation_bounds(events, ctx) + return _replace_end_date(events, with_bounds, with_bounds.op_end_date) + + eras = compute_drug_eras( + ctx, + drug_codeset_id=drug_codeset_id, + gap_days=gap_days, + offset=offset, + days_supply_override=days_supply_override, + ) + + eras_for_join = eras.select( + eras.person_id.name("_era_person_id"), + eras.era_start_date, + eras.era_end_date, + ) + + with_bounds = attach_observation_bounds(events, ctx) + + joined = with_bounds.left_join( + eras_for_join, + predicates=[ + with_bounds.person_id == eras_for_join._era_person_id, + with_bounds[START_DATE] >= eras_for_join.era_start_date, + with_bounds[START_DATE] <= eras_for_join.era_end_date, + ], + ) + + event_window = ibis.window( + group_by=joined.event_id, + order_by=[joined.era_end_date.desc()], + ) + ranked = joined.mutate(_rn=ibis.row_number().over(event_window)) + one_per_event = ranked.filter(ranked._rn == 0) + + effective_end = ibis.coalesce( + one_per_event.era_end_date, + one_per_event.op_end_date, + ) + final_end = ibis.least(effective_end, one_per_event.op_end_date) + + return _replace_end_date(events, one_per_event, final_end) diff --git a/circe/execution/engine/end_strategy.py b/circe/execution/engine/end_strategy.py index a099985b..4b8e5b9e 100644 --- a/circe/execution/engine/end_strategy.py +++ b/circe/execution/engine/end_strategy.py @@ -64,7 +64,9 @@ def apply_end_strategy(events, strategy, ctx): return _replace_end_date(events, with_bounds, end_date_expr) if strategy.kind == "custom_era": - raise UnsupportedFeatureError("Ibis executor end-strategy error: custom_era is not supported.") + from .custom_era import apply_custom_era_strategy + + return apply_custom_era_strategy(events, strategy, ctx) # Fallback: preserve default semantics of op_end_date clipping. return _replace_end_date(events, with_bounds, with_bounds.op_end_date) diff --git a/circe/execution/normalize/cohort.py b/circe/execution/normalize/cohort.py index b2f657ff..61765b47 100644 --- a/circe/execution/normalize/cohort.py +++ b/circe/execution/normalize/cohort.py @@ -3,7 +3,7 @@ from ...cohortdefinition import CohortExpression from ...vocabulary.concept import ConceptSet from .._dataclass import frozen_slots_dataclass -from ..errors import ExecutionNormalizationError, UnsupportedFeatureError +from ..errors import ExecutionNormalizationError from .collapse import NormalizedCollapseSettings, normalize_collapse_settings from .criteria import NormalizedCriterion, normalize_criterion from .end_strategy import NormalizedEndStrategy, normalize_end_strategy @@ -149,10 +149,6 @@ def normalize_cohort( ) normalized_end_strategy = normalize_end_strategy(expression.end_strategy) - if normalized_end_strategy is not None and normalized_end_strategy.kind == "custom_era": - raise UnsupportedFeatureError( - "Ibis executor normalization error: custom_era end strategy is not supported." - ) return NormalizedCohort( title=expression.title, diff --git a/circe/execution/normalize/end_strategy.py b/circe/execution/normalize/end_strategy.py index 62ff666b..8e034091 100644 --- a/circe/execution/normalize/end_strategy.py +++ b/circe/execution/normalize/end_strategy.py @@ -32,6 +32,7 @@ def normalize_end_strategy( "drug_codeset_id": value.drug_codeset_id, "offset": int(value.offset), "gap_days": int(value.gap_days), + "days_supply_override": value.days_supply_override, }, ) return NormalizedEndStrategy(kind="end_strategy", payload={}) diff --git a/tests/execution/test_api_ibis.py b/tests/execution/test_api_ibis.py index ef0a73e4..db55e52e 100644 --- a/tests/execution/test_api_ibis.py +++ b/tests/execution/test_api_ibis.py @@ -26,8 +26,7 @@ VisitDetail, VisitOccurrence, ) -from circe.cohortdefinition.core import CustomEraStrategy, NumericRange -from circe.execution.errors import UnsupportedFeatureError +from circe.cohortdefinition.core import NumericRange from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem @@ -1213,10 +1212,12 @@ def test_build_cohort_location_region_keeps_repeated_location_history_rows(): assert sorted(result.start_date.astype(str).tolist()) == ["2020-01-01", "2020-02-01"] -def test_build_cohort_rejects_unsupported_features(): - expression = CohortExpression( - primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence()]), - end_strategy=CustomEraStrategy(drug_codeset_id=1, gap_days=30, offset=0), - ) - with pytest.raises(UnsupportedFeatureError, match="custom_era"): - _ = build_cohort(expression, backend=object(), cdm_schema="main") +def test_build_cohort_rejects_unsupported_criteria(): + """Unsupported base criteria type is rejected at normalization time.""" + from circe.cohortdefinition.criteria import Criteria as RawCriteria + from circe.execution.errors import UnsupportedCriterionError + + with pytest.raises(UnsupportedCriterionError): + from circe.execution.normalize.criteria import normalize_criterion + + normalize_criterion(RawCriteria()) diff --git a/tests/execution/test_error_messages.py b/tests/execution/test_error_messages.py index 80133b45..8e71481b 100644 --- a/tests/execution/test_error_messages.py +++ b/tests/execution/test_error_messages.py @@ -14,7 +14,7 @@ Occurrence, PrimaryCriteria, ) -from circe.cohortdefinition.core import CustomEraStrategy, NumericRange +from circe.cohortdefinition.core import NumericRange from circe.execution.errors import CompilationError, UnsupportedCriterionError, UnsupportedFeatureError from circe.execution.normalize.criteria import normalize_criterion from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem @@ -55,16 +55,6 @@ def _concept_set(set_id: int, concept_id: int) -> ConceptSet: ) -def test_error_message_for_custom_era_end_strategy(): - expression = CohortExpression( - primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence()]), - end_strategy=CustomEraStrategy(drug_codeset_id=1, gap_days=30, offset=0), - ) - - with pytest.raises(UnsupportedFeatureError, match="custom_era end strategy"): - _ = build_cohort(expression, backend=object(), cdm_schema="main") - - def test_error_message_for_unsupported_criterion_type(): with pytest.raises( UnsupportedCriterionError, From 09748f2fca4a5e3370dc50a7a7627fdd60c86463 Mon Sep 17 00:00:00 2001 From: Jamie Gilbert Date: Thu, 14 May 2026 12:48:25 -0700 Subject: [PATCH 2/5] more tests around custom era logic for parity with java --- tests/execution/test_custom_era.py | 566 +++++++++++++++++++++++++++++ 1 file changed, 566 insertions(+) create mode 100644 tests/execution/test_custom_era.py diff --git a/tests/execution/test_custom_era.py b/tests/execution/test_custom_era.py new file mode 100644 index 00000000..def64960 --- /dev/null +++ b/tests/execution/test_custom_era.py @@ -0,0 +1,566 @@ +from __future__ import annotations + +from datetime import date + +import pytest + +from circe.api import build_cohort +from circe.cohortdefinition import ( + CohortExpression, + ConditionOccurrence, + DrugExposure, + PrimaryCriteria, +) +from circe.cohortdefinition.core import CustomEraStrategy +from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem + + +def _make_concept_set(set_id: int, concept_id: int) -> ConceptSet: + return ConceptSet( + id=set_id, + expression=ConceptSetExpression( + items=[ConceptSetItem(concept=Concept(conceptId=concept_id))] + ), + ) + + +def _seed_common_tables(conn, ibis): + conn.create_table( + "person", + obj=ibis.memtable( + { + "person_id": [1], + "year_of_birth": [1980], + "gender_concept_id": [8507], + } + ), + overwrite=True, + ) + conn.create_table( + "observation_period", + obj=ibis.memtable( + { + "person_id": [1], + "observation_period_id": [10], + "observation_period_start_date": [date(2019, 1, 1)], + "observation_period_end_date": [date(2021, 12, 31)], + } + ), + overwrite=True, + ) + + +def test_custom_era_merges_drugs_within_gap(): + """Drug exposures within gap_days merge into one era; cohort end_date reflects it.""" + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + + conn.create_table( + "drug_exposure", + obj=ibis.memtable( + { + "person_id": [1, 1], + "drug_exposure_id": [1, 2], + "drug_concept_id": [222, 222], + "drug_exposure_start_date": [date(2020, 1, 1), date(2020, 2, 1)], + "drug_exposure_end_date": [date(2020, 1, 31), date(2020, 3, 3)], + "days_supply": [0, 0], + } + ), + overwrite=True, + ) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1], + "condition_occurrence_id": [100], + "condition_concept_id": [111], + "condition_start_date": [date(2020, 1, 1)], + "condition_end_date": [date(2020, 1, 1)], + "visit_occurrence_id": [10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + _make_concept_set(1, 111), + _make_concept_set(2, 222), + ], + primary_criteria=PrimaryCriteria( + criteria_list=[ConditionOccurrence(codeset_id=1)] + ), + end_strategy=CustomEraStrategy(drug_codeset_id=2, gap_days=30, offset=0), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + assert len(result) == 1 + assert str(result.iloc[0]["start_date"])[:10] == "2020-01-01" + # exp 1: end=2020-01-31, exp 2: end=2020-03-03 + # gap = 1 <= 30 -> merged era: start=2020-01-01, end=2020-03-03 + assert str(result.iloc[0]["end_date"])[:10] == "2020-03-03" + + +def test_custom_era_no_merge_across_large_gap(): + """Drug exposures beyond gap_days form separate eras; cohort uses nearest era.""" + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + + conn.create_table( + "drug_exposure", + obj=ibis.memtable( + { + "person_id": [1, 1], + "drug_exposure_id": [1, 2], + "drug_concept_id": [222, 222], + "drug_exposure_start_date": [date(2020, 1, 1), date(2020, 2, 1)], + "drug_exposure_end_date": [date(2020, 1, 6), date(2020, 3, 3)], + "days_supply": [0, 0], + } + ), + overwrite=True, + ) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1], + "condition_occurrence_id": [100], + "condition_concept_id": [111], + "condition_start_date": [date(2020, 1, 1)], + "condition_end_date": [date(2020, 1, 1)], + "visit_occurrence_id": [10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + _make_concept_set(1, 111), + _make_concept_set(2, 222), + ], + primary_criteria=PrimaryCriteria( + criteria_list=[ConditionOccurrence(codeset_id=1)] + ), + end_strategy=CustomEraStrategy(drug_codeset_id=2, gap_days=5, offset=0), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + assert len(result) == 1 + assert str(result.iloc[0]["start_date"])[:10] == "2020-01-01" + # exp 1: end=2020-01-06, exp 2: end=2020-03-03 + # gap = 26 > 5 -> separate eras + # cohort start 2020-01-01 matches era 1: end 2020-01-06 + assert str(result.iloc[0]["end_date"])[:10] == "2020-01-06" + + +def test_custom_era_offset_applied(): + """Offset days are added to the drug era end_date.""" + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + + conn.create_table( + "drug_exposure", + obj=ibis.memtable( + { + "person_id": [1], + "drug_exposure_id": [1], + "drug_concept_id": [222], + "drug_exposure_start_date": [date(2020, 1, 1)], + "drug_exposure_end_date": [date(2020, 1, 10)], + "days_supply": [0], + } + ), + overwrite=True, + ) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1], + "condition_occurrence_id": [100], + "condition_concept_id": [111], + "condition_start_date": [date(2020, 1, 1)], + "condition_end_date": [date(2020, 1, 1)], + "visit_occurrence_id": [10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + _make_concept_set(1, 111), + _make_concept_set(2, 222), + ], + primary_criteria=PrimaryCriteria( + criteria_list=[ConditionOccurrence(codeset_id=1)] + ), + end_strategy=CustomEraStrategy(drug_codeset_id=2, gap_days=30, offset=7), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + assert len(result) == 1 + assert str(result.iloc[0]["start_date"])[:10] == "2020-01-01" + # drug effective end: 2020-01-10 (end_date override) + # era: start=2020-01-01, end=2020-01-10+7=2020-01-17 + assert str(result.iloc[0]["end_date"])[:10] == "2020-01-17" + + +def test_custom_era_no_matching_drugs(): + """No matching drug exposures -> fall back to observation_period_end_date.""" + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1], + "condition_occurrence_id": [100], + "condition_concept_id": [111], + "condition_start_date": [date(2020, 1, 15)], + "condition_end_date": [date(2020, 1, 15)], + "visit_occurrence_id": [10], + } + ), + overwrite=True, + ) + conn.create_table( + "drug_exposure", + obj=ibis.memtable( + { + "person_id": [], + "drug_exposure_id": [], + "drug_concept_id": [], + "drug_exposure_start_date": [], + "drug_exposure_end_date": [], + "days_supply": [], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + _make_concept_set(1, 111), + _make_concept_set(2, 999), + ], + primary_criteria=PrimaryCriteria( + criteria_list=[ConditionOccurrence(codeset_id=1)] + ), + end_strategy=CustomEraStrategy(drug_codeset_id=2, gap_days=30, offset=0), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + assert len(result) == 1 + assert str(result.iloc[0]["start_date"])[:10] == "2020-01-15" + # No matching drugs -> end_date = observation_period_end_date = 2021-12-31 + assert str(result.iloc[0]["end_date"])[:10] == "2021-12-31" + + +def test_custom_era_with_drug_exposure_as_primary(): + """Custom era works with DrugExposure as the primary criterion.""" + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + + conn.create_table( + "drug_exposure", + obj=ibis.memtable( + { + "person_id": [1, 1], + "drug_exposure_id": [1, 2], + "drug_concept_id": [222, 222], + "drug_exposure_start_date": [date(2020, 1, 1), date(2020, 2, 1)], + "drug_exposure_end_date": [date(2020, 1, 31), date(2020, 3, 3)], + "days_supply": [0, 0], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 222)], + primary_criteria=PrimaryCriteria( + criteria_list=[DrugExposure(codeset_id=1)] + ), + end_strategy=CustomEraStrategy(drug_codeset_id=1, gap_days=30, offset=0), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + # With primary_limit_type="all", both drug exposures produce cohort entries. + # Both entries get end_date from the merged drug era (2020-03-03). + assert len(result) == 2 + start_dates = sorted(result["start_date"].astype(str).tolist()) + assert start_dates == ["2020-01-01", "2020-02-01"] + assert all( + str(d)[:10] == "2020-03-03" for d in result["end_date"] + ) + + +def test_compute_drug_eras_matches_java_sql_logic(): + """compute_drug_eras ibis output matches equivalent raw SQL (Java template translated to DuckDB).""" + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + from types import SimpleNamespace + + from circe.execution.engine.custom_era import compute_drug_eras + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + + # 5 exposures for person 1, with gap_days=7, offset=3. + # Exposure end_dates are set explicitly so COALESCE is predictable. + conn.create_table( + "drug_exposure", + obj=ibis.memtable( + { + "person_id": [1, 1, 1, 1, 1], + "drug_exposure_id": [1, 2, 3, 4, 5], + "drug_concept_id": [222, 222, 222, 222, 222], + "drug_exposure_start_date": [ + date(2020, 1, 1), + date(2020, 1, 10), + date(2020, 3, 1), + date(2020, 3, 20), + date(2020, 5, 1), + ], + "drug_exposure_end_date": [ + date(2020, 1, 6), + date(2020, 2, 9), + date(2020, 3, 21), + date(2020, 3, 30), + date(2020, 5, 15), + ], + "days_supply": [0, 0, 0, 0, 0], + } + ), + overwrite=True, + ) + + ctx = SimpleNamespace( + table=lambda name: conn.table(name), + concept_ids_for_codeset=lambda cid: (222,) if cid == 2 else (), + ) + + # --- ibis path --- + ibis_result = compute_drug_eras( + ctx, drug_codeset_id=2, gap_days=7, offset=3, days_supply_override=None + ).execute() + ibis_result = ibis_result.sort_values(["person_id", "era_start_date"]).reset_index(drop=True) + + # --- raw SQL path (Java template core logic, DuckDB dialect) --- + # Java template uses: COALESCE(end, start+days_supply, start+1) + # then pads by (gap_days + offset), groups by cumulative-max-over-preceding, + # and finally subtracts gap_days from max(end) to leave only offset. + gap = 7 + off = 3 + + sql = f""" + WITH exposures AS ( + SELECT + person_id::INTEGER AS person_id, + drug_exposure_start_date::DATE AS start_date, + COALESCE( + drug_exposure_end_date::DATE, + drug_exposure_start_date::DATE + days_supply::INTEGER, + drug_exposure_start_date::DATE + 1 + ) + {gap + off} AS padded_end + FROM drug_exposure + WHERE drug_concept_id IN (222) + ), + with_prev_max AS ( + SELECT *, + MAX(padded_end) OVER ( + PARTITION BY person_id ORDER BY start_date, padded_end DESC + ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING + ) AS prev_max + FROM exposures + ), + with_markers AS ( + SELECT *, + CASE WHEN prev_max IS NULL OR prev_max < start_date THEN 1 ELSE 0 END AS is_new + FROM with_prev_max + ), + with_era AS ( + SELECT *, + SUM(is_new) OVER ( + PARTITION BY person_id + ORDER BY start_date, is_new DESC, padded_end DESC + ) AS era_id + FROM with_markers + ) + SELECT + person_id, + MIN(start_date)::DATE AS era_start_date, + (MAX(padded_end) - {gap})::DATE AS era_end_date + FROM with_era + GROUP BY person_id, era_id + ORDER BY person_id, MIN(start_date) + """ + + raw_conn = conn.con + sql_result = raw_conn.sql(sql).fetchdf() + + # --- compare --- + pd = pytest.importorskip("pandas") + pd.testing.assert_frame_equal( + ibis_result, + sql_result, + check_dtype=False, + check_column_type=False, + ) + + +def test_full_cohort_custom_era_matches_sql_end_dates(): + """Full cohort pipeline with CustomEraStrategy produces same end_dates as raw SQL.""" + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + + conn.create_table( + "drug_exposure", + obj=ibis.memtable( + { + "person_id": [1, 1], + "drug_exposure_id": [1, 2], + "drug_concept_id": [222, 222], + "drug_exposure_start_date": [date(2020, 1, 1), date(2020, 2, 1)], + "drug_exposure_end_date": [date(2020, 1, 31), date(2020, 3, 3)], + "days_supply": [0, 0], + } + ), + overwrite=True, + ) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1], + "condition_occurrence_id": [100], + "condition_concept_id": [111], + "condition_start_date": [date(2020, 1, 1)], + "condition_end_date": [date(2020, 1, 1)], + "visit_occurrence_id": [10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + _make_concept_set(1, 111), + _make_concept_set(2, 222), + ], + primary_criteria=PrimaryCriteria( + criteria_list=[ConditionOccurrence(codeset_id=1)] + ), + end_strategy=CustomEraStrategy(drug_codeset_id=2, gap_days=30, offset=0), + ) + + # --- ibis pipeline --- + cohort_result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + # --- raw SQL pipeline (Java CUSTOM_ERA_STRATEGY_TEMPLATE logic, DuckDB dialect) --- + # Computes drug eras, then matches era end_dates to events via start_date overlap. + sql = f""" + WITH drug_eras AS ( + SELECT + person_id, + MIN(start_date) AS era_start_date, + MAX(padded_end) - 30 AS era_end_date + FROM ( + SELECT + person_id, start_date, padded_end, + SUM(is_new) OVER ( + PARTITION BY person_id + ORDER BY start_date, is_new DESC, padded_end DESC + ) AS era_id + FROM ( + SELECT + person_id, start_date, padded_end, + CASE WHEN prev_max IS NULL OR prev_max < start_date THEN 1 ELSE 0 END AS is_new + FROM ( + SELECT + person_id, start_date, padded_end, + MAX(padded_end) OVER ( + PARTITION BY person_id ORDER BY start_date, padded_end DESC + ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING + ) AS prev_max + FROM ( + SELECT + de.person_id, + de.drug_exposure_start_date::DATE AS start_date, + COALESCE( + de.drug_exposure_end_date::DATE, + de.drug_exposure_start_date::DATE + de.days_supply::INTEGER, + de.drug_exposure_start_date::DATE + 1 + ) + 30 AS padded_end + FROM drug_exposure de + WHERE de.drug_concept_id = 222 + ) raw_ends + ) maxes + ) marked + ) indexed + GROUP BY person_id, era_id + ), + events_with_obs AS ( + SELECT + e.person_id, + e.condition_occurrence_id AS event_id, + e.condition_start_date::DATE AS start_date, + op.observation_period_end_date::DATE AS op_end_date + FROM condition_occurrence e + JOIN observation_period op ON e.person_id = op.person_id + ) + SELECT + ev.person_id, + ev.start_date, + LEAST( + COALESCE(MAX(er.era_end_date), ev.op_end_date), + ev.op_end_date + )::DATE AS end_date + FROM events_with_obs ev + LEFT JOIN drug_eras er + ON ev.person_id = er.person_id + AND ev.start_date BETWEEN er.era_start_date AND er.era_end_date + GROUP BY ev.person_id, ev.event_id, ev.start_date, ev.op_end_date + ORDER BY ev.person_id, ev.start_date + """ + + sql_result = conn.con.sql(sql).fetchdf() + + # Compare end_dates and start_dates after sorting + ibis_ends = sorted(cohort_result["end_date"].astype(str).tolist()) + sql_ends = sorted(sql_result["end_date"].astype(str).tolist()) + assert ibis_ends == sql_ends + + ibis_starts = sorted(cohort_result["start_date"].astype(str).tolist()) + sql_starts = sorted(sql_result["start_date"].astype(str).tolist()) + assert ibis_starts == sql_starts From 8f7c4d2051e19d59ddef2e10a190583ea8b87a81 Mon Sep 17 00:00:00 2001 From: Jamie Gilbert Date: Thu, 14 May 2026 18:55:07 -0700 Subject: [PATCH 3/5] fix(execution): CustomEra window partitions on (person_id, event_id) to preserve all events When DrugExposure(first=True) and QualifiedLimit=First, every person has exactly 1 event with event_id=1 (assigned by _assign_primary_event_ids). The CustomEra window previously grouped by event_id alone, collapsing all rows into 1 partition and dropping N-1 rows with _rn==0. Grouping by (person_id, event_id) gives each row its own partition, preserving all events. Adds regression test via build_cohort. --- circe/execution/engine/custom_era.py | 2 +- tests/execution/test_custom_era.py | 82 +++++++++++++++++++++++++++- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/circe/execution/engine/custom_era.py b/circe/execution/engine/custom_era.py index e9a93e39..bf53091d 100644 --- a/circe/execution/engine/custom_era.py +++ b/circe/execution/engine/custom_era.py @@ -134,7 +134,7 @@ def apply_custom_era_strategy(events, strategy, ctx): ) event_window = ibis.window( - group_by=joined.event_id, + group_by=[joined.person_id, joined.event_id], order_by=[joined.era_end_date.desc()], ) ranked = joined.mutate(_rn=ibis.row_number().over(event_window)) diff --git a/tests/execution/test_custom_era.py b/tests/execution/test_custom_era.py index def64960..6a3a85b5 100644 --- a/tests/execution/test_custom_era.py +++ b/tests/execution/test_custom_era.py @@ -11,7 +11,7 @@ DrugExposure, PrimaryCriteria, ) -from circe.cohortdefinition.core import CustomEraStrategy +from circe.cohortdefinition.core import CustomEraStrategy, ResultLimit from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem @@ -564,3 +564,83 @@ def test_full_cohort_custom_era_matches_sql_end_dates(): ibis_starts = sorted(cohort_result["start_date"].astype(str).tolist()) sql_starts = sorted(sql_result["start_date"].astype(str).tolist()) assert ibis_starts == sql_starts + + +# --------------------------------------------------------------------------- +# Regression: CustomEra must preserve all events when event_id is shared +# +# After ``first=True`` + ``QualifiedLimit=First`` + ``ExpressionLimit=First`` +# every person contributes at most one event, and ``_assign_primary_event_ids`` +# assigns ``event_id=1`` to all of them. The CustomEra window that selects +# one matching era per event must therefore partition on *(person_id, event_id)* +# — otherwise all rows collapse into a single partition and only one survives. +# --------------------------------------------------------------------------- + + +def _seed_common_tables_multi_person(conn, ibis): + conn.create_table( + "person", + obj=ibis.memtable( + { + "person_id": [1, 2, 3], + "year_of_birth": [1980, 1985, 1990], + "gender_concept_id": [8507, 8507, 8507], + } + ), + overwrite=True, + ) + conn.create_table( + "observation_period", + obj=ibis.memtable( + { + "person_id": [1, 2, 3], + "observation_period_id": [10, 11, 12], + "observation_period_start_date": [date(2019, 1, 1), date(2019, 1, 1), date(2019, 1, 1)], + "observation_period_end_date": [date(2021, 12, 31), date(2021, 12, 31), date(2021, 12, 31)], + } + ), + overwrite=True, + ) + + +def test_custom_era_preserves_all_persons_with_first_true(): + """All persons survive when DrugExposure(first=True) + CustomEra + limits. + + The window ``group_by=joined.event_id`` previously collapsed every row + into a single partition because all events had ``event_id=1`` (assigned + by ``_assign_primary_event_ids`` — each person has exactly 1 event after + ``first=True`` and the per-person limits). + """ + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables_multi_person(conn, ibis) + + conn.create_table( + "drug_exposure", + obj=ibis.memtable( + { + "person_id": [1, 2, 3], + "drug_exposure_id": [100, 200, 300], + "drug_concept_id": [222, 222, 222], + "drug_exposure_start_date": [date(2020, 1, 1), date(2020, 2, 1), date(2020, 3, 1)], + "drug_exposure_end_date": [date(2020, 1, 31), date(2020, 2, 28), date(2020, 3, 31)], + "days_supply": [0, 0, 0], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 222)], + primary_criteria=PrimaryCriteria(criteria_list=[DrugExposure(codeset_id=1, first=True)]), + qualified_limit=ResultLimit(Type="First"), + expression_limit=ResultLimit(Type="First"), + end_strategy=CustomEraStrategy(drug_codeset_id=1, gap_days=30, offset=0), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + assert len(result) == 3, f"expected 3 rows, got {len(result)}" + assert set(result["person_id"]) == {1, 2, 3} From dd20010a84cf5e76b234ab509cd98d62999b36ff Mon Sep 17 00:00:00 2001 From: Jamie Gilbert Date: Sat, 16 May 2026 16:11:17 -0700 Subject: [PATCH 4/5] fixes and improved tests; MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ccrce/execution/engine/custom_era.py: - Issue 2 ✓ — _padded_end includes gap_days + offset; era end uses max(padded_end) - gap_days matching Circe BE - Issue 3 ✓ — Filters both drug_concept_id and drug_source_concept_id (with column existence guard) - Issue 5 ✓ — compute_drug_eras accepts cohort_person_ids; apply_custom_era_strategy semi-joins drug_exposure to cohort persons --- circe/execution/engine/custom_era.py | 32 +++++- tests/execution/test_custom_era.py | 141 ++++++++++++++++++++------- 2 files changed, 131 insertions(+), 42 deletions(-) diff --git a/circe/execution/engine/custom_era.py b/circe/execution/engine/custom_era.py index bf53091d..8a805236 100644 --- a/circe/execution/engine/custom_era.py +++ b/circe/execution/engine/custom_era.py @@ -26,7 +26,9 @@ def _compute_exposure_end_date(table, *, days_supply_override: int | None): def _compute_eras(exposures, *, gap_days: int, offset: int): - padded = exposures.mutate(_padded_end=(exposures._exposure_end + ibis.interval(days=int(gap_days)))) + padded = exposures.mutate( + _padded_end=(exposures._exposure_end + ibis.interval(days=int(gap_days + offset))) + ) ordering = [ padded.start_date, @@ -62,18 +64,24 @@ def _compute_eras(exposures, *, gap_days: int, offset: int): collapsed = era_indexed.group_by(era_indexed.person_id, era_indexed._era_id).aggregate( era_start_date=era_indexed.start_date.min(), - _max_exposure_end=era_indexed._exposure_end.max(), + _max_padded_end=era_indexed._padded_end.max(), ) return collapsed.select( collapsed.person_id.cast("int64").name(PERSON_ID), collapsed.era_start_date.cast("date").name("era_start_date"), - (collapsed._max_exposure_end + ibis.interval(days=int(offset))).cast("date").name("era_end_date"), + (collapsed._max_padded_end - ibis.interval(days=int(gap_days))).cast("date").name("era_end_date"), ) def compute_drug_eras( - ctx, *, drug_codeset_id: int, gap_days: int, offset: int, days_supply_override: int | None + ctx, + *, + drug_codeset_id: int, + gap_days: int, + offset: int, + days_supply_override: int | None, + cohort_person_ids=None, ): concept_ids = ctx.concept_ids_for_codeset(drug_codeset_id) @@ -86,7 +94,18 @@ def compute_drug_eras( ) de = ctx.table("drug_exposure") - filtered = de.filter(de.drug_concept_id.isin(concept_ids)) + if cohort_person_ids is not None: + de = de.semi_join( + cohort_person_ids.select(cohort_person_ids.person_id).distinct(), + predicates=[de.person_id == cohort_person_ids.person_id], + ) + + if "drug_source_concept_id" in de.columns: + filtered = de.filter( + de.drug_concept_id.isin(concept_ids) | de.drug_source_concept_id.isin(concept_ids) + ) + else: + filtered = de.filter(de.drug_concept_id.isin(concept_ids)) prepared = filtered.select( filtered.person_id.cast("int64").name("person_id"), @@ -108,12 +127,15 @@ def apply_custom_era_strategy(events, strategy, ctx): with_bounds = attach_observation_bounds(events, ctx) return _replace_end_date(events, with_bounds, with_bounds.op_end_date) + cohort_person_ids = events.select(events.person_id).distinct() + eras = compute_drug_eras( ctx, drug_codeset_id=drug_codeset_id, gap_days=gap_days, offset=offset, days_supply_override=days_supply_override, + cohort_person_ids=cohort_person_ids, ) eras_for_join = eras.select( diff --git a/tests/execution/test_custom_era.py b/tests/execution/test_custom_era.py index 6a3a85b5..ab327d41 100644 --- a/tests/execution/test_custom_era.py +++ b/tests/execution/test_custom_era.py @@ -18,9 +18,7 @@ def _make_concept_set(set_id: int, concept_id: int) -> ConceptSet: return ConceptSet( id=set_id, - expression=ConceptSetExpression( - items=[ConceptSetItem(concept=Concept(conceptId=concept_id))] - ), + expression=ConceptSetExpression(items=[ConceptSetItem(concept=Concept(conceptId=concept_id))]), ) @@ -92,9 +90,7 @@ def test_custom_era_merges_drugs_within_gap(): _make_concept_set(1, 111), _make_concept_set(2, 222), ], - primary_criteria=PrimaryCriteria( - criteria_list=[ConditionOccurrence(codeset_id=1)] - ), + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), end_strategy=CustomEraStrategy(drug_codeset_id=2, gap_days=30, offset=0), ) @@ -149,9 +145,7 @@ def test_custom_era_no_merge_across_large_gap(): _make_concept_set(1, 111), _make_concept_set(2, 222), ], - primary_criteria=PrimaryCriteria( - criteria_list=[ConditionOccurrence(codeset_id=1)] - ), + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), end_strategy=CustomEraStrategy(drug_codeset_id=2, gap_days=5, offset=0), ) @@ -207,9 +201,7 @@ def test_custom_era_offset_applied(): _make_concept_set(1, 111), _make_concept_set(2, 222), ], - primary_criteria=PrimaryCriteria( - criteria_list=[ConditionOccurrence(codeset_id=1)] - ), + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), end_strategy=CustomEraStrategy(drug_codeset_id=2, gap_days=30, offset=7), ) @@ -264,9 +256,7 @@ def test_custom_era_no_matching_drugs(): _make_concept_set(1, 111), _make_concept_set(2, 999), ], - primary_criteria=PrimaryCriteria( - criteria_list=[ConditionOccurrence(codeset_id=1)] - ), + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), end_strategy=CustomEraStrategy(drug_codeset_id=2, gap_days=30, offset=0), ) @@ -303,9 +293,7 @@ def test_custom_era_with_drug_exposure_as_primary(): expression = CohortExpression( concept_sets=[_make_concept_set(1, 222)], - primary_criteria=PrimaryCriteria( - criteria_list=[DrugExposure(codeset_id=1)] - ), + primary_criteria=PrimaryCriteria(criteria_list=[DrugExposure(codeset_id=1)]), end_strategy=CustomEraStrategy(drug_codeset_id=1, gap_days=30, offset=0), ) @@ -316,9 +304,7 @@ def test_custom_era_with_drug_exposure_as_primary(): assert len(result) == 2 start_dates = sorted(result["start_date"].astype(str).tolist()) assert start_dates == ["2020-01-01", "2020-02-01"] - assert all( - str(d)[:10] == "2020-03-03" for d in result["end_date"] - ) + assert all(str(d)[:10] == "2020-03-03" for d in result["end_date"]) def test_compute_drug_eras_matches_java_sql_logic(): @@ -436,6 +422,72 @@ def test_compute_drug_eras_matches_java_sql_logic(): ) +def test_custom_era_offset_affects_era_grouping(): + """Offset included in padded_end changes which exposures merge into eras. + + With gap_days=0, offset=30: + exp1: end=2020-01-10, exp2: start=2020-01-12 (gap=2 days) + + Without offset in padded_end: padded_end1=2020-01-10 (< start 2020-01-12) + → separate eras, cohort end=2020-01-10+30=2020-02-09 + + With offset in padded_end (Circe BE: DATEADD(day, gap+offset, end)): + padded_end1=2020-01-10+30=2020-02-09 (>= start 2020-01-12) + → merged era, cohort end=2020-01-20+30=2020-02-19 + """ + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + + conn.create_table( + "drug_exposure", + obj=ibis.memtable( + { + "person_id": [1, 1], + "drug_exposure_id": [1, 2], + "drug_concept_id": [222, 222], + "drug_exposure_start_date": [date(2020, 1, 1), date(2020, 1, 12)], + "drug_exposure_end_date": [date(2020, 1, 10), date(2020, 1, 20)], + "days_supply": [0, 0], + } + ), + overwrite=True, + ) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1], + "condition_occurrence_id": [100], + "condition_concept_id": [111], + "condition_start_date": [date(2020, 1, 1)], + "condition_end_date": [date(2020, 1, 1)], + "visit_occurrence_id": [10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + _make_concept_set(1, 111), + _make_concept_set(2, 222), + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + end_strategy=CustomEraStrategy(drug_codeset_id=2, gap_days=0, offset=30), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + assert len(result) == 1 + assert str(result.iloc[0]["start_date"])[:10] == "2020-01-01" + # Both exposures merge because padded_end1=2020-02-09 >= start 2020-01-12 + # era end = max(end) + offset = 2020-01-20 + 30 = 2020-02-19 + assert str(result.iloc[0]["end_date"])[:10] == "2020-02-19" + + def test_full_cohort_custom_era_matches_sql_end_dates(): """Full cohort pipeline with CustomEraStrategy produces same end_dates as raw SQL.""" ibis = pytest.importorskip("ibis") @@ -478,9 +530,7 @@ def test_full_cohort_custom_era_matches_sql_end_dates(): _make_concept_set(1, 111), _make_concept_set(2, 222), ], - primary_criteria=PrimaryCriteria( - criteria_list=[ConditionOccurrence(codeset_id=1)] - ), + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), end_strategy=CustomEraStrategy(drug_codeset_id=2, gap_days=30, offset=0), ) @@ -488,13 +538,17 @@ def test_full_cohort_custom_era_matches_sql_end_dates(): cohort_result = build_cohort(expression, backend=conn, cdm_schema="main").execute() # --- raw SQL pipeline (Java CUSTOM_ERA_STRATEGY_TEMPLATE logic, DuckDB dialect) --- - # Computes drug eras, then matches era end_dates to events via start_date overlap. + # Mirrors Circe BE's generateCohort.sql end-date selection: + # ROW_NUMBER() PARTITION BY person_id, event_id ORDER BY era_end_date ASC + # picks the earliest strategy end per event, matching Circe BE's + # MIN(end_date) across #strategy_ends union. + gap = 30 sql = f""" WITH drug_eras AS ( SELECT person_id, MIN(start_date) AS era_start_date, - MAX(padded_end) - 30 AS era_end_date + MAX(padded_end) - {gap} AS era_end_date FROM ( SELECT person_id, start_date, padded_end, @@ -521,7 +575,7 @@ def test_full_cohort_custom_era_matches_sql_end_dates(): de.drug_exposure_end_date::DATE, de.drug_exposure_start_date::DATE + de.days_supply::INTEGER, de.drug_exposure_start_date::DATE + 1 - ) + 30 AS padded_end + ) + {gap} AS padded_end FROM drug_exposure de WHERE de.drug_concept_id = 222 ) raw_ends @@ -538,20 +592,33 @@ def test_full_cohort_custom_era_matches_sql_end_dates(): op.observation_period_end_date::DATE AS op_end_date FROM condition_occurrence e JOIN observation_period op ON e.person_id = op.person_id + ), + ranked_ends AS ( + SELECT + ev.person_id, + ev.event_id, + ev.start_date, + ev.op_end_date, + er.era_end_date, + ROW_NUMBER() OVER ( + PARTITION BY ev.person_id, ev.event_id + ORDER BY er.era_end_date + ) AS rn + FROM events_with_obs ev + LEFT JOIN drug_eras er + ON ev.person_id = er.person_id + AND ev.start_date BETWEEN er.era_start_date AND er.era_end_date ) SELECT - ev.person_id, - ev.start_date, + person_id, + start_date, LEAST( - COALESCE(MAX(er.era_end_date), ev.op_end_date), - ev.op_end_date + COALESCE(era_end_date, op_end_date), + op_end_date )::DATE AS end_date - FROM events_with_obs ev - LEFT JOIN drug_eras er - ON ev.person_id = er.person_id - AND ev.start_date BETWEEN er.era_start_date AND er.era_end_date - GROUP BY ev.person_id, ev.event_id, ev.start_date, ev.op_end_date - ORDER BY ev.person_id, ev.start_date + FROM ranked_ends + WHERE rn = 1 + ORDER BY person_id, start_date """ sql_result = conn.con.sql(sql).fetchdf() From f3331bf81f32077f558bb61c5de7a370da0e697c Mon Sep 17 00:00:00 2001 From: Jamie Gilbert Date: Mon, 18 May 2026 16:55:47 -0700 Subject: [PATCH 5/5] More fixes from PR feedback --- circe/execution/engine/custom_era.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/circe/execution/engine/custom_era.py b/circe/execution/engine/custom_era.py index 8a805236..7e69db7b 100644 --- a/circe/execution/engine/custom_era.py +++ b/circe/execution/engine/custom_era.py @@ -96,7 +96,7 @@ def compute_drug_eras( de = ctx.table("drug_exposure") if cohort_person_ids is not None: de = de.semi_join( - cohort_person_ids.select(cohort_person_ids.person_id).distinct(), + cohort_person_ids, predicates=[de.person_id == cohort_person_ids.person_id], ) @@ -157,7 +157,7 @@ def apply_custom_era_strategy(events, strategy, ctx): event_window = ibis.window( group_by=[joined.person_id, joined.event_id], - order_by=[joined.era_end_date.desc()], + order_by=[joined.era_end_date.asc()], ) ranked = joined.mutate(_rn=ibis.row_number().over(event_window)) one_per_event = ranked.filter(ranked._rn == 0)