Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions tests/calibration/features/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,73 @@ def test_find_matching_ions_no_matches(self):
assert match_fraction == 0.0 # 0 matches / 1 source ion
assert average_intensity == 0.0 # 0 match intensity / 1000 total intensity

def test_find_matching_ions_prevents_double_matching(self):
"""Test that each observed peak can only be matched once."""
# Two theoretical ions very close together, only one observed peak
source_mz = [100.0, 100.005]
target_mz = [100.002] # Within tolerance of both source ions
target_intensities = [1000.0]
tolerance = 0.02

match_fraction, average_intensity, matched_annotations, _ = find_matching_ions(
source_mz,
target_mz,
target_intensities,
source_annotations=["b1+1", "b2+1"],
mz_tolerance=tolerance,
)

# Only one source ion should match (the first one gets the peak)
assert match_fraction == pytest.approx(0.5) # 1 match / 2 source ions
assert len(matched_annotations) == 1
assert matched_annotations[0] == "b1+1"

def test_find_matching_ions_fallback_to_second_best(self):
"""Test fallback to next nearest peak when closest is already matched."""
# First source ion takes the middle peak, second should fall back to the other
source_mz = [100.0, 100.01]
target_mz = [100.002, 100.015] # Both within tolerance of second source
target_intensities = [1000.0, 2000.0]
tolerance = 0.02

match_fraction, average_intensity, matched_annotations, _ = find_matching_ions(
source_mz,
target_mz,
target_intensities,
source_annotations=["b1+1", "b2+1"],
mz_tolerance=tolerance,
)

# Both source ions should match (to different observed peaks)
assert match_fraction == 1.0 # 2 matches / 2 source ions
assert len(matched_annotations) == 2

def test_find_matching_ions_isotope_masking(self):
"""Test that isotope peaks are masked and not available for subsequent M0 matches."""
# First source ion at 100.0 has isotope at ~101.003 (for +1 charge)
# Second source ion at 101.0 should NOT match the isotope peak
source_mz = [100.0, 101.0]
# Observed peaks: M0 at 100.0, M+1 isotope at 101.003, and another at 150.0
target_mz = [100.0, 101.003, 150.0]
target_intensities = [1000.0, 500.0, 2000.0]
tolerance = 0.02

match_fraction, average_intensity, matched_annotations, _ = find_matching_ions(
source_mz,
target_mz,
target_intensities,
source_annotations=["b1+1", "b2+1"], # +1 charge, isotope spacing ~1.003
mz_tolerance=tolerance,
)

# First ion matches (M0 at 100.0, isotope at 101.003)
# Second ion at 101.0 should NOT match 101.003 (already claimed as isotope)
assert match_fraction == pytest.approx(0.5) # Only 1 M0 match / 2 source ions
assert len(matched_annotations) == 1
assert matched_annotations[0] == "b1+1"
# Intensity should include M0 (1000) + isotope (500) = 1500 / 3500 total
assert average_intensity == pytest.approx(1500.0 / 3500.0)


class TestModelInputHelpers:
"""Tests for validate_model_input_params and resolve_model_inputs utility functions."""
Expand Down
19 changes: 19 additions & 0 deletions tests/datasets/test_data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1510,6 +1510,25 @@ def test_create_beam_predictions_multiple_preds_sorted_by_confidence_desc(
[np.log(0.4), np.log(0.5), np.log(0.6)]
)

def test_create_beam_predictions_preserves_mztab_token_log_probabilities(
self, loader
):
"""mzTab aa_scores are already log-probabilities and must not be logged again."""
df = pl.DataFrame(
{
"index": pl.Series([0], dtype=pl.Int64),
"confidence": [0.9],
"prediction": [["P", "E", "P"]],
"token_scores": pl.Series(
[[-0.1, -0.2, -0.3]], dtype=pl.List(pl.Float64)
),
}
)

result = loader._create_beam_predictions(df, [0])

assert result[0][0].token_log_probabilities == pytest.approx([-0.1, -0.2, -0.3])

# ------------------------------------------------------------------
# load() – error handling
# ------------------------------------------------------------------
Expand Down
68 changes: 53 additions & 15 deletions winnow/calibration/features/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,32 +189,59 @@ def format_intensity_prediction_outputs(predictions: pd.DataFrame) -> pd.DataFra
########################################################


def _iter_candidates_by_distance(
target_mz: List[float], query_mz: float, insertion_point: int
) -> Iterator[Tuple[int, float]]:
"""Yield (index, distance) pairs outward from insertion point, closest first."""
left = insertion_point - 1
right = insertion_point
n = len(target_mz)

while left >= 0 or right < n:
left_dist = abs(target_mz[left] - query_mz) if left >= 0 else float("inf")
right_dist = abs(target_mz[right] - query_mz) if right < n else float("inf")

if left_dist <= right_dist:
yield left, left_dist # left neighbour is closer
left -= 1
else:
yield right, right_dist # right neighbour is closer
right += 1


def _find_peak_index(
target_mz: List[float], query_mz: float, mz_tolerance: float
target_mz: List[float],
query_mz: float,
mz_tolerance: float,
excluded_indices: set[int] | None = None,
) -> int | None:
"""Find index of peak in sorted target_mz within tolerance of query_mz.
"""Find index of nearest unmatched peak in sorted target_mz within tolerance.

Searches outward from the binary search insertion point to find the closest
peak within tolerance that is not in the excluded set.

Args:
target_mz: Sorted list of m/z values.
query_mz: The m/z value to search for.
mz_tolerance: Tolerance for matching (Daltons).
excluded_indices: Set of indices to skip (already matched peaks).

Returns:
Index of matching peak, or None if no match found.
"""
nearest = bisect.bisect_left(target_mz, query_mz)
if excluded_indices is None:
excluded_indices = set()

# Check right neighbour
if nearest < len(target_mz):
if target_mz[nearest] - query_mz < mz_tolerance:
return nearest
insertion_point = bisect.bisect_left(target_mz, query_mz)

# Check left neighbour
if nearest > 0:
if query_mz - target_mz[nearest - 1] < mz_tolerance:
return nearest - 1
# Search outward from insertion point, checking candidates by distance
for idx, dist in _iter_candidates_by_distance(target_mz, query_mz, insertion_point):
if dist >= mz_tolerance:
return None # out of tolerance and next candidates are further
if idx not in excluded_indices:
return idx # valid match (within tolerance and not already matched)

return None
return None # no valid match found


def find_matching_ions(
Expand All @@ -235,6 +262,10 @@ def find_matching_ions(
3. The list of matched theoretical ion annotations.
4. The list of matched theoretical ion m/z values.

Each observed peak can only be matched once. Once an observed peak is assigned to a
theoretical ion (either as M0 or as part of its isotopic envelope), it is excluded
from matching subsequent theoretical ions.

Args:
source_mz: List of m/z values from the source (theoretical) spectrum.
target_mz: List of m/z values from the target (observed) spectrum.
Expand All @@ -253,21 +284,25 @@ def find_matching_ions(
matched_ion_mz = []
total_target_intensity = sum(target_intensities)

# Track matched observed peak indices
matched_indices: set[int] = set()

# Decode the ion annotations to strings if they are bytes
source_annotations = [
ion_annotation.decode() if isinstance(ion_annotation, bytes) else ion_annotation
for ion_annotation in source_annotations
]

for ion_mz, ion_annotation in zip(source_mz, source_annotations):
# Find monoisotopic peak (M0)
# Find monoisotopic peak (M0), excluding already-matched peaks
source_ion_charge = extract_fragment_ion_charge(ion_annotation)
isotope_spacing = CARBON_ISOTOPE_MASS_SHIFT / source_ion_charge
m0_idx = _find_peak_index(target_mz, ion_mz, mz_tolerance)
m0_idx = _find_peak_index(target_mz, ion_mz, mz_tolerance, matched_indices)

if m0_idx is not None:
# Count match only for M0 (avoids noise inflation)
num_matches += 1
matched_indices.add(m0_idx)

# Add the ion annotation to the list of matched ion annotations
matched_ion_annotations.append(ion_annotation)
Expand All @@ -280,8 +315,11 @@ def find_matching_ions(
# Sum isotopic envelope intensities (M+1, M+2, M+3, M+4)
for i in range(1, 5):
isotope_mz = ion_mz + i * isotope_spacing
iso_idx = _find_peak_index(target_mz, isotope_mz, mz_tolerance)
iso_idx = _find_peak_index(
target_mz, isotope_mz, mz_tolerance, matched_indices
)
if iso_idx is not None:
matched_indices.add(iso_idx)
match_intensity += target_intensities[iso_idx]

return (
Expand Down
Loading