Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 77 additions & 20 deletions ehrapy/preprocessing/_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,43 @@ def _knn_impute(
edata.layers[layer][:, imputer_data_indices] = X_imputed


@singledispatch
def _miss_forest_impute_function(arr, num_initial_strategy, n_estimators, max_iter, random_state):
_raise_array_type_not_implemented(_miss_forest_impute_function, type(arr))


@_miss_forest_impute_function.register(DaskArray)
def _(arr: DaskArray, num_initial_strategy, n_estimators, max_iter, random_state):
_raise_array_type_not_implemented(_miss_forest_impute_function, type(arr))


@_miss_forest_impute_function.register(sp.coo_array)
def _(arr: sp.coo_array, num_initial_strategy, n_estimators, max_iter, random_state):
_raise_array_type_not_implemented(_miss_forest_impute_function, type(arr))


@_miss_forest_impute_function.register(np.ndarray)
@_miss_forest_impute_function.register(sp.csr_array)
@_miss_forest_impute_function.register(sp.csc_array)
@_apply_over_time_axis
def _(arr: np.ndarray, num_initial_strategy, n_estimators, max_iter, random_state):

if set(range(arr.shape[1])).issubset(_get_non_numerical_column_indices(arr)):
raise ValueError(
"Can only impute numerical data. Try to restrict imputation to certain columns using var_names parameter."
)
from sklearn.ensemble import ExtraTreesRegressor
from sklearn.impute import IterativeImputer

return IterativeImputer(
estimator=ExtraTreesRegressor(n_estimators=n_estimators, n_jobs=settings.n_jobs),
initial_strategy=num_initial_strategy,
max_iter=max_iter,
random_state=random_state,
).fit_transform(arr)


@use_ehrdata(deprecated_after="1.0.0")
@function_2D_only()
@spinner("Performing miss-forest impute")
def miss_forest_impute(
edata: EHRData | AnnData,
Expand All @@ -465,6 +500,9 @@ def miss_forest_impute(
The strategy works by fitting a random forest model on each feature containing missing values,
and using the trained model to predict the missing values.

For 2D data, if layer is `None`, `edata.X` is used directly.
For 3D data, the layer is flattened along axis 0 before imputation and reshaped back to 3D afterwards.

See https://academic.oup.com/bioinformatics/article/28/1/112/219101.

If required, the data needs to be properly encoded as this imputation requires numerical data only.
Expand All @@ -478,7 +516,7 @@ def miss_forest_impute(
Decrease for faster computations.
random_state: The random seed for the initialization.
warning_threshold: Threshold of percentage of missing values to display a warning for.
layer: The layer to impute.
layer: The layer to impute. Required when input data is 3D.
copy: Whether to return a copy or act in place.

Returns:
Expand All @@ -488,13 +526,29 @@ def miss_forest_impute(
Examples:
>>> import ehrdata as ed
>>> import ehrapy as ep
>>> edata = ed.dt.mimic_2()
>>> edata = ep.pp.encode(edata, autodetect=True)
>>> ep.pp.miss_forest_impute(edata)
>>> edata_3d = ed.dt.ehrdata_blobs(n_variables=3, n_observations=3, base_timepoints=2, missing_values=0.3)
>>> edata_imputed = ep.pp.knn_impute(edata_3d, layer="tem_data", copy=True)

Example Output:

>>> edata_3d.layers["tem_data"][0, :, :]
[[-12.12732884, -18.37304373],
[ nan, -0.91339411],
[ nan, -7.88514984]]
>>> edata_imputed.layers["tem_data"][0, :, :]
[[-12.12732884, -18.37304373],
[ -0.3278448 , -0.91339411],
[ -4.39722201, -7.88514984]]

"""
if copy:
edata = edata.copy()

if edata.X is None and layer is None: # if edata is 3D
raise ValueError(
"3D imputation requires a layer to be specified. Pass the layer containing the full temporal data."
)

if var_names is None:
_warn_imputation_threshold(edata, list(edata.var_names), threshold=warning_threshold, layer=layer)
elif isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names):
Expand All @@ -505,16 +559,11 @@ def miss_forest_impute(

patch_sklearn()

from sklearn.ensemble import ExtraTreesRegressor, RandomForestClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.impute import IterativeImputer

try:
imp_num = IterativeImputer(
estimator=ExtraTreesRegressor(n_estimators=n_estimators, n_jobs=settings.n_jobs),
initial_strategy=num_initial_strategy,
max_iter=max_iter,
random_state=random_state,
)
# not sure if this should be kept?
# initial strategy here will not be parametrized since only most_frequent will be applied to non numerical data
IterativeImputer(
estimator=RandomForestClassifier(n_estimators=n_estimators, n_jobs=settings.n_jobs),
Expand All @@ -527,18 +576,26 @@ def miss_forest_impute(
var_names = edata.var_names
var_indices = edata.var_names.get_indexer(var_names).tolist()

if set(var_indices).issubset(_get_non_numerical_column_indices(edata.X)):
raise ValueError(
"Can only impute numerical data. Try to restrict imputation to certain columns using "
"var_names parameter."
)
mtx = edata.X if layer is None else edata.layers[layer]
input_dtype = mtx.dtype if np.issubdtype(mtx.dtype, np.floating) else np.float64

if mtx.ndim == 3:
mtx_slice = mtx[:, var_indices, :].astype(input_dtype, copy=True)
else:
mtx_slice = mtx[:, var_indices].astype(input_dtype, copy=True)

# this step is the most expensive one and might extremely slow down the impute process
if var_indices:
if layer is None:
edata.X[::, var_indices] = imp_num.fit_transform(edata.X[::, var_indices])
X_imputed = _miss_forest_impute_function(
mtx_slice, num_initial_strategy, n_estimators, max_iter, random_state
)
if mtx.ndim == 3:
edata.layers[layer][:, var_indices, :] = X_imputed
else:
edata.layers[layer][::, var_indices] = imp_num.fit_transform(edata.layers[layer][::, var_indices])
if layer is None:
edata.X[:, var_indices] = X_imputed
else:
edata.layers[layer][:, var_indices] = X_imputed
else:
raise ValueError("Cannot find any feature to perform imputation")

Expand Down
32 changes: 28 additions & 4 deletions tests/preprocessing/test_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,34 @@ def test_knn_impute_numerical_data(impute_num_edata):
_base_check_imputation(impute_num_edata, edata_imputed)


def test_missforest_impute_3D_edata(edata_blob_small):
miss_forest_impute(edata_blob_small, layer="layer_2")
with pytest.raises(ValueError, match=r"only supports 2D data"):
miss_forest_impute(edata_blob_small, layer=DEFAULT_TEM_LAYER_NAME)
@pytest.mark.parametrize("edata_mini_3D_missing_values", [True], indirect=True)
def test_missforest_impute_3D_edata(edata_mini_3D_missing_values):
edata = edata_mini_3D_missing_values.copy()
edata_imputed = miss_forest_impute(edata, layer=DEFAULT_TEM_LAYER_NAME, copy=True)
_base_check_imputation(
edata_mini_3D_missing_values,
edata_imputed,
before_imputation_layer=DEFAULT_TEM_LAYER_NAME,
after_imputation_layer=DEFAULT_TEM_LAYER_NAME,
)


def test_missforest_impute_3d_var_names_subset(edata_mini_3D_missing_values):
edata = edata_mini_3D_missing_values.copy()
imputed = miss_forest_impute(edata, layer=DEFAULT_TEM_LAYER_NAME, var_names=["1", "2"], copy=True)
edata_imputed = imputed[:, :2].copy()
_base_check_imputation(
edata_mini_3D_missing_values[:, :2],
edata_imputed,
before_imputation_layer=DEFAULT_TEM_LAYER_NAME,
after_imputation_layer=DEFAULT_TEM_LAYER_NAME,
)
assert edata.shape == imputed.shape


def test_missforest_impute_3d_layer_none(edata_mini_3D_missing_values):
with pytest.raises(ValueError, match="requires a layer"):
miss_forest_impute(edata_mini_3D_missing_values, copy=True)


def test_missforest_impute_non_numerical_data(impute_edata):
Expand Down
Loading