Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions circe/execution/engine/custom_era.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
from __future__ import annotations

import ibis

from ..plan.schema import PERSON_ID, START_DATE
from .end_strategy import _replace_end_date, attach_observation_bounds


def _compute_exposure_end_date(table, *, days_supply_override: int | None):
start = table["drug_exposure_start_date"].cast("date")

if days_supply_override is not None:
return start + ibis.interval(days=days_supply_override)

raw_end = (
table["drug_exposure_end_date"].cast("date")
if "drug_exposure_end_date" in table.columns
else ibis.null().cast("date")
)
days_supply = (
table["days_supply"].cast("int64") if "days_supply" in table.columns else ibis.null().cast("int64")
)
supply_end = start + days_supply.as_interval("D")

return ibis.coalesce(raw_end, supply_end, start + ibis.interval(days=1))


def _compute_eras(exposures, *, gap_days: int, offset: int):
padded = exposures.mutate(
_padded_end=(exposures._exposure_end + ibis.interval(days=int(gap_days + offset)))
)

ordering = [
padded.start_date,
padded._padded_end.desc(),
padded._exposure_end.desc(),
]

cumulative_window = ibis.cumulative_window(group_by=padded.person_id, order_by=ordering)
ordered_window = ibis.window(group_by=padded.person_id, order_by=ordering)

with_cummax = padded.mutate(_cummax_padded_end=padded._padded_end.max().over(cumulative_window))

with_prev = with_cummax.mutate(_prev_max=with_cummax._cummax_padded_end.lag().over(ordered_window))

marked = with_prev.mutate(
_is_new=ibis.ifelse(
with_prev._prev_max.isnull() | (with_prev._prev_max < with_prev.start_date),
ibis.literal(1, type="int64"),
ibis.literal(0, type="int64"),
)
)

group_window = ibis.cumulative_window(
group_by=marked.person_id,
order_by=[
marked.start_date,
marked._padded_end.desc(),
marked._exposure_end.desc(),
marked._is_new.desc(),
],
)
era_indexed = marked.mutate(_era_id=marked._is_new.sum().over(group_window))

collapsed = era_indexed.group_by(era_indexed.person_id, era_indexed._era_id).aggregate(
era_start_date=era_indexed.start_date.min(),
_max_padded_end=era_indexed._padded_end.max(),
)

return collapsed.select(
collapsed.person_id.cast("int64").name(PERSON_ID),
collapsed.era_start_date.cast("date").name("era_start_date"),
(collapsed._max_padded_end - ibis.interval(days=int(gap_days))).cast("date").name("era_end_date"),
)


def compute_drug_eras(
ctx,
*,
drug_codeset_id: int,
gap_days: int,
offset: int,
days_supply_override: int | None,
cohort_person_ids=None,
):
concept_ids = ctx.concept_ids_for_codeset(drug_codeset_id)

if not concept_ids:
de = ctx.table("drug_exposure")
return de.filter(ibis.literal(False)).select(
de.person_id.cast("int64").name(PERSON_ID),
ibis.null().cast("date").name("era_start_date"),
ibis.null().cast("date").name("era_end_date"),
)

de = ctx.table("drug_exposure")
if cohort_person_ids is not None:
de = de.semi_join(
cohort_person_ids,
predicates=[de.person_id == cohort_person_ids.person_id],
)

if "drug_source_concept_id" in de.columns:
filtered = de.filter(
de.drug_concept_id.isin(concept_ids) | de.drug_source_concept_id.isin(concept_ids)
)
else:
filtered = de.filter(de.drug_concept_id.isin(concept_ids))

prepared = filtered.select(
filtered.person_id.cast("int64").name("person_id"),
filtered.drug_exposure_start_date.cast("date").name("start_date"),
_compute_exposure_end_date(filtered, days_supply_override=days_supply_override).name("_exposure_end"),
)

return _compute_eras(prepared, gap_days=gap_days, offset=offset)


def apply_custom_era_strategy(events, strategy, ctx):
payload = strategy.payload
drug_codeset_id = payload["drug_codeset_id"]
gap_days = payload["gap_days"]
offset = payload["offset"]
days_supply_override = payload.get("days_supply_override")

if drug_codeset_id is None:
with_bounds = attach_observation_bounds(events, ctx)
return _replace_end_date(events, with_bounds, with_bounds.op_end_date)

cohort_person_ids = events.select(events.person_id).distinct()

eras = compute_drug_eras(
ctx,
drug_codeset_id=drug_codeset_id,
gap_days=gap_days,
offset=offset,
days_supply_override=days_supply_override,
cohort_person_ids=cohort_person_ids,
)

eras_for_join = eras.select(
eras.person_id.name("_era_person_id"),
eras.era_start_date,
eras.era_end_date,
)

with_bounds = attach_observation_bounds(events, ctx)

joined = with_bounds.left_join(
eras_for_join,
predicates=[
with_bounds.person_id == eras_for_join._era_person_id,
with_bounds[START_DATE] >= eras_for_join.era_start_date,
with_bounds[START_DATE] <= eras_for_join.era_end_date,
],
)

event_window = ibis.window(
group_by=[joined.person_id, joined.event_id],
order_by=[joined.era_end_date.asc()],
)
ranked = joined.mutate(_rn=ibis.row_number().over(event_window))
one_per_event = ranked.filter(ranked._rn == 0)

effective_end = ibis.coalesce(
one_per_event.era_end_date,
one_per_event.op_end_date,
)
final_end = ibis.least(effective_end, one_per_event.op_end_date)

return _replace_end_date(events, one_per_event, final_end)
Comment on lines +1 to +171
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are filtering the whole drug_exposure table here for these concept_ids, isn't it better to first filter to people having primary events/included events. Like CIRCE-BE does:

https://github.com/OHDSI/circe-be/blob/498893689a9cf4f09c2a43cc893bb01116db7184/src/main/resources/resources/cohortdefinition/sql/customEraStrategy.sql#L3-L13

4 changes: 3 additions & 1 deletion circe/execution/engine/end_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def apply_end_strategy(events, strategy, ctx):
return _replace_end_date(events, with_bounds, end_date_expr)

if strategy.kind == "custom_era":
raise UnsupportedFeatureError("Ibis executor end-strategy error: custom_era is not supported.")
from .custom_era import apply_custom_era_strategy

return apply_custom_era_strategy(events, strategy, ctx)

# Fallback: preserve default semantics of op_end_date clipping.
return _replace_end_date(events, with_bounds, with_bounds.op_end_date)
6 changes: 1 addition & 5 deletions circe/execution/normalize/cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ...cohortdefinition import CohortExpression
from ...vocabulary.concept import ConceptSet
from .._dataclass import frozen_slots_dataclass
from ..errors import ExecutionNormalizationError, UnsupportedFeatureError
from ..errors import ExecutionNormalizationError
from .collapse import NormalizedCollapseSettings, normalize_collapse_settings
from .criteria import NormalizedCriterion, normalize_criterion
from .end_strategy import NormalizedEndStrategy, normalize_end_strategy
Expand Down Expand Up @@ -149,10 +149,6 @@ def normalize_cohort(
)

normalized_end_strategy = normalize_end_strategy(expression.end_strategy)
if normalized_end_strategy is not None and normalized_end_strategy.kind == "custom_era":
raise UnsupportedFeatureError(
"Ibis executor normalization error: custom_era end strategy is not supported."
)

return NormalizedCohort(
title=expression.title,
Expand Down
1 change: 1 addition & 0 deletions circe/execution/normalize/end_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def normalize_end_strategy(
"drug_codeset_id": value.drug_codeset_id,
"offset": int(value.offset),
"gap_days": int(value.gap_days),
"days_supply_override": value.days_supply_override,
},
)
return NormalizedEndStrategy(kind="end_strategy", payload={})
19 changes: 10 additions & 9 deletions tests/execution/test_api_ibis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
VisitDetail,
VisitOccurrence,
)
from circe.cohortdefinition.core import CustomEraStrategy, NumericRange
from circe.execution.errors import UnsupportedFeatureError
from circe.cohortdefinition.core import NumericRange
from circe.vocabulary import Concept, ConceptSet, ConceptSetExpression, ConceptSetItem


Expand Down Expand Up @@ -1213,10 +1212,12 @@ def test_build_cohort_location_region_keeps_repeated_location_history_rows():
assert sorted(result.start_date.astype(str).tolist()) == ["2020-01-01", "2020-02-01"]


def test_build_cohort_rejects_unsupported_features():
expression = CohortExpression(
primary_criteria=PrimaryCriteria(criteria_list=[ConditionOccurrence()]),
end_strategy=CustomEraStrategy(drug_codeset_id=1, gap_days=30, offset=0),
)
with pytest.raises(UnsupportedFeatureError, match="custom_era"):
_ = build_cohort(expression, backend=object(), cdm_schema="main")
def test_build_cohort_rejects_unsupported_criteria():
"""Unsupported base criteria type is rejected at normalization time."""
from circe.cohortdefinition.criteria import Criteria as RawCriteria
from circe.execution.errors import UnsupportedCriterionError

with pytest.raises(UnsupportedCriterionError):
from circe.execution.normalize.criteria import normalize_criterion

normalize_criterion(RawCriteria())
Loading
Loading