From fb27d992c9e0aae0bbcf0c1b86d1d9be14b1599a Mon Sep 17 00:00:00 2001 From: Jemma Daniel <134346753+JemmaLDaniel@users.noreply.github.com> Date: Sat, 4 Apr 2026 17:46:57 +0100 Subject: [PATCH 1/4] feat: support mgf file loading and auto-addition of spectrum ID for MZTabDatasetLoader --- winnow/configs/data_loader/mztab.yaml | 6 + winnow/datasets/data_loaders.py | 236 ++++++++++++++------------ 2 files changed, 135 insertions(+), 107 deletions(-) diff --git a/winnow/configs/data_loader/mztab.yaml b/winnow/configs/data_loader/mztab.yaml index 681cc625..95b3ba27 100644 --- a/winnow/configs/data_loader/mztab.yaml +++ b/winnow/configs/data_loader/mztab.yaml @@ -8,6 +8,12 @@ residue_masses: ${residue_masses} # Defaults to true. load_beams: true +# When true, add experiment_name + spectrum_id to spectrum data. This maps input spectra +# to search engine output predictions (e.g. Casanovo, InstaNovo). MGF files always receive +# these columns regardless of this setting. Set to false if your parquet/ipc files already +# contain a spectrum_id column. +add_index_cols: true + residue_remapping: # Used to map Casanovo-specific notations to UNIMOD tokens. "M+15.995": "M[UNIMOD:35]" # Oxidation "Q+0.984": "Q[UNIMOD:7]" # Deamidation diff --git a/winnow/datasets/data_loaders.py b/winnow/datasets/data_loaders.py index b3c3358b..dc523046 100644 --- a/winnow/datasets/data_loaders.py +++ b/winnow/datasets/data_loaders.py @@ -29,9 +29,88 @@ ) +def _df_from_matchms(spectra: list[Spectrum]) -> pl.DataFrame: + """Convert a list of Matchms spectra to a polars DataFrame. + + Includes only metadata columns that matchms exposes for at least one spectrum. + ``scan_number`` is always a 0-based enumerate index. + + Args: + spectra: List of Matchms spectrum objects. + + Returns: + The polars DataFrame. + """ + metadata_map = { + "precursor_mz": "precursor_mz", + "charge": "precursor_charge", + "retention_time": "retention_time", + } + sequence_keys = ("seq", "peptide_sequence") + + all_metadata_keys: set[str] = set() + for spectrum in spectra: + all_metadata_keys.update(spectrum.metadata.keys()) + + active_columns = { + mgf_key: col_name + for mgf_key, col_name in metadata_map.items() + if mgf_key in all_metadata_keys + } + + sequence_key = next((k for k in sequence_keys if k in all_metadata_keys), None) + + data: dict[str, list[Any]] = {"scan_number": []} + for col_name in active_columns.values(): + data[col_name] = [] + if sequence_key: + data["sequence"] = [] + data["mz_array"] = [] + data["intensity_array"] = [] + + for i, spectrum in enumerate(spectra): + data["scan_number"].append(i) + for mgf_key, col_name in active_columns.items(): + data[col_name].append(spectrum.metadata.get(mgf_key)) + if sequence_key: + data["sequence"].append(spectrum.metadata.get(sequence_key)) + data["mz_array"].append(spectrum.peaks.mz) + data["intensity_array"].append(spectrum.peaks.intensities) + + return pl.DataFrame(data) + + +def _add_index_cols(df: pl.DataFrame, fp: Path | str) -> pl.DataFrame: + """Add ``experiment_name`` and ``spectrum_id`` columns. + + If ``scan_number`` is present, ``spectrum_id`` is ``experiment_name:scan_number``. + Otherwise uses a row index, matching InstaNovo's data_handler fallback. + """ + exp_name = Path(fp).stem + df = df.with_columns(pl.lit(exp_name).alias("experiment_name").cast(pl.Utf8)) + if "scan_number" in df.columns: + df = df.with_columns( + ( + pl.col("experiment_name") + ":" + pl.col("scan_number").cast(pl.Utf8) + ).alias("spectrum_id") + ) + else: + df = df.with_row_index("idx") + df = df.with_columns( + (pl.col("experiment_name") + ":" + pl.col("idx").cast(pl.Utf8)).alias( + "spectrum_id" + ) + ) + df = df.drop("idx") + return df + + class InstaNovoDatasetLoader(DatasetLoader): """Loader for InstaNovo predictions in CSV format.""" + _df_from_matchms = staticmethod(_df_from_matchms) + _add_index_cols = staticmethod(_add_index_cols) + def __init__( self, residue_masses: dict[str, float], @@ -59,84 +138,6 @@ def __init__( self.beam_columns = beam_columns self.add_index_cols = add_index_cols - @staticmethod - def _df_from_matchms(spectra: list[Spectrum]) -> pl.DataFrame: - """Convert a list of Matchms spectra to a polars DataFrame. - - Includes only metadata columns that matchms exposes for at least one spectrum. - ``scan_number`` is always a 0-based enumerate index. - - Args: - spectra: List of Matchms spectrum objects. - - Returns: - The polars DataFrame. - """ - metadata_map = { - "precursor_mz": "precursor_mz", - "charge": "precursor_charge", - "retention_time": "retention_time", - } - sequence_keys = ("seq", "peptide_sequence") - - all_metadata_keys: set[str] = set() - for spectrum in spectra: - all_metadata_keys.update(spectrum.metadata.keys()) - - active_columns = { - mgf_key: col_name - for mgf_key, col_name in metadata_map.items() - if mgf_key in all_metadata_keys - } - - sequence_key = next((k for k in sequence_keys if k in all_metadata_keys), None) - - data: dict[str, list[Any]] = {"scan_number": []} - for col_name in active_columns.values(): - data[col_name] = [] - if sequence_key: - data["sequence"] = [] - data["mz_array"] = [] - data["intensity_array"] = [] - - for i, spectrum in enumerate(spectra): - data["scan_number"].append(i) - for mgf_key, col_name in active_columns.items(): - data[col_name].append(spectrum.metadata.get(mgf_key)) - if sequence_key: - data["sequence"].append(spectrum.metadata.get(sequence_key)) - data["mz_array"].append(spectrum.peaks.mz) - data["intensity_array"].append(spectrum.peaks.intensities) - - return pl.DataFrame(data) - - @staticmethod - def _add_index_cols(df: pl.DataFrame, fp: Path | str) -> pl.DataFrame: - """Add ``experiment_name`` and ``spectrum_id`` to align with InstaNovo CSV output. - - If ``scan_number`` is present, ``spectrum_id`` is ``experiment_name:scan_number``. - Otherwise uses a row index, matching InstaNovo's data_handler fallback. - """ - exp_name = Path(fp).stem - df = df.with_columns(pl.lit(exp_name).alias("experiment_name").cast(pl.Utf8)) - if "scan_number" in df.columns: - df = df.with_columns( - ( - pl.col("experiment_name") - + ":" - + pl.col("scan_number").cast(pl.Utf8) - ).alias("spectrum_id") - ) - else: - df = df.with_row_index("idx") - df = df.with_columns( - (pl.col("experiment_name") + ":" + pl.col("idx").cast(pl.Utf8)).alias( - "spectrum_id" - ) - ) - df = df.drop("idx") - return df - @staticmethod def _merge_spectrum_data( preds_dataset: pd.DataFrame, spectrum_dataset: pd.DataFrame @@ -536,13 +537,19 @@ class MZTabDatasetLoader(DatasetLoader): If missing (traditional search engines), token_log_probabilities will be set to None Expected Spectrum Data Format: - - Parquet or IPC file with spectrum metadata + - Parquet, IPC, or MGF file with spectrum metadata - Row indices should match the extracted indices from MZTab spectra_ref - Optional 'sequence' column for ground truth labels Note: The loader handles both single prediction per spectrum and multiple predictions per spectrum, creating beam predictions with List[ScoredSequence] structure. Works with - both traditional database search engines and Casanovo outputs, returning a single beam prediction if only one prediction is present. + both traditional database search engines and Casanovo outputs, returning a single beam + prediction if only one prediction is present. + + When ``add_index_cols`` is enabled (or for MGF inputs), ``experiment_name`` and + ``spectrum_id`` columns are added to map input spectra to search engine output + predictions. The ``spectrum_id`` uses the format ``{file_stem}:{row_index}``, + which aligns with the 0-based index extracted from MZTab ``spectra_ref``. """ def __init__( @@ -551,6 +558,7 @@ def __init__( residue_remapping: dict[str, str], isotope_error_range: Tuple[int, int] = (0, 1), load_beams: bool = True, + add_index_cols: bool = False, ) -> None: """Initialise the MZTabDatasetLoader. @@ -561,6 +569,9 @@ def __init__( load_beams: Whether to load beam predictions. If False, dataset.predictions will be None. Set to False if you only need metadata features and want to skip beam processing or do not have beams. Defaults to True. + add_index_cols: If True, add ``experiment_name`` and ``spectrum_id`` to + parquet/ipc inputs. MGF inputs always get these columns regardless of + this flag. Defaults to False. """ self.metrics = Metrics( residue_set=ResidueSet( @@ -569,33 +580,7 @@ def __init__( isotope_error_range=isotope_error_range, ) self.load_beams = load_beams - - @staticmethod - def _load_spectrum_data(spectrum_path: Path | str) -> Tuple[pl.DataFrame, bool]: - """Load spectrum data from either a Parquet or IPC file. - - Args: - spectrum_path: Path to spectrum data file - - Returns: - DataFrame containing spectrum data - """ - spectrum_path = Path(spectrum_path) - has_labels = False - - if spectrum_path.suffix == ".parquet": - df = pl.read_parquet(spectrum_path) - elif spectrum_path.suffix == ".ipc": - df = pl.read_ipc(spectrum_path) - else: - raise ValueError( - f"Unsupported file format for spectrum data: {spectrum_path.suffix}. Supported formats are .parquet and .ipc." - ) - - if "sequence" in df.columns: - has_labels = True - - return df, has_labels + self.add_index_cols = add_index_cols @staticmethod def _load_dataset(predictions_path: Path | str) -> pl.DataFrame: @@ -669,6 +654,43 @@ def load( return CalibrationDataset(metadata=metadata_pd, predictions=beam_predictions) + def _load_spectrum_data( + self, spectrum_path: Path | str + ) -> Tuple[pl.DataFrame, bool]: + """Load spectrum data from a Parquet, IPC, or MGF file. + + Args: + spectrum_path: Path to spectrum data file (.parquet, .ipc, or .mgf) + + Returns: + Tuple of (DataFrame containing spectrum data, whether ground truth labels exist) + """ + spectrum_path = Path(spectrum_path) + has_labels = False + + if spectrum_path.suffix == ".parquet": + df = pl.read_parquet(spectrum_path) + elif spectrum_path.suffix == ".ipc": + df = pl.read_ipc(spectrum_path) + elif spectrum_path.suffix == ".mgf": + from matchms.importing import load_from_mgf + + spectra = list(load_from_mgf(str(spectrum_path))) + df = _df_from_matchms(spectra) + else: + raise ValueError( + f"Unsupported file format for spectrum data: {spectrum_path.suffix}. " + "Supported formats are .parquet, .ipc and .mgf." + ) + + if spectrum_path.suffix == ".mgf" or self.add_index_cols: + df = _add_index_cols(df, spectrum_path) + + if "sequence" in df.columns: + has_labels = True + + return df, has_labels + def _process_predictions( self, predictions: pl.DataFrame, spectrum_data_columns: List[str] ) -> pl.DataFrame: From abcdd5bb9649f75c8cbf7dd4dee81f94c9271e62 Mon Sep 17 00:00:00 2001 From: Jemma Daniel <134346753+JemmaLDaniel@users.noreply.github.com> Date: Sat, 4 Apr 2026 17:47:26 +0100 Subject: [PATCH 2/4] test: add tests for auto-addition of spectrum ID and mgf file loading for MZTabDatasetLoader --- tests/datasets/test_data_loaders.py | 120 +++++++++++++++++++++++++--- 1 file changed, 110 insertions(+), 10 deletions(-) diff --git a/tests/datasets/test_data_loaders.py b/tests/datasets/test_data_loaders.py index caa43747..9db1d61d 100644 --- a/tests/datasets/test_data_loaders.py +++ b/tests/datasets/test_data_loaders.py @@ -1150,44 +1150,144 @@ def test_only_unimod_notation_needed_in_invalid_list( # _load_spectrum_data # ------------------------------------------------------------------ - def test_load_spectrum_data_raises_for_unsupported_extension(self, tmp_path): + def test_load_spectrum_data_raises_for_unsupported_extension( + self, loader, tmp_path + ): path = tmp_path / "data.tsv" path.touch() with pytest.raises(ValueError, match="Unsupported file format"): - MZTabDatasetLoader._load_spectrum_data(path) + loader._load_spectrum_data(path) - def test_load_spectrum_data_reads_parquet(self, tmp_path): + def test_load_spectrum_data_reads_parquet(self, loader, tmp_path): df = pl.DataFrame({"charge": [2], "mz_array": [[100.0]]}) path = tmp_path / "data.parquet" df.write_parquet(path) - result_df, _ = MZTabDatasetLoader._load_spectrum_data(path) + result_df, _ = loader._load_spectrum_data(path) assert "charge" in result_df.columns - def test_load_spectrum_data_reads_ipc(self, tmp_path): + def test_load_spectrum_data_reads_ipc(self, loader, tmp_path): df = pl.DataFrame({"charge": [2], "mz_array": [[100.0]]}) path = tmp_path / "data.ipc" df.write_ipc(path) - result_df, _ = MZTabDatasetLoader._load_spectrum_data(path) + result_df, _ = loader._load_spectrum_data(path) assert "charge" in result_df.columns - def test_load_spectrum_data_detects_labels_when_sequence_present(self, tmp_path): + def test_load_spectrum_data_detects_labels_when_sequence_present( + self, loader, tmp_path + ): df = pl.DataFrame({"sequence": ["PEPTIDE"], "charge": [2]}) path = tmp_path / "data.parquet" df.write_parquet(path) - _, has_labels = MZTabDatasetLoader._load_spectrum_data(path) + _, has_labels = loader._load_spectrum_data(path) assert has_labels is True - def test_load_spectrum_data_no_labels_when_sequence_absent(self, tmp_path): + def test_load_spectrum_data_no_labels_when_sequence_absent(self, loader, tmp_path): df = pl.DataFrame({"charge": [2]}) path = tmp_path / "data.parquet" df.write_parquet(path) - _, has_labels = MZTabDatasetLoader._load_spectrum_data(path) + _, has_labels = loader._load_spectrum_data(path) + assert has_labels is False + + def test_load_spectrum_data_reads_mgf(self, loader, tmp_path): + mgf_path = tmp_path / "spectra.mgf" + mgf_path.write_text( + "BEGIN IONS\n" + "PEPMASS=500.0\n" + "CHARGE=2+\n" + "RTINSECONDS=100.0\n" + "100.0 1.0\n" + "END IONS\n", + encoding="utf-8", + ) + result_df, has_labels = loader._load_spectrum_data(mgf_path) + assert "mz_array" in result_df.columns + assert "intensity_array" in result_df.columns + assert "precursor_mz" in result_df.columns + assert "precursor_charge" in result_df.columns + assert has_labels is False + + def test_load_spectrum_data_mgf_always_adds_index_cols(self, tmp_path): + """MGF inputs get experiment_name and spectrum_id regardless of add_index_cols.""" + loader = MZTabDatasetLoader( + residue_masses=_FULL_RESIDUE_MASSES, + residue_remapping=_STANDARD_REMAPPING, + add_index_cols=False, + ) + mgf_path = tmp_path / "spectra.mgf" + mgf_path.write_text( + "BEGIN IONS\n" "PEPMASS=500.0\n" "CHARGE=2+\n" "100.0 1.0\n" "END IONS\n", + encoding="utf-8", + ) + result_df, _ = loader._load_spectrum_data(mgf_path) + assert "experiment_name" in result_df.columns + assert "spectrum_id" in result_df.columns + assert result_df["spectrum_id"][0] == "spectra:0" + + def test_load_spectrum_data_mgf_has_labels(self, loader, tmp_path): + mgf_path = tmp_path / "labeled.mgf" + mgf_path.write_text( + "BEGIN IONS\n" + "PEPMASS=500.0\n" + "CHARGE=2+\n" + "SEQ=PEPTIDE\n" + "100.0 1.0\n" + "END IONS\n", + encoding="utf-8", + ) + _, has_labels = loader._load_spectrum_data(mgf_path) + assert has_labels is True + + def test_load_spectrum_data_mgf_no_labels(self, loader, tmp_path): + mgf_path = tmp_path / "unlabeled.mgf" + mgf_path.write_text( + "BEGIN IONS\n" "PEPMASS=500.0\n" "CHARGE=2+\n" "100.0 1.0\n" "END IONS\n", + encoding="utf-8", + ) + _, has_labels = loader._load_spectrum_data(mgf_path) assert has_labels is False + def test_load_spectrum_data_parquet_adds_index_cols_when_enabled(self, tmp_path): + loader = MZTabDatasetLoader( + residue_masses=_FULL_RESIDUE_MASSES, + residue_remapping=_STANDARD_REMAPPING, + add_index_cols=True, + ) + df = pl.DataFrame({"charge": [2], "mz_array": [[100.0]]}) + path = tmp_path / "spec.parquet" + df.write_parquet(path) + result_df, _ = loader._load_spectrum_data(path) + assert "experiment_name" in result_df.columns + assert "spectrum_id" in result_df.columns + assert result_df["spectrum_id"][0] == "spec:0" + + def test_load_spectrum_data_ipc_adds_index_cols_when_enabled(self, tmp_path): + loader = MZTabDatasetLoader( + residue_masses=_FULL_RESIDUE_MASSES, + residue_remapping=_STANDARD_REMAPPING, + add_index_cols=True, + ) + df = pl.DataFrame({"charge": [2], "mz_array": [[100.0]]}) + path = tmp_path / "spec.ipc" + df.write_ipc(path) + result_df, _ = loader._load_spectrum_data(path) + assert "experiment_name" in result_df.columns + assert "spectrum_id" in result_df.columns + assert result_df["spectrum_id"][0] == "spec:0" + + def test_load_spectrum_data_parquet_no_index_cols_by_default( + self, loader, tmp_path + ): + df = pl.DataFrame({"charge": [2], "mz_array": [[100.0]]}) + path = tmp_path / "data.parquet" + df.write_parquet(path) + result_df, _ = loader._load_spectrum_data(path) + assert "experiment_name" not in result_df.columns + assert "spectrum_id" not in result_df.columns + # ------------------------------------------------------------------ # _load_dataset # ------------------------------------------------------------------ From 72f017e5e8b0a18a928b986a81ea67b619d8aa19 Mon Sep 17 00:00:00 2001 From: Jemma Daniel <134346753+JemmaLDaniel@users.noreply.github.com> Date: Mon, 15 Jun 2026 16:19:31 +0100 Subject: [PATCH 3/4] chore: default to not adding spectrum_id columns --- winnow/configs/data_loader/mztab.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/winnow/configs/data_loader/mztab.yaml b/winnow/configs/data_loader/mztab.yaml index 95b3ba27..cb7935bb 100644 --- a/winnow/configs/data_loader/mztab.yaml +++ b/winnow/configs/data_loader/mztab.yaml @@ -12,7 +12,7 @@ load_beams: true # to search engine output predictions (e.g. Casanovo, InstaNovo). MGF files always receive # these columns regardless of this setting. Set to false if your parquet/ipc files already # contain a spectrum_id column. -add_index_cols: true +add_index_cols: false residue_remapping: # Used to map Casanovo-specific notations to UNIMOD tokens. "M+15.995": "M[UNIMOD:35]" # Oxidation From cb3b7b59fbe951e95f8a9eec91a2486b99f3d671 Mon Sep 17 00:00:00 2001 From: Jemma Daniel <134346753+JemmaLDaniel@users.noreply.github.com> Date: Mon, 15 Jun 2026 16:20:06 +0100 Subject: [PATCH 4/4] docs: update data loader documentation to clarify matching protocol --- docs/api/datasets.md | 37 +++++++++++++++++++++++++++++++++++++ docs/cli.md | 6 +++--- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/docs/api/datasets.md b/docs/api/datasets.md index 7fae70d1..65942368 100644 --- a/docs/api/datasets.md +++ b/docs/api/datasets.md @@ -90,6 +90,43 @@ loader = WinnowDatasetLoader() dataset = loader.load(data_path=Path("saved_dataset_directory")) ``` +#### Spectrum and prediction matching + +InstaNovo and MZTab use **different join keys** to pair spectrum rows with prediction rows. + +##### InstaNovo: join on `spectrum_id` + +The InstaNovo loader inner-joins the predictions CSV and the spectrum file on a shared `spectrum_id` column. Every prediction row must match exactly one spectrum row. + +- The predictions CSV must include `spectrum_id`. +- The spectrum file must include the same `spectrum_id` values (unless you use `add_index_cols`; see below). +- After merging, the row count must equal the number of prediction rows. If any prediction has no matching spectrum, or if `spectrum_id` is duplicated in the spectrum file, loading raises a `ValueError`. + +InstaNovo predictions from an MGF file use **0-based spectrum position** in the MGF, not the instrument scan number. If your parquet uses `experiment_name:scan_number` IDs but the CSV uses `{stem}:0`, `{stem}:1`, ... you must re-key the parquet (see `scripts/rekey_split_parquet_spectrum_ids.py`) or set `add_index_cols` appropriately. + +##### `add_index_cols` (InstaNovo and MZTab) + +Controls whether Winnow synthesises `experiment_name` and `spectrum_id` when loading parquet or IPC spectrum files. MGF inputs **always** receive these columns regardless of this setting. Configure via `data_loader.add_index_cols` in the loader YAML (see [configuration guide](../configuration.md#data-loader-configs)). + +When `add_index_cols: true`: + +- `experiment_name` is set to the spectrum file stem (e.g. `spectra` for `spectra.parquet`). +- If a `scan_number` column exists, `spectrum_id` is `{file_stem}:{scan_number}`. +- Otherwise, `spectrum_id` is `{file_stem}:{row_index}` (0-based row position in the file). + +##### MZTab: join on row index from `spectra_ref` + +The MZTab loader does **not** use `spectrum_id` to match predictions. Instead: + +1. Each PSM row in the `.mztab` file provides `spectra_ref` (e.g. `ms_run[1]:index=123`). +2. Winnow extracts the integer after `index=` as the join key. +3. Spectrum rows are numbered 0, 1, 2, … in **file order** via a row index. +4. Predictions and spectra are inner-joined on that numeric index. + +So row 0 in your parquet/IPC/MGF must correspond to `index=0` in the mzTab file, row 1 to `index=1`, and so on. Spectra with no matching prediction (or vice versa) are dropped. + +A pre-existing `spectrum_id` column (e.g. `{file_stem}:{scan_number}`) is preserved in the output metadata but **ignored for matching**. With `add_index_cols: false`, Winnow does not add or overwrite identifier columns in parquet/IPC files. + ### PSMDataset A dataset containing multiple peptide-spectrum matches (PSMs). Provides a container for managing collections of PSMs with iteration and indexing support. diff --git a/docs/cli.md b/docs/cli.md index 0adc331f..6a4757e3 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -179,7 +179,7 @@ For training (`winnow train`), you need: - **Labelled dataset**: Ground truth peptide sequences for evaluation - **Predictions**: Model predictions with confidence scores - **Spectral data**: MS/MS spectra and metadata -- **Unique identifiers**: Each PSM must have a unique `spectrum_id` in both input files +- **Aligned identifiers**: Spectrum and prediction rows must refer to the same spectra ### Feature export (`winnow compute-features`) @@ -192,14 +192,14 @@ For prediction (`winnow predict`), you need: - **Unlabelled dataset**: Predictions and spectra (no ground truth required for non-parametric FDR) - **Trained model**: Pretrained model from Hugging Face or output from `winnow train` - **Confidence scores**: Raw confidence values to calibrate -- **Unique identifiers**: Each PSM must have a unique `spectrum_id` in both input files +- **Aligned identifiers**: Same matching rules as training ### Data formats Winnow supports multiple input formats: - **InstaNovo**: Parquet, IPC, or MGF spectra + CSV predictions (beam search format) -- **MZTab**: Parquet or IPC spectra + MZTab predictions +- **MZTab**: Parquet, IPC or MGF spectra + MZTab predictions - **PointNovo**: Similar to InstaNovo format - **Winnow**: Internal format (directory with metadata.csv and predictions.pkl)