-
Notifications
You must be signed in to change notification settings - Fork 45
Add 3D support to miss_forest_impute
#1052
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f943c76
731d61a
8e839ea
8a4412e
dd73521
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -444,8 +444,43 @@ def _knn_impute( | |
| edata.layers[layer][:, imputer_data_indices] = X_imputed | ||
|
|
||
|
|
||
| @singledispatch | ||
| def _miss_forest_impute_function(arr, num_initial_strategy, n_estimators, max_iter, random_state): | ||
| _raise_array_type_not_implemented(_miss_forest_impute_function, type(arr)) | ||
|
|
||
|
|
||
| @_miss_forest_impute_function.register(DaskArray) | ||
| def _(arr: DaskArray, num_initial_strategy, n_estimators, max_iter, random_state): | ||
| _raise_array_type_not_implemented(_miss_forest_impute_function, type(arr)) | ||
|
|
||
|
|
||
| @_miss_forest_impute_function.register(sp.coo_array) | ||
| def _(arr: sp.coo_array, num_initial_strategy, n_estimators, max_iter, random_state): | ||
| _raise_array_type_not_implemented(_miss_forest_impute_function, type(arr)) | ||
|
|
||
|
|
||
| @_miss_forest_impute_function.register(np.ndarray) | ||
| @_miss_forest_impute_function.register(sp.csr_array) | ||
| @_miss_forest_impute_function.register(sp.csc_array) | ||
| @_apply_over_time_axis | ||
| def _(arr: np.ndarray, num_initial_strategy, n_estimators, max_iter, random_state): | ||
|
|
||
| if set(range(arr.shape[1])).issubset(_get_non_numerical_column_indices(arr)): | ||
| raise ValueError( | ||
| "Can only impute numerical data. Try to restrict imputation to certain columns using var_names parameter." | ||
| ) | ||
| from sklearn.ensemble import ExtraTreesRegressor | ||
| from sklearn.impute import IterativeImputer | ||
|
|
||
| return IterativeImputer( | ||
| estimator=ExtraTreesRegressor(n_estimators=n_estimators, n_jobs=settings.n_jobs), | ||
| initial_strategy=num_initial_strategy, | ||
| max_iter=max_iter, | ||
| random_state=random_state, | ||
| ).fit_transform(arr) | ||
|
|
||
|
|
||
| @use_ehrdata(deprecated_after="1.0.0") | ||
| @function_2D_only() | ||
| @spinner("Performing miss-forest impute") | ||
| def miss_forest_impute( | ||
| edata: EHRData | AnnData, | ||
|
|
@@ -465,6 +500,9 @@ def miss_forest_impute( | |
| The strategy works by fitting a random forest model on each feature containing missing values, | ||
| and using the trained model to predict the missing values. | ||
|
|
||
| For 2D data, if layer is `None`, `edata.X` is used directly. | ||
| For 3D data, the layer is flattened along axis 0 before imputation and reshaped back to 3D afterwards. | ||
|
|
||
| See https://academic.oup.com/bioinformatics/article/28/1/112/219101. | ||
|
|
||
| If required, the data needs to be properly encoded as this imputation requires numerical data only. | ||
|
|
@@ -478,7 +516,7 @@ def miss_forest_impute( | |
| Decrease for faster computations. | ||
| random_state: The random seed for the initialization. | ||
| warning_threshold: Threshold of percentage of missing values to display a warning for. | ||
| layer: The layer to impute. | ||
| layer: The layer to impute. Required when input data is 3D. | ||
| copy: Whether to return a copy or act in place. | ||
|
|
||
| Returns: | ||
|
|
@@ -488,13 +526,29 @@ def miss_forest_impute( | |
| Examples: | ||
| >>> import ehrdata as ed | ||
| >>> import ehrapy as ep | ||
| >>> edata = ed.dt.mimic_2() | ||
| >>> edata = ep.pp.encode(edata, autodetect=True) | ||
| >>> ep.pp.miss_forest_impute(edata) | ||
| >>> edata_3d = ed.dt.ehrdata_blobs(n_variables=3, n_observations=3, base_timepoints=2, missing_values=0.3) | ||
| >>> edata_imputed = ep.pp.knn_impute(edata_3d, layer="tem_data", copy=True) | ||
|
|
||
| Example Output: | ||
|
|
||
| >>> edata_3d.layers["tem_data"][0, :, :] | ||
| [[-12.12732884, -18.37304373], | ||
| [ nan, -0.91339411], | ||
| [ nan, -7.88514984]] | ||
| >>> edata_imputed.layers["tem_data"][0, :, :] | ||
| [[-12.12732884, -18.37304373], | ||
| [ -0.3278448 , -0.91339411], | ||
| [ -4.39722201, -7.88514984]] | ||
|
|
||
| """ | ||
| if copy: | ||
| edata = edata.copy() | ||
|
|
||
| if edata.X is None and layer is None: # if edata is 3D | ||
| raise ValueError( | ||
| "3D imputation requires a layer to be specified. Pass the layer containing the full temporal data." | ||
| ) | ||
|
|
||
| if var_names is None: | ||
| _warn_imputation_threshold(edata, list(edata.var_names), threshold=warning_threshold, layer=layer) | ||
| elif isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names): | ||
|
|
@@ -505,16 +559,11 @@ def miss_forest_impute( | |
|
|
||
| patch_sklearn() | ||
|
|
||
| from sklearn.ensemble import ExtraTreesRegressor, RandomForestClassifier | ||
| from sklearn.ensemble import RandomForestClassifier | ||
| from sklearn.impute import IterativeImputer | ||
|
|
||
| try: | ||
| imp_num = IterativeImputer( | ||
| estimator=ExtraTreesRegressor(n_estimators=n_estimators, n_jobs=settings.n_jobs), | ||
| initial_strategy=num_initial_strategy, | ||
| max_iter=max_iter, | ||
| random_state=random_state, | ||
| ) | ||
| # not sure if this should be kept? | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this can be removed indeed, you're defining the Imputer in the single-dispatch. further, there is in Line 568 an unused definition of RandomForestClassifier which you can also throw out. |
||
| # initial strategy here will not be parametrized since only most_frequent will be applied to non numerical data | ||
| IterativeImputer( | ||
| estimator=RandomForestClassifier(n_estimators=n_estimators, n_jobs=settings.n_jobs), | ||
|
|
@@ -527,18 +576,26 @@ def miss_forest_impute( | |
| var_names = edata.var_names | ||
| var_indices = edata.var_names.get_indexer(var_names).tolist() | ||
|
|
||
| if set(var_indices).issubset(_get_non_numerical_column_indices(edata.X)): | ||
| raise ValueError( | ||
| "Can only impute numerical data. Try to restrict imputation to certain columns using " | ||
| "var_names parameter." | ||
| ) | ||
| mtx = edata.X if layer is None else edata.layers[layer] | ||
| input_dtype = mtx.dtype if np.issubdtype(mtx.dtype, np.floating) else np.float64 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a quick comment on why this is needed here? :) |
||
|
|
||
| if mtx.ndim == 3: | ||
| mtx_slice = mtx[:, var_indices, :].astype(input_dtype, copy=True) | ||
| else: | ||
| mtx_slice = mtx[:, var_indices].astype(input_dtype, copy=True) | ||
|
|
||
| # this step is the most expensive one and might extremely slow down the impute process | ||
| if var_indices: | ||
| if layer is None: | ||
| edata.X[::, var_indices] = imp_num.fit_transform(edata.X[::, var_indices]) | ||
| X_imputed = _miss_forest_impute_function( | ||
| mtx_slice, num_initial_strategy, n_estimators, max_iter, random_state | ||
| ) | ||
| if mtx.ndim == 3: | ||
| edata.layers[layer][:, var_indices, :] = X_imputed | ||
| else: | ||
| edata.layers[layer][::, var_indices] = imp_num.fit_transform(edata.layers[layer][::, var_indices]) | ||
| if layer is None: | ||
| edata.X[:, var_indices] = X_imputed | ||
| else: | ||
| edata.layers[layer][:, var_indices] = X_imputed | ||
| else: | ||
| raise ValueError("Cannot find any feature to perform imputation") | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -354,10 +354,34 @@ def test_knn_impute_numerical_data(impute_num_edata): | |
| _base_check_imputation(impute_num_edata, edata_imputed) | ||
|
|
||
|
|
||
| def test_missforest_impute_3D_edata(edata_blob_small): | ||
| miss_forest_impute(edata_blob_small, layer="layer_2") | ||
| with pytest.raises(ValueError, match=r"only supports 2D data"): | ||
| miss_forest_impute(edata_blob_small, layer=DEFAULT_TEM_LAYER_NAME) | ||
| @pytest.mark.parametrize("edata_mini_3D_missing_values", [True], indirect=True) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add for the "basic" test a parametrization which also checks the array types, where dask raises a valueerror is checked, and it is also checked that this works with sparse (at least in the 2D case then)? |
||
| def test_missforest_impute_3D_edata(edata_mini_3D_missing_values): | ||
| edata = edata_mini_3D_missing_values.copy() | ||
| edata_imputed = miss_forest_impute(edata, layer=DEFAULT_TEM_LAYER_NAME, copy=True) | ||
| _base_check_imputation( | ||
| edata_mini_3D_missing_values, | ||
| edata_imputed, | ||
| before_imputation_layer=DEFAULT_TEM_LAYER_NAME, | ||
| after_imputation_layer=DEFAULT_TEM_LAYER_NAME, | ||
| ) | ||
|
|
||
|
|
||
| def test_missforest_impute_3d_var_names_subset(edata_mini_3D_missing_values): | ||
| edata = edata_mini_3D_missing_values.copy() | ||
| imputed = miss_forest_impute(edata, layer=DEFAULT_TEM_LAYER_NAME, var_names=["1", "2"], copy=True) | ||
| edata_imputed = imputed[:, :2].copy() | ||
| _base_check_imputation( | ||
| edata_mini_3D_missing_values[:, :2], | ||
| edata_imputed, | ||
| before_imputation_layer=DEFAULT_TEM_LAYER_NAME, | ||
| after_imputation_layer=DEFAULT_TEM_LAYER_NAME, | ||
| ) | ||
| assert edata.shape == imputed.shape | ||
|
|
||
|
|
||
| def test_missforest_impute_3d_layer_none(edata_mini_3D_missing_values): | ||
| with pytest.raises(ValueError, match="requires a layer"): | ||
| miss_forest_impute(edata_mini_3D_missing_values, copy=True) | ||
|
|
||
|
|
||
| def test_missforest_impute_non_numerical_data(impute_edata): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is legit to consider sparse arrays for imputations and make them dense.
Could you mention this in the function docstring?