diff --git a/circe/execution/engine/custom_era.py b/circe/execution/engine/custom_era.py new file mode 100644 index 0000000..7e69db7 --- /dev/null +++ b/circe/execution/engine/custom_era.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +import ibis + +from ..plan.schema import PERSON_ID, START_DATE +from .end_strategy import _replace_end_date, attach_observation_bounds + + +def _compute_exposure_end_date(table, *, days_supply_override: int | None): + start = table["drug_exposure_start_date"].cast("date") + + if days_supply_override is not None: + return start + ibis.interval(days=days_supply_override) + + raw_end = ( + table["drug_exposure_end_date"].cast("date") + if "drug_exposure_end_date" in table.columns + else ibis.null().cast("date") + ) + days_supply = ( + table["days_supply"].cast("int64") if "days_supply" in table.columns else ibis.null().cast("int64") + ) + supply_end = start + days_supply.as_interval("D") + + return ibis.coalesce(raw_end, supply_end, start + ibis.interval(days=1)) + + +def _compute_eras(exposures, *, gap_days: int, offset: int): + padded = exposures.mutate( + _padded_end=(exposures._exposure_end + ibis.interval(days=int(gap_days + offset))) + ) + + ordering = [ + padded.start_date, + padded._padded_end.desc(), + padded._exposure_end.desc(), + ] + + cumulative_window = ibis.cumulative_window(group_by=padded.person_id, order_by=ordering) + ordered_window = ibis.window(group_by=padded.person_id, order_by=ordering) + + with_cummax = padded.mutate(_cummax_padded_end=padded._padded_end.max().over(cumulative_window)) + + with_prev = with_cummax.mutate(_prev_max=with_cummax._cummax_padded_end.lag().over(ordered_window)) + + marked = with_prev.mutate( + _is_new=ibis.ifelse( + with_prev._prev_max.isnull() | (with_prev._prev_max < with_prev.start_date), + ibis.literal(1, type="int64"), + ibis.literal(0, type="int64"), + ) + ) + + group_window = ibis.cumulative_window( + group_by=marked.person_id, + order_by=[ + marked.start_date, + marked._padded_end.desc(), + marked._exposure_end.desc(), + marked._is_new.desc(), + ], + ) + era_indexed = marked.mutate(_era_id=marked._is_new.sum().over(group_window)) + + collapsed = era_indexed.group_by(era_indexed.person_id, era_indexed._era_id).aggregate( + era_start_date=era_indexed.start_date.min(), + _max_padded_end=era_indexed._padded_end.max(), + ) + + return collapsed.select( + collapsed.person_id.cast("int64").name(PERSON_ID), + collapsed.era_start_date.cast("date").name("era_start_date"), + (collapsed._max_padded_end - ibis.interval(days=int(gap_days))).cast("date").name("era_end_date"), + ) + + +def compute_drug_eras( + ctx, + *, + drug_codeset_id: int, + gap_days: int, + offset: int, + days_supply_override: int | None, + cohort_person_ids=None, +): + concept_ids = ctx.concept_ids_for_codeset(drug_codeset_id) + + if not concept_ids: + de = ctx.table("drug_exposure") + return de.filter(ibis.literal(False)).select( + de.person_id.cast("int64").name(PERSON_ID), + ibis.null().cast("date").name("era_start_date"), + ibis.null().cast("date").name("era_end_date"), + ) + + de = ctx.table("drug_exposure") + if cohort_person_ids is not None: + de = de.semi_join( + cohort_person_ids, + predicates=[de.person_id == cohort_person_ids.person_id], + ) + + if "drug_source_concept_id" in de.columns: + filtered = de.filter( + de.drug_concept_id.isin(concept_ids) | de.drug_source_concept_id.isin(concept_ids) + ) + else: + filtered = de.filter(de.drug_concept_id.isin(concept_ids)) + + prepared = filtered.select( + filtered.person_id.cast("int64").name("person_id"), + filtered.drug_exposure_start_date.cast("date").name("start_date"), + _compute_exposure_end_date(filtered, days_supply_override=days_supply_override).name("_exposure_end"), + ) + + return _compute_eras(prepared, gap_days=gap_days, offset=offset) + + +def apply_custom_era_strategy(events, strategy, ctx): + payload = strategy.payload + drug_codeset_id = payload["drug_codeset_id"] + gap_days = payload["gap_days"] + offset = payload["offset"] + days_supply_override = payload.get("days_supply_override") + + if drug_codeset_id is None: + with_bounds = attach_observation_bounds(events, ctx) + return _replace_end_date(events, with_bounds, with_bounds.op_end_date) + + cohort_person_ids = events.select(events.person_id).distinct() + + eras = compute_drug_eras( + ctx, + drug_codeset_id=drug_codeset_id, + gap_days=gap_days, + offset=offset, + days_supply_override=days_supply_override, + cohort_person_ids=cohort_person_ids, + ) + + eras_for_join = eras.select( + eras.person_id.name("_era_person_id"), + eras.era_start_date, + eras.era_end_date, + ) + + with_bounds = attach_observation_bounds(events, ctx) + + joined = with_bounds.left_join( + eras_for_join, + predicates=[ + with_bounds.person_id == eras_for_join._era_person_id, + with_bounds[START_DATE] >= eras_for_join.era_start_date, + with_bounds[START_DATE] <= eras_for_join.era_end_date, + ], + ) + + event_window = ibis.window( + group_by=[joined.person_id, joined.event_id], + order_by=[joined.era_end_date.asc()], + ) + ranked = joined.mutate(_rn=ibis.row_number().over(event_window)) + one_per_event = ranked.filter(ranked._rn == 0) + + effective_end = ibis.coalesce( + one_per_event.era_end_date, + one_per_event.op_end_date, + ) + final_end = ibis.least(effective_end, one_per_event.op_end_date) + + return _replace_end_date(events, one_per_event, final_end) diff --git a/circe/execution/engine/end_strategy.py b/circe/execution/engine/end_strategy.py index a099985..4b8e5b9 100644 --- a/circe/execution/engine/end_strategy.py +++ b/circe/execution/engine/end_strategy.py @@ -64,7 +64,9 @@ def apply_end_strategy(events, strategy, ctx): return _replace_end_date(events, with_bounds, end_date_expr) if strategy.kind == "custom_era": - raise UnsupportedFeatureError("Ibis executor end-strategy error: custom_era is not supported.") + from .custom_era import apply_custom_era_strategy + + return apply_custom_era_strategy(events, strategy, ctx) # Fallback: preserve default semantics of op_end_date clipping. return _replace_end_date(events, with_bounds, with_bounds.op_end_date) diff --git a/circe/execution/normalize/cohort.py b/circe/execution/normalize/cohort.py index b2f657f..61765b4 100644 --- a/circe/execution/normalize/cohort.py +++ b/circe/execution/normalize/cohort.py @@ -3,7 +3,7 @@ from ...cohortdefinition import CohortExpression from ...vocabulary.concept import ConceptSet from .._dataclass import frozen_slots_dataclass -from ..errors import ExecutionNormalizationError, UnsupportedFeatureError +from ..errors import ExecutionNormalizationError from .collapse import NormalizedCollapseSettings, normalize_collapse_settings from .criteria import NormalizedCriterion, normalize_criterion from .end_strategy import NormalizedEndStrategy, normalize_end_strategy @@ -149,10 +149,6 @@ def normalize_cohort( ) normalized_end_strategy = normalize_end_strategy(expression.end_strategy) - if normalized_end_strategy is not None and normalized_end_strategy.kind == "custom_era": - raise UnsupportedFeatureError( - "Ibis executor normalization error: custom_era end strategy is not supported." - ) return NormalizedCohort( title=expression.title, diff --git a/circe/execution/normalize/end_strategy.py b/circe/execution/normalize/end_strategy.py index 62ff666..8e03409 100644 --- a/circe/execution/normalize/end_strategy.py +++ b/circe/execution/normalize/end_strategy.py @@ -32,6 +32,7 @@ def normalize_end_strategy( "drug_codeset_id": value.drug_codeset_id, "offset": int(value.offset), "gap_days": int(value.gap_days), + "days_supply_override": value.days_supply_override, }, ) return NormalizedEndStrategy(kind="end_strategy", payload={}) diff --git a/tests/execution/test_api_ibis.py b/tests/execution/test_api_ibis.py index ef0a73e..db55e52 100644 --- a/tests/execution/test_api_ibis.py +++ b/tests/execution/test_api_ibis.py @@ -26,8 +26,7 @@ VisitDetail, VisitOccurrence, ) -from circe.cohortdefinition.core import CustomEraStrategy, NumericRange -from circe.execution.errors import UnsupportedFeatureError +from circe.cohortdefinition.core import NumericRange from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem @@ -1213,10 +1212,12 @@ def test_build_cohort_location_region_keeps_repeated_location_history_rows(): assert sorted(result.start_date.astype(str).tolist()) == ["2020-01-01", "2020-02-01"] -def test_build_cohort_rejects_unsupported_features(): - expression = CohortExpression( - primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence()]), - end_strategy=CustomEraStrategy(drug_codeset_id=1, gap_days=30, offset=0), - ) - with pytest.raises(UnsupportedFeatureError, match="custom_era"): - _ = build_cohort(expression, backend=object(), cdm_schema="main") +def test_build_cohort_rejects_unsupported_criteria(): + """Unsupported base criteria type is rejected at normalization time.""" + from circe.cohortdefinition.criteria import Criteria as RawCriteria + from circe.execution.errors import UnsupportedCriterionError + + with pytest.raises(UnsupportedCriterionError): + from circe.execution.normalize.criteria import normalize_criterion + + normalize_criterion(RawCriteria()) diff --git a/tests/execution/test_custom_era.py b/tests/execution/test_custom_era.py new file mode 100644 index 0000000..ab327d4 --- /dev/null +++ b/tests/execution/test_custom_era.py @@ -0,0 +1,713 @@ +from __future__ import annotations + +from datetime import date + +import pytest + +from circe.api import build_cohort +from circe.cohortdefinition import ( + CohortExpression, + ConditionOccurrence, + DrugExposure, + PrimaryCriteria, +) +from circe.cohortdefinition.core import CustomEraStrategy, ResultLimit +from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem + + +def _make_concept_set(set_id: int, concept_id: int) -> ConceptSet: + return ConceptSet( + id=set_id, + expression=ConceptSetExpression(items=[ConceptSetItem(concept=Concept(conceptId=concept_id))]), + ) + + +def _seed_common_tables(conn, ibis): + conn.create_table( + "person", + obj=ibis.memtable( + { + "person_id": [1], + "year_of_birth": [1980], + "gender_concept_id": [8507], + } + ), + overwrite=True, + ) + conn.create_table( + "observation_period", + obj=ibis.memtable( + { + "person_id": [1], + "observation_period_id": [10], + "observation_period_start_date": [date(2019, 1, 1)], + "observation_period_end_date": [date(2021, 12, 31)], + } + ), + overwrite=True, + ) + + +def test_custom_era_merges_drugs_within_gap(): + """Drug exposures within gap_days merge into one era; cohort end_date reflects it.""" + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + + conn.create_table( + "drug_exposure", + obj=ibis.memtable( + { + "person_id": [1, 1], + "drug_exposure_id": [1, 2], + "drug_concept_id": [222, 222], + "drug_exposure_start_date": [date(2020, 1, 1), date(2020, 2, 1)], + "drug_exposure_end_date": [date(2020, 1, 31), date(2020, 3, 3)], + "days_supply": [0, 0], + } + ), + overwrite=True, + ) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1], + "condition_occurrence_id": [100], + "condition_concept_id": [111], + "condition_start_date": [date(2020, 1, 1)], + "condition_end_date": [date(2020, 1, 1)], + "visit_occurrence_id": [10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + _make_concept_set(1, 111), + _make_concept_set(2, 222), + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + end_strategy=CustomEraStrategy(drug_codeset_id=2, gap_days=30, offset=0), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + assert len(result) == 1 + assert str(result.iloc[0]["start_date"])[:10] == "2020-01-01" + # exp 1: end=2020-01-31, exp 2: end=2020-03-03 + # gap = 1 <= 30 -> merged era: start=2020-01-01, end=2020-03-03 + assert str(result.iloc[0]["end_date"])[:10] == "2020-03-03" + + +def test_custom_era_no_merge_across_large_gap(): + """Drug exposures beyond gap_days form separate eras; cohort uses nearest era.""" + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + + conn.create_table( + "drug_exposure", + obj=ibis.memtable( + { + "person_id": [1, 1], + "drug_exposure_id": [1, 2], + "drug_concept_id": [222, 222], + "drug_exposure_start_date": [date(2020, 1, 1), date(2020, 2, 1)], + "drug_exposure_end_date": [date(2020, 1, 6), date(2020, 3, 3)], + "days_supply": [0, 0], + } + ), + overwrite=True, + ) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1], + "condition_occurrence_id": [100], + "condition_concept_id": [111], + "condition_start_date": [date(2020, 1, 1)], + "condition_end_date": [date(2020, 1, 1)], + "visit_occurrence_id": [10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + _make_concept_set(1, 111), + _make_concept_set(2, 222), + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + end_strategy=CustomEraStrategy(drug_codeset_id=2, gap_days=5, offset=0), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + assert len(result) == 1 + assert str(result.iloc[0]["start_date"])[:10] == "2020-01-01" + # exp 1: end=2020-01-06, exp 2: end=2020-03-03 + # gap = 26 > 5 -> separate eras + # cohort start 2020-01-01 matches era 1: end 2020-01-06 + assert str(result.iloc[0]["end_date"])[:10] == "2020-01-06" + + +def test_custom_era_offset_applied(): + """Offset days are added to the drug era end_date.""" + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + + conn.create_table( + "drug_exposure", + obj=ibis.memtable( + { + "person_id": [1], + "drug_exposure_id": [1], + "drug_concept_id": [222], + "drug_exposure_start_date": [date(2020, 1, 1)], + "drug_exposure_end_date": [date(2020, 1, 10)], + "days_supply": [0], + } + ), + overwrite=True, + ) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1], + "condition_occurrence_id": [100], + "condition_concept_id": [111], + "condition_start_date": [date(2020, 1, 1)], + "condition_end_date": [date(2020, 1, 1)], + "visit_occurrence_id": [10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + _make_concept_set(1, 111), + _make_concept_set(2, 222), + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + end_strategy=CustomEraStrategy(drug_codeset_id=2, gap_days=30, offset=7), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + assert len(result) == 1 + assert str(result.iloc[0]["start_date"])[:10] == "2020-01-01" + # drug effective end: 2020-01-10 (end_date override) + # era: start=2020-01-01, end=2020-01-10+7=2020-01-17 + assert str(result.iloc[0]["end_date"])[:10] == "2020-01-17" + + +def test_custom_era_no_matching_drugs(): + """No matching drug exposures -> fall back to observation_period_end_date.""" + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1], + "condition_occurrence_id": [100], + "condition_concept_id": [111], + "condition_start_date": [date(2020, 1, 15)], + "condition_end_date": [date(2020, 1, 15)], + "visit_occurrence_id": [10], + } + ), + overwrite=True, + ) + conn.create_table( + "drug_exposure", + obj=ibis.memtable( + { + "person_id": [], + "drug_exposure_id": [], + "drug_concept_id": [], + "drug_exposure_start_date": [], + "drug_exposure_end_date": [], + "days_supply": [], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + _make_concept_set(1, 111), + _make_concept_set(2, 999), + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + end_strategy=CustomEraStrategy(drug_codeset_id=2, gap_days=30, offset=0), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + assert len(result) == 1 + assert str(result.iloc[0]["start_date"])[:10] == "2020-01-15" + # No matching drugs -> end_date = observation_period_end_date = 2021-12-31 + assert str(result.iloc[0]["end_date"])[:10] == "2021-12-31" + + +def test_custom_era_with_drug_exposure_as_primary(): + """Custom era works with DrugExposure as the primary criterion.""" + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + + conn.create_table( + "drug_exposure", + obj=ibis.memtable( + { + "person_id": [1, 1], + "drug_exposure_id": [1, 2], + "drug_concept_id": [222, 222], + "drug_exposure_start_date": [date(2020, 1, 1), date(2020, 2, 1)], + "drug_exposure_end_date": [date(2020, 1, 31), date(2020, 3, 3)], + "days_supply": [0, 0], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 222)], + primary_criteria=PrimaryCriteria(criteria_list=[DrugExposure(codeset_id=1)]), + end_strategy=CustomEraStrategy(drug_codeset_id=1, gap_days=30, offset=0), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + # With primary_limit_type="all", both drug exposures produce cohort entries. + # Both entries get end_date from the merged drug era (2020-03-03). + assert len(result) == 2 + start_dates = sorted(result["start_date"].astype(str).tolist()) + assert start_dates == ["2020-01-01", "2020-02-01"] + assert all(str(d)[:10] == "2020-03-03" for d in result["end_date"]) + + +def test_compute_drug_eras_matches_java_sql_logic(): + """compute_drug_eras ibis output matches equivalent raw SQL (Java template translated to DuckDB).""" + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + from types import SimpleNamespace + + from circe.execution.engine.custom_era import compute_drug_eras + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + + # 5 exposures for person 1, with gap_days=7, offset=3. + # Exposure end_dates are set explicitly so COALESCE is predictable. + conn.create_table( + "drug_exposure", + obj=ibis.memtable( + { + "person_id": [1, 1, 1, 1, 1], + "drug_exposure_id": [1, 2, 3, 4, 5], + "drug_concept_id": [222, 222, 222, 222, 222], + "drug_exposure_start_date": [ + date(2020, 1, 1), + date(2020, 1, 10), + date(2020, 3, 1), + date(2020, 3, 20), + date(2020, 5, 1), + ], + "drug_exposure_end_date": [ + date(2020, 1, 6), + date(2020, 2, 9), + date(2020, 3, 21), + date(2020, 3, 30), + date(2020, 5, 15), + ], + "days_supply": [0, 0, 0, 0, 0], + } + ), + overwrite=True, + ) + + ctx = SimpleNamespace( + table=lambda name: conn.table(name), + concept_ids_for_codeset=lambda cid: (222,) if cid == 2 else (), + ) + + # --- ibis path --- + ibis_result = compute_drug_eras( + ctx, drug_codeset_id=2, gap_days=7, offset=3, days_supply_override=None + ).execute() + ibis_result = ibis_result.sort_values(["person_id", "era_start_date"]).reset_index(drop=True) + + # --- raw SQL path (Java template core logic, DuckDB dialect) --- + # Java template uses: COALESCE(end, start+days_supply, start+1) + # then pads by (gap_days + offset), groups by cumulative-max-over-preceding, + # and finally subtracts gap_days from max(end) to leave only offset. + gap = 7 + off = 3 + + sql = f""" + WITH exposures AS ( + SELECT + person_id::INTEGER AS person_id, + drug_exposure_start_date::DATE AS start_date, + COALESCE( + drug_exposure_end_date::DATE, + drug_exposure_start_date::DATE + days_supply::INTEGER, + drug_exposure_start_date::DATE + 1 + ) + {gap + off} AS padded_end + FROM drug_exposure + WHERE drug_concept_id IN (222) + ), + with_prev_max AS ( + SELECT *, + MAX(padded_end) OVER ( + PARTITION BY person_id ORDER BY start_date, padded_end DESC + ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING + ) AS prev_max + FROM exposures + ), + with_markers AS ( + SELECT *, + CASE WHEN prev_max IS NULL OR prev_max < start_date THEN 1 ELSE 0 END AS is_new + FROM with_prev_max + ), + with_era AS ( + SELECT *, + SUM(is_new) OVER ( + PARTITION BY person_id + ORDER BY start_date, is_new DESC, padded_end DESC + ) AS era_id + FROM with_markers + ) + SELECT + person_id, + MIN(start_date)::DATE AS era_start_date, + (MAX(padded_end) - {gap})::DATE AS era_end_date + FROM with_era + GROUP BY person_id, era_id + ORDER BY person_id, MIN(start_date) + """ + + raw_conn = conn.con + sql_result = raw_conn.sql(sql).fetchdf() + + # --- compare --- + pd = pytest.importorskip("pandas") + pd.testing.assert_frame_equal( + ibis_result, + sql_result, + check_dtype=False, + check_column_type=False, + ) + + +def test_custom_era_offset_affects_era_grouping(): + """Offset included in padded_end changes which exposures merge into eras. + + With gap_days=0, offset=30: + exp1: end=2020-01-10, exp2: start=2020-01-12 (gap=2 days) + + Without offset in padded_end: padded_end1=2020-01-10 (< start 2020-01-12) + → separate eras, cohort end=2020-01-10+30=2020-02-09 + + With offset in padded_end (Circe BE: DATEADD(day, gap+offset, end)): + padded_end1=2020-01-10+30=2020-02-09 (>= start 2020-01-12) + → merged era, cohort end=2020-01-20+30=2020-02-19 + """ + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + + conn.create_table( + "drug_exposure", + obj=ibis.memtable( + { + "person_id": [1, 1], + "drug_exposure_id": [1, 2], + "drug_concept_id": [222, 222], + "drug_exposure_start_date": [date(2020, 1, 1), date(2020, 1, 12)], + "drug_exposure_end_date": [date(2020, 1, 10), date(2020, 1, 20)], + "days_supply": [0, 0], + } + ), + overwrite=True, + ) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1], + "condition_occurrence_id": [100], + "condition_concept_id": [111], + "condition_start_date": [date(2020, 1, 1)], + "condition_end_date": [date(2020, 1, 1)], + "visit_occurrence_id": [10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + _make_concept_set(1, 111), + _make_concept_set(2, 222), + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + end_strategy=CustomEraStrategy(drug_codeset_id=2, gap_days=0, offset=30), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + assert len(result) == 1 + assert str(result.iloc[0]["start_date"])[:10] == "2020-01-01" + # Both exposures merge because padded_end1=2020-02-09 >= start 2020-01-12 + # era end = max(end) + offset = 2020-01-20 + 30 = 2020-02-19 + assert str(result.iloc[0]["end_date"])[:10] == "2020-02-19" + + +def test_full_cohort_custom_era_matches_sql_end_dates(): + """Full cohort pipeline with CustomEraStrategy produces same end_dates as raw SQL.""" + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables(conn, ibis) + + conn.create_table( + "drug_exposure", + obj=ibis.memtable( + { + "person_id": [1, 1], + "drug_exposure_id": [1, 2], + "drug_concept_id": [222, 222], + "drug_exposure_start_date": [date(2020, 1, 1), date(2020, 2, 1)], + "drug_exposure_end_date": [date(2020, 1, 31), date(2020, 3, 3)], + "days_supply": [0, 0], + } + ), + overwrite=True, + ) + conn.create_table( + "condition_occurrence", + obj=ibis.memtable( + { + "person_id": [1], + "condition_occurrence_id": [100], + "condition_concept_id": [111], + "condition_start_date": [date(2020, 1, 1)], + "condition_end_date": [date(2020, 1, 1)], + "visit_occurrence_id": [10], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[ + _make_concept_set(1, 111), + _make_concept_set(2, 222), + ], + primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence(codeset_id=1)]), + end_strategy=CustomEraStrategy(drug_codeset_id=2, gap_days=30, offset=0), + ) + + # --- ibis pipeline --- + cohort_result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + # --- raw SQL pipeline (Java CUSTOM_ERA_STRATEGY_TEMPLATE logic, DuckDB dialect) --- + # Mirrors Circe BE's generateCohort.sql end-date selection: + # ROW_NUMBER() PARTITION BY person_id, event_id ORDER BY era_end_date ASC + # picks the earliest strategy end per event, matching Circe BE's + # MIN(end_date) across #strategy_ends union. + gap = 30 + sql = f""" + WITH drug_eras AS ( + SELECT + person_id, + MIN(start_date) AS era_start_date, + MAX(padded_end) - {gap} AS era_end_date + FROM ( + SELECT + person_id, start_date, padded_end, + SUM(is_new) OVER ( + PARTITION BY person_id + ORDER BY start_date, is_new DESC, padded_end DESC + ) AS era_id + FROM ( + SELECT + person_id, start_date, padded_end, + CASE WHEN prev_max IS NULL OR prev_max < start_date THEN 1 ELSE 0 END AS is_new + FROM ( + SELECT + person_id, start_date, padded_end, + MAX(padded_end) OVER ( + PARTITION BY person_id ORDER BY start_date, padded_end DESC + ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING + ) AS prev_max + FROM ( + SELECT + de.person_id, + de.drug_exposure_start_date::DATE AS start_date, + COALESCE( + de.drug_exposure_end_date::DATE, + de.drug_exposure_start_date::DATE + de.days_supply::INTEGER, + de.drug_exposure_start_date::DATE + 1 + ) + {gap} AS padded_end + FROM drug_exposure de + WHERE de.drug_concept_id = 222 + ) raw_ends + ) maxes + ) marked + ) indexed + GROUP BY person_id, era_id + ), + events_with_obs AS ( + SELECT + e.person_id, + e.condition_occurrence_id AS event_id, + e.condition_start_date::DATE AS start_date, + op.observation_period_end_date::DATE AS op_end_date + FROM condition_occurrence e + JOIN observation_period op ON e.person_id = op.person_id + ), + ranked_ends AS ( + SELECT + ev.person_id, + ev.event_id, + ev.start_date, + ev.op_end_date, + er.era_end_date, + ROW_NUMBER() OVER ( + PARTITION BY ev.person_id, ev.event_id + ORDER BY er.era_end_date + ) AS rn + FROM events_with_obs ev + LEFT JOIN drug_eras er + ON ev.person_id = er.person_id + AND ev.start_date BETWEEN er.era_start_date AND er.era_end_date + ) + SELECT + person_id, + start_date, + LEAST( + COALESCE(era_end_date, op_end_date), + op_end_date + )::DATE AS end_date + FROM ranked_ends + WHERE rn = 1 + ORDER BY person_id, start_date + """ + + sql_result = conn.con.sql(sql).fetchdf() + + # Compare end_dates and start_dates after sorting + ibis_ends = sorted(cohort_result["end_date"].astype(str).tolist()) + sql_ends = sorted(sql_result["end_date"].astype(str).tolist()) + assert ibis_ends == sql_ends + + ibis_starts = sorted(cohort_result["start_date"].astype(str).tolist()) + sql_starts = sorted(sql_result["start_date"].astype(str).tolist()) + assert ibis_starts == sql_starts + + +# --------------------------------------------------------------------------- +# Regression: CustomEra must preserve all events when event_id is shared +# +# After ``first=True`` + ``QualifiedLimit=First`` + ``ExpressionLimit=First`` +# every person contributes at most one event, and ``_assign_primary_event_ids`` +# assigns ``event_id=1`` to all of them. The CustomEra window that selects +# one matching era per event must therefore partition on *(person_id, event_id)* +# — otherwise all rows collapse into a single partition and only one survives. +# --------------------------------------------------------------------------- + + +def _seed_common_tables_multi_person(conn, ibis): + conn.create_table( + "person", + obj=ibis.memtable( + { + "person_id": [1, 2, 3], + "year_of_birth": [1980, 1985, 1990], + "gender_concept_id": [8507, 8507, 8507], + } + ), + overwrite=True, + ) + conn.create_table( + "observation_period", + obj=ibis.memtable( + { + "person_id": [1, 2, 3], + "observation_period_id": [10, 11, 12], + "observation_period_start_date": [date(2019, 1, 1), date(2019, 1, 1), date(2019, 1, 1)], + "observation_period_end_date": [date(2021, 12, 31), date(2021, 12, 31), date(2021, 12, 31)], + } + ), + overwrite=True, + ) + + +def test_custom_era_preserves_all_persons_with_first_true(): + """All persons survive when DrugExposure(first=True) + CustomEra + limits. + + The window ``group_by=joined.event_id`` previously collapsed every row + into a single partition because all events had ``event_id=1`` (assigned + by ``_assign_primary_event_ids`` — each person has exactly 1 event after + ``first=True`` and the per-person limits). + """ + ibis = pytest.importorskip("ibis") + _ = pytest.importorskip("duckdb") + + conn = ibis.duckdb.connect() + _seed_common_tables_multi_person(conn, ibis) + + conn.create_table( + "drug_exposure", + obj=ibis.memtable( + { + "person_id": [1, 2, 3], + "drug_exposure_id": [100, 200, 300], + "drug_concept_id": [222, 222, 222], + "drug_exposure_start_date": [date(2020, 1, 1), date(2020, 2, 1), date(2020, 3, 1)], + "drug_exposure_end_date": [date(2020, 1, 31), date(2020, 2, 28), date(2020, 3, 31)], + "days_supply": [0, 0, 0], + } + ), + overwrite=True, + ) + + expression = CohortExpression( + concept_sets=[_make_concept_set(1, 222)], + primary_criteria=PrimaryCriteria(criteria_list=[DrugExposure(codeset_id=1, first=True)]), + qualified_limit=ResultLimit(Type="First"), + expression_limit=ResultLimit(Type="First"), + end_strategy=CustomEraStrategy(drug_codeset_id=1, gap_days=30, offset=0), + ) + + result = build_cohort(expression, backend=conn, cdm_schema="main").execute() + + assert len(result) == 3, f"expected 3 rows, got {len(result)}" + assert set(result["person_id"]) == {1, 2, 3} diff --git a/tests/execution/test_error_messages.py b/tests/execution/test_error_messages.py index 80133b4..8e71481 100644 --- a/tests/execution/test_error_messages.py +++ b/tests/execution/test_error_messages.py @@ -14,7 +14,7 @@ Occurrence, PrimaryCriteria, ) -from circe.cohortdefinition.core import CustomEraStrategy, NumericRange +from circe.cohortdefinition.core import NumericRange from circe.execution.errors import CompilationError, UnsupportedCriterionError, UnsupportedFeatureError from circe.execution.normalize.criteria import normalize_criterion from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem @@ -55,16 +55,6 @@ def _concept_set(set_id: int, concept_id: int) -> ConceptSet: ) -def test_error_message_for_custom_era_end_strategy(): - expression = CohortExpression( - primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence()]), - end_strategy=CustomEraStrategy(drug_codeset_id=1, gap_days=30, offset=0), - ) - - with pytest.raises(UnsupportedFeatureError, match="custom_era end strategy"): - _ = build_cohort(expression, backend=object(), cdm_schema="main") - - def test_error_message_for_unsupported_criterion_type(): with pytest.raises( UnsupportedCriterionError,