From a17953fb26d3b9f0cec103d8ffcf2064c59026ec Mon Sep 17 00:00:00 2001 From: martiinaina <106203695+martina-occhetta@users.noreply.github.com> Date: Mon, 24 Mar 2025 12:39:42 +0000 Subject: [PATCH] Update perturb_dataset.py --- src/data/perturb_dataset.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/data/perturb_dataset.py b/src/data/perturb_dataset.py index 6dbde58..75737c5 100644 --- a/src/data/perturb_dataset.py +++ b/src/data/perturb_dataset.py @@ -95,19 +95,25 @@ def __init__(self, adata, data_path, spectral_parameter, spectra_params, fm, sta self.all_perts_test = pkl.load(f) if self.data_name == "replogle_rpe1": - if not os.path.exists(f"{self.data_path}/input_features/train_data_{self.spectral_parameter}.pkl.gz"): - ctrl_adata, pert_adata, train, test, pert_list = self.preprocess_and_featurise(adata) - pp_data = self.featurise_replogle(pert_adata, pert_list, ctrl_adata, train, test) - self.X_train, self.train_target, self.X_val, self.val_target, self.X_test, self.test_target = pp_data + if not os.path.exists(f"{feature_path}/train_data_{self.spectral_parameter}.pkl.gz"): + (self.X_train, self.train_target, self.X_val, self.val_target, self.X_test, self.test_target, + self.ctrl_expr, _) = self.preprocess_and_featurise(adata) else: - with gzip.open(f"{self.data_path}/input_features/train_data_{self.spectral_parameter}.pkl.gz", - "rb") as f: + self.basal_ctrl_adata = sc.read_h5ad( + f"{self.data_path}/basal_ctrl_{self.data_name}_pp_filtered.h5ad") + with gzip.open(f"{feature_path}/train_data_{self.spectral_parameter}.pkl.gz", "rb") as f: self.X_train, self.train_target = pkl.load(f) - with gzip.open(f"{self.data_path}/input_features/val_data_{self.spectral_parameter}.pkl.gz", "rb") as f: + with gzip.open(f"{feature_path}/val_data_{self.spectral_parameter}.pkl.gz", "rb") as f: self.X_val, self.val_target = pkl.load(f) - with gzip.open(f"{self.data_path}/input_features/test_data_{self.spectral_parameter}.pkl.gz", - "rb") as f: + with gzip.open(f"{feature_path}/test_data_{self.spectral_parameter}.pkl.gz", "rb") as f: self.X_test, self.test_target = pkl.load(f) + if PerturbData.ctrl_expr_cache is None: + with open(f"{self.data_path}/raw_expression_{self.data_name}_pp_filtered.pkl", "rb") as f: + PerturbData.ctrl_expr_cache = pkl.load(f) + self.ctrl_expr = PerturbData.ctrl_expr_cache + else: + with open(f"{self.data_path}/raw_expression_{self.data_name}_pp_filtered.pkl", "rb") as f: + self.ctrl_expr = pkl.load(f) if self.data_name == "replogle_k562": if not os.path.exists(f"{feature_path}/train_data_{self.spectral_parameter}.pkl.gz"):