diff --git a/ehrapy/preprocessing/_imputation.py b/ehrapy/preprocessing/_imputation.py index 04e890bf..92d46eb6 100644 --- a/ehrapy/preprocessing/_imputation.py +++ b/ehrapy/preprocessing/_imputation.py @@ -444,8 +444,43 @@ def _knn_impute( edata.layers[layer][:, imputer_data_indices] = X_imputed +@singledispatch +def _miss_forest_impute_function(arr, num_initial_strategy, n_estimators, max_iter, random_state): + _raise_array_type_not_implemented(_miss_forest_impute_function, type(arr)) + + +@_miss_forest_impute_function.register(DaskArray) +def _(arr: DaskArray, num_initial_strategy, n_estimators, max_iter, random_state): + _raise_array_type_not_implemented(_miss_forest_impute_function, type(arr)) + + +@_miss_forest_impute_function.register(sp.coo_array) +def _(arr: sp.coo_array, num_initial_strategy, n_estimators, max_iter, random_state): + _raise_array_type_not_implemented(_miss_forest_impute_function, type(arr)) + + +@_miss_forest_impute_function.register(np.ndarray) +@_miss_forest_impute_function.register(sp.csr_array) +@_miss_forest_impute_function.register(sp.csc_array) +@_apply_over_time_axis +def _(arr: np.ndarray, num_initial_strategy, n_estimators, max_iter, random_state): + + if set(range(arr.shape[1])).issubset(_get_non_numerical_column_indices(arr)): + raise ValueError( + "Can only impute numerical data. Try to restrict imputation to certain columns using var_names parameter." + ) + from sklearn.ensemble import ExtraTreesRegressor + from sklearn.impute import IterativeImputer + + return IterativeImputer( + estimator=ExtraTreesRegressor(n_estimators=n_estimators, n_jobs=settings.n_jobs), + initial_strategy=num_initial_strategy, + max_iter=max_iter, + random_state=random_state, + ).fit_transform(arr) + + @use_ehrdata(deprecated_after="1.0.0") -@function_2D_only() @spinner("Performing miss-forest impute") def miss_forest_impute( edata: EHRData | AnnData, @@ -465,6 +500,9 @@ def miss_forest_impute( The strategy works by fitting a random forest model on each feature containing missing values, and using the trained model to predict the missing values. + For 2D data, if layer is `None`, `edata.X` is used directly. + For 3D data, the layer is flattened along axis 0 before imputation and reshaped back to 3D afterwards. + See https://academic.oup.com/bioinformatics/article/28/1/112/219101. If required, the data needs to be properly encoded as this imputation requires numerical data only. @@ -478,7 +516,7 @@ def miss_forest_impute( Decrease for faster computations. random_state: The random seed for the initialization. warning_threshold: Threshold of percentage of missing values to display a warning for. - layer: The layer to impute. + layer: The layer to impute. Required when input data is 3D. copy: Whether to return a copy or act in place. Returns: @@ -488,13 +526,29 @@ def miss_forest_impute( Examples: >>> import ehrdata as ed >>> import ehrapy as ep - >>> edata = ed.dt.mimic_2() - >>> edata = ep.pp.encode(edata, autodetect=True) - >>> ep.pp.miss_forest_impute(edata) + >>> edata_3d = ed.dt.ehrdata_blobs(n_variables=3, n_observations=3, base_timepoints=2, missing_values=0.3) + >>> edata_imputed = ep.pp.knn_impute(edata_3d, layer="tem_data", copy=True) + + Example Output: + + >>> edata_3d.layers["tem_data"][0, :, :] + [[-12.12732884, -18.37304373], + [ nan, -0.91339411], + [ nan, -7.88514984]] + >>> edata_imputed.layers["tem_data"][0, :, :] + [[-12.12732884, -18.37304373], + [ -0.3278448 , -0.91339411], + [ -4.39722201, -7.88514984]] + """ if copy: edata = edata.copy() + if edata.X is None and layer is None: # if edata is 3D + raise ValueError( + "3D imputation requires a layer to be specified. Pass the layer containing the full temporal data." + ) + if var_names is None: _warn_imputation_threshold(edata, list(edata.var_names), threshold=warning_threshold, layer=layer) elif isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names): @@ -505,16 +559,11 @@ def miss_forest_impute( patch_sklearn() - from sklearn.ensemble import ExtraTreesRegressor, RandomForestClassifier + from sklearn.ensemble import RandomForestClassifier from sklearn.impute import IterativeImputer try: - imp_num = IterativeImputer( - estimator=ExtraTreesRegressor(n_estimators=n_estimators, n_jobs=settings.n_jobs), - initial_strategy=num_initial_strategy, - max_iter=max_iter, - random_state=random_state, - ) + # not sure if this should be kept? # initial strategy here will not be parametrized since only most_frequent will be applied to non numerical data IterativeImputer( estimator=RandomForestClassifier(n_estimators=n_estimators, n_jobs=settings.n_jobs), @@ -527,18 +576,26 @@ def miss_forest_impute( var_names = edata.var_names var_indices = edata.var_names.get_indexer(var_names).tolist() - if set(var_indices).issubset(_get_non_numerical_column_indices(edata.X)): - raise ValueError( - "Can only impute numerical data. Try to restrict imputation to certain columns using " - "var_names parameter." - ) + mtx = edata.X if layer is None else edata.layers[layer] + input_dtype = mtx.dtype if np.issubdtype(mtx.dtype, np.floating) else np.float64 + + if mtx.ndim == 3: + mtx_slice = mtx[:, var_indices, :].astype(input_dtype, copy=True) + else: + mtx_slice = mtx[:, var_indices].astype(input_dtype, copy=True) # this step is the most expensive one and might extremely slow down the impute process if var_indices: - if layer is None: - edata.X[::, var_indices] = imp_num.fit_transform(edata.X[::, var_indices]) + X_imputed = _miss_forest_impute_function( + mtx_slice, num_initial_strategy, n_estimators, max_iter, random_state + ) + if mtx.ndim == 3: + edata.layers[layer][:, var_indices, :] = X_imputed else: - edata.layers[layer][::, var_indices] = imp_num.fit_transform(edata.layers[layer][::, var_indices]) + if layer is None: + edata.X[:, var_indices] = X_imputed + else: + edata.layers[layer][:, var_indices] = X_imputed else: raise ValueError("Cannot find any feature to perform imputation") diff --git a/tests/preprocessing/test_imputation.py b/tests/preprocessing/test_imputation.py index 11f64474..31bcb48f 100644 --- a/tests/preprocessing/test_imputation.py +++ b/tests/preprocessing/test_imputation.py @@ -354,10 +354,34 @@ def test_knn_impute_numerical_data(impute_num_edata): _base_check_imputation(impute_num_edata, edata_imputed) -def test_missforest_impute_3D_edata(edata_blob_small): - miss_forest_impute(edata_blob_small, layer="layer_2") - with pytest.raises(ValueError, match=r"only supports 2D data"): - miss_forest_impute(edata_blob_small, layer=DEFAULT_TEM_LAYER_NAME) +@pytest.mark.parametrize("edata_mini_3D_missing_values", [True], indirect=True) +def test_missforest_impute_3D_edata(edata_mini_3D_missing_values): + edata = edata_mini_3D_missing_values.copy() + edata_imputed = miss_forest_impute(edata, layer=DEFAULT_TEM_LAYER_NAME, copy=True) + _base_check_imputation( + edata_mini_3D_missing_values, + edata_imputed, + before_imputation_layer=DEFAULT_TEM_LAYER_NAME, + after_imputation_layer=DEFAULT_TEM_LAYER_NAME, + ) + + +def test_missforest_impute_3d_var_names_subset(edata_mini_3D_missing_values): + edata = edata_mini_3D_missing_values.copy() + imputed = miss_forest_impute(edata, layer=DEFAULT_TEM_LAYER_NAME, var_names=["1", "2"], copy=True) + edata_imputed = imputed[:, :2].copy() + _base_check_imputation( + edata_mini_3D_missing_values[:, :2], + edata_imputed, + before_imputation_layer=DEFAULT_TEM_LAYER_NAME, + after_imputation_layer=DEFAULT_TEM_LAYER_NAME, + ) + assert edata.shape == imputed.shape + + +def test_missforest_impute_3d_layer_none(edata_mini_3D_missing_values): + with pytest.raises(ValueError, match="requires a layer"): + miss_forest_impute(edata_mini_3D_missing_values, copy=True) def test_missforest_impute_non_numerical_data(impute_edata):