diff --git a/src/haddock/libs/libpdb.py b/src/haddock/libs/libpdb.py index 155078518..a12e1286c 100644 --- a/src/haddock/libs/libpdb.py +++ b/src/haddock/libs/libpdb.py @@ -1,5 +1,6 @@ """Parse molecular structures in PDB format.""" +from math import log import os from copy import deepcopy from functools import partial @@ -9,6 +10,7 @@ from pdbtools.pdb_splitchain import run as split_chain from pdbtools.pdb_splitmodel import run as split_model from pdbtools.pdb_tidy import run as tidy_pdbfile +from pdbtools.pdb_wc import run as pdb_wc from haddock.core.exceptions import SetupError from haddock.core.supported_molecules import supported_residues @@ -21,6 +23,7 @@ TypeVar, Union, ) +from haddock import log as haddock_log from haddock.libs.libio import PDBFile, working_directory from haddock.libs.libutil import get_result_or_same_in_list, sort_numbered_paths @@ -523,3 +526,67 @@ def check_mol_shape(input_mol: Path) -> bool: if any('SHA SHA ' in line for line in input_file_mol): shape = True return shape + +def handle_input_reference(reference: Path) -> list[Path]: + """Validate the reference file by returning only one model. + + Parameters + ---------- + reference : Path + Path to the input reference structure, possibly containing + an ensemble. + + Returns + ------- + reference or first_model_path : Path + Path to the reference structure to be used downstream. + """ + if reference.stat().st_size == 0: + raise ValueError(f"Reference file is empty: {reference}") + + # Extremly complicated stuff to manage the gathering of the sys.stdout, + # as the pdb_tools.pdb_wc is basically writing on it. + import sys + from io import TextIOWrapper, BytesIO + # Memorize previous sys.stdout + original_stdout = sys.stdout + # setup the new stdout environment + sys.stdout = TextIOWrapper(BytesIO(), sys.stdout.encoding) + + # Count number of models + with open(reference, "r") as fh: + pdb_wc(fh, "m") + # Get output + sys.stdout.seek(0) # Jump to the start + wc_return = sys.stdout.read() # Read output + # Restore original stdout + sys.stdout.close() + sys.stdout = original_stdout + # Parse output + # using here `\n` (and not os.linesep) as it is the output for pdb_wc + nb_models = 1 # pdb_wc treats a file without MODEL records as a single model + for line in wc_return.split("\n"): + if "No. models" in line: + nb_models = int(line.strip().split()[-1]) + break + # Return reference as only one structure present + if nb_models == 1: + return [reference] + # If more than one model in reference + haddock_log.info( + f"Multiple structures ({nb_models}) found in reference file. " + "Using all conformations as reference." + ) + # Split models + with open(reference, "r") as ref_in: + split_model(ref_in, "reference_model") + # Gather individual references and sort them + references = sorted( + list(Path(".").glob("reference_model_*.pdb")), + key=lambda k: int(k.stem.split("_")[-1]), + ) + assert len(references) == nb_models, ( + "Issue while splitting references conformation: " + f"{nb_models} detected, {len(references)} generated" + ) + return references \ No newline at end of file diff --git a/src/haddock/libs/libstructure.py b/src/haddock/libs/libstructure.py index 9f8164436..d2230d7fe 100644 --- a/src/haddock/libs/libstructure.py +++ b/src/haddock/libs/libstructure.py @@ -2,6 +2,7 @@ from functools import partial from pathlib import Path from typing import Any, Iterable, Optional +from haddock.libs.libontology import PDBFile class Molecule: @@ -42,3 +43,33 @@ def __init__(self, def make_molecules(paths: Iterable[Path], **kwargs: Any) -> list[Molecule]: """Get input molecules from the data stream.""" return list(map(partial(Molecule, **kwargs), paths)) + + +def find_ff(models: list[PDBFile]) -> str: + """Finds the force-field information (all-atom or martini) from the topology + associated to the first model. Used in caprieval and rmsdfilter. + + The assumption is that the force-fields will be identical between models. + + Parameters + ----------- + models : list[PDBFile] + List of models where to find the topology + + Return + ------- + ff : str + The force-field used in those models. + """ + try: + ff = Path(models[0].topology[0].rel_path).stem.split("_")[-1] + except TypeError: + try: + ff = Path(models[0].topology.rel_path).stem.split("_")[-1] + except AttributeError: + ff = "aa" + # In case of issue, fall back to all-atom + if "martini" not in ff: + ff = "aa" + + return ff \ No newline at end of file diff --git a/src/haddock/modules/analysis/caprieval/__init__.py b/src/haddock/modules/analysis/caprieval/__init__.py index 701ded9cb..a89b26c65 100644 --- a/src/haddock/modules/analysis/caprieval/__init__.py +++ b/src/haddock/modules/analysis/caprieval/__init__.py @@ -35,6 +35,9 @@ from haddock.core.typing import FilePath, Union from haddock.libs.libontology import PDBFile from haddock.libs.libparallel import Scheduler +from haddock.libs.libaa2cg import martinize +from haddock.libs.libstructure import find_ff +from haddock.libs.libpdb import handle_input_reference from haddock.modules import BaseHaddockModule from haddock.modules.analysis.caprieval.capri import ( CAPRI, @@ -43,9 +46,6 @@ extract_data_from_capri_class, extract_models_best_references, ) -from haddock.libs.libaa2cg import martinize -from pdbtools.pdb_wc import run as pdb_wc -from pdbtools.pdb_splitmodel import run as pdb_splitmodel RECIPE_PATH = Path(__file__).resolve().parent @@ -77,96 +77,6 @@ def is_nested(models: list[Union[PDBFile, list[PDBFile]]]) -> bool: return True return False - @staticmethod - def find_ff(models: list[PDBFile]) -> str: - """Finds the force-field information (all-atom or martini) from the topology - associated to the first model. - - The assumption is that the force-fields will be identical between models. - - Parameters - ----------- - models : list[PDBFile] - List of models where to find the topology - - Return - ------- - ff : str - The force-field used in those models. - """ - try: - ff = Path(models[0].topology[0].rel_path).stem.split("_")[-1] - except TypeError: - try: - ff = Path(models[0].topology.rel_path).stem.split("_")[-1] - except AttributeError: - ff = "aa" - # In case of issue, fall back to all-atom - if "martini" not in ff: - ff = "aa" - - return ff - - def handle_input_reference(self, reference: Path) -> list[Path]: - """Validate the reference file by returning only one model. - - Parameters - ---------- - reference : Path - Path to the input reference structure, possibly containing - an ensemble. - - Returns - ------- - reference or first_model_path : Path - Path to the reference structure to be used downstream. - """ - # Extremly complicated stuff to manage the gathering of the sys.stdout, - # as the pdb_tools.pdb_wc is basically writing on it. - import sys - from io import TextIOWrapper, BytesIO - # Memorize previous sys.stdout - original_stdout = sys.stdout - # setup the new stdout environment - sys.stdout = TextIOWrapper(BytesIO(), sys.stdout.encoding) - - # Count number of models - with open(reference, "r") as fh: - pdb_wc(fh, "m") - # Get output - sys.stdout.seek(0) # Jump to the start - wc_return = sys.stdout.read() # Read output - # Restore original stdout - sys.stdout.close() - sys.stdout = original_stdout - # Parse output - # using here `\n` (and not os.linesep) as it is the output for pdb_wc - for line in wc_return.split("\n"): - if "No. models" in line: - sline = line.strip().split() - nb_models = int(sline[-1]) - break - # Return reference as only one structure present - if nb_models == 1: - return [reference] - - self.log( - f"Multiple structures ({nb_models}) found in reference file. " - "Using all conformations as reference." - ) - # Split models - with open(reference, "r") as ref_in: - pdb_splitmodel(ref_in, "reference_model") - # Gather individual references and sort them - references = sorted( - list(Path(".").glob("reference_model_*.pdb")), - key=lambda k: int(k.stem.split("_")[-1]), - ) - assert len(references) == nb_models, ( - "Issue while splitting references conformation: " - f"{nb_models} detected, {len(references)} generated" - ) - return references def get_reference(self, models: list[PDBFile]) -> list[Path]: """Manage to obtain the reference structure to be used downstream. @@ -184,7 +94,7 @@ def get_reference(self, models: list[PDBFile]) -> list[Path]: """ if self.params["reference_fname"]: _reference = Path(self.params["reference_fname"]) - references = self.handle_input_reference(_reference) + references = handle_input_reference(_reference) else: self.log( "No reference structure provided. " @@ -216,7 +126,7 @@ def _run(self) -> None: dump_weights(self.order) # Find force-field - ff = self.find_ff(models) + ff = find_ff(models) # Get reference file if ff == "martini2": references = [ diff --git a/src/haddock/modules/analysis/rmsdfilter/__init__.py b/src/haddock/modules/analysis/rmsdfilter/__init__.py new file mode 100644 index 000000000..bf6a6f7cf --- /dev/null +++ b/src/haddock/modules/analysis/rmsdfilter/__init__.py @@ -0,0 +1,203 @@ +"""RMDS-based filtering module. + +This module calculates RMSD for input models against a user-supplied reference structure(s) +and filters out models based on RMSD threshold. The idea behind this module is to simplify +removal of a priori incorrect models, for example those generated with diffusion algorithms, +such as (partially) unfolded antibodies. + +The following file is generated: +- **rmsdfilter_ss.tsv**: a table with the global RMSD and score (if available) for each model. +- **rmsdfilter_ss_multiref.tsv**: when multiple references are provided, a table with RMSD for every (model, reference) pair. + +This module will terminate with an error message in the following cases: +* ``reference_fname`` is not set (opposite to `caprieval`). +* Alignment fails for all models, i.e. all RMSD values are NaN. +* All models exceed the RMSD threshold, i.e. no models pass through filtering. +* Models have already been clustered, i.e. models carry cluster metadata. + +For more details about this module, please `refer to the haddock3 user manual +`_ +""" + +from math import isnan +from pathlib import Path + +from haddock.core.defaults import MODULE_DEFAULT_YAML +from haddock.core.typing import FilePath, Union +from haddock.libs.libalign import get_align +from haddock.libs.libaa2cg import martinize +from haddock.libs.libontology import PDBFile +from haddock.libs.libparallel import Scheduler +from haddock.libs.libstructure import find_ff +from haddock.libs.libpdb import handle_input_reference +from haddock.modules import BaseHaddockModule +from haddock.modules.analysis.rmsdfilter.rmsdfilter import ( + RMSDFilter, + build_sorted_rows, + collect_rmsd_map, + write_rmsdfilter_multiref, + write_rmsdfilter_ss, +) + + +RECIPE_PATH = Path(__file__).resolve().parent +DEFAULT_CONFIG = Path(RECIPE_PATH, MODULE_DEFAULT_YAML) + + +class HaddockModule(BaseHaddockModule): + """HADDOCK3 module to filter models by RMSD.""" + + name = RECIPE_PATH.name + + def __init__( + self, + order: int, + path: Path, + init_params: FilePath = DEFAULT_CONFIG, + ) -> None: + super().__init__(order, path, init_params) + + @classmethod + def confirm_installation(cls) -> None: + """Confirm if module is installed.""" + return + + @staticmethod + def is_nested(models: list[Union[PDBFile, list[PDBFile]]]) -> bool: + for model in models: + if isinstance(model, list): + return True + return False + + def _run(self) -> None: + """Execute module.""" + # Get the models generated in previous step + if type(self.previous_io) == iter: + self.finish_with_error( + "[rmsdfilter] module cannot come after one that " + "produced an iterable." + ) + + # Get the models generated in previous step + models = self.previous_io.retrieve_models(individualize=True) + if self.is_nested(models): + raise ValueError( + "[rmsdfilter] module cannot be executed after " + "modules that produce a nested list of models." + ) + + # Check if cluster info is present + if any(m.clt_id is not None for m in models): + self.finish_with_error( + "Models have been clustered!" + "[rmsdfilter] cannot be performed after clustering - " + "filtering individual models after clustering would leave " + "remaining models with stale and inconsistent cluster assignments." + ) + + # load reference(s) + if self.params["reference_fname"]: + reference = Path(self.params["reference_fname"]) + references = handle_input_reference(reference) + else: + self.finish_with_error( + "[rmsdfilter] No valid reference structure(s) provided!" + "A reference structure is required for this module to work," + "please set the 'reference_fname' parameter to a valid PDB file." + ) + + # Detect force field (aa of cg) + ff = find_ff(models) + # if cg, convert reference to cg + if ff == "martini2": + references = [ + Path(martinize(ref, self.path.resolve().parent, False)) + for ref in references + ] + + # Build alignment function (to be passed to each job) + align_func = get_align( + method=self.params["alignment_method"], + lovoalign_exec=self.params["lovoalign_exec"], + keep_hetatm=self.params["keep_hetatm"], + ) + + # Create one job per (model, reference) pair + jobs: list[RMSDFilter] = [] + for i, model in enumerate(models, start=1): + for ref_id, reference in enumerate(references, start=1): + jobs.append( + RMSDFilter( + identificator=i, + model=model, + reference=reference, + path=Path("."), + params=self.params, + align_func=align_func, + ref_id=ref_id, + ) + ) + + engine = Scheduler( + tasks=jobs, + ncores=self.params["ncores"], + max_cpus=self.params["max_cpus"], + ) + engine.run() + results: list[RMSDFilter] = engine.results + + id_to_model: dict[int, PDBFile] = { + i: m for i, m in enumerate(models, start=1) + } + + rmsd_map = collect_rmsd_map(results) + + # Get sorting parameters + sortby = self.params["sortby"] + sort_ascending = self.params["sort_ascending"] + # Check if score sorting is possible + has_score = any( + m.score is not None and not (isinstance(m.score, float) and isnan(m.score)) + for m in models + ) + if sortby == "score" and not has_score: + self.log("Cannot sort models by score, falling back to sorting by RMSD.") + sortby = "rmsd" + + rows = build_sorted_rows(rmsd_map, id_to_model, sortby, sort_ascending) + # Get stats for the header of tvs file and log + valid_rows = [(m, r) for m, r in rows if not isnan(r)] + nan_count = len(rows) - len(valid_rows) + filtered = [m for m, r in valid_rows if r <= self.params["threshold"]] + percent_filtered = (1 - len(filtered) / len(models)) * 100 + # write tsv header + write_rmsdfilter_ss(rows, filtered, percent_filtered, self.params["threshold"]) + + if len(references) > 1: + write_rmsdfilter_multiref(results, id_to_model) + + if nan_count > 0: + self.log( + f"{100 * nan_count / len(rows):6.2f}% of models had NaN RMSD " + "(alignment failed) and will be excluded from filtering." + ) + + if not valid_rows: + self.finish_with_error( + "[rmsdfilter] All models have NaN RMSD - alignment failed for every model." + ) + + if not filtered: + self.finish_with_error( + f"[rmsdfilter] With threshold {self.params['threshold']:.3f} Å, " + "ALL models were filtered out !" + ) + + self.log( + f"with threshold {self.params['threshold']:.3f} Å: " + f"{percent_filtered:6.2f}% of models were filtered out, " + f"{len(filtered)} model(s) passed." + ) + + self.output_models = filtered + self.export_io_models() diff --git a/src/haddock/modules/analysis/rmsdfilter/defaults.yaml b/src/haddock/modules/analysis/rmsdfilter/defaults.yaml new file mode 100644 index 000000000..c77e3ad65 --- /dev/null +++ b/src/haddock/modules/analysis/rmsdfilter/defaults.yaml @@ -0,0 +1,98 @@ +reference_fname: + default: '' + type: file + title: Reference structure + short: Reference structure for RMSD calculation. + long: Path to the reference PDB structure used when calculating the global RMSD. + This parameter is mandatory — the module will stop with an error if it is not set. + group: analysis + explevel: easy + +threshold: + default: 5.0 + type: float + min: 0.0 + max: 99999.99 + precision: 3 + title: RMSD threshold, Å + short: Models with RMSD above this value are filtered out. + long: Models with a global RMSD above this threshold with respect to the + reference structure are filtered out. If all models exceed the threshold the + workflow will stop with an error message. + group: analysis + explevel: easy + +alignment_method: + default: sequence + type: string + minchars: 0 + maxchars: 100 + choices: + - sequence + - structure + title: Alignment method + short: Alignment method used to match residue numbering between model and reference. + long: Alignment method used to match residue numbering between model and reference. + sequence alignment is the default and works well for most cases. structure + alignment requires a LovoAlign executable to be provided via lovoalign_exec. + group: analysis + explevel: easy + +sortby: + default: rmsd + type: string + minchars: 0 + maxchars: 10 + choices: + - rmsd + - score + title: Sort output models in the output file by this value + short: Column used to sort the output TSV. + long: Column used to sort the output TSV. Use rmsd (default) to sort by global + RMSD ascending. Use score to sort by model score if available; falls back to + rmsd if no model carries a valid score. + group: analysis + explevel: easy + +sort_ascending: + default: true + type: boolean + title: Sort in ascending order + short: Sort the output TSV in ascending order. + long: Sort the output TSV in ascending order. + group: analysis + explevel: easy + +allatoms: + default: false + type: boolean + title: Use all heavy atoms + short: Use all heavy atoms (including side chains) for RMSD calculation. + long: If false (default), only backbone atoms (CA, C, N, O) are used for the + RMSD calculation. If true, all heavy atoms including side chains are used. + group: analysis + explevel: easy + +keep_hetatm: + default: false + type: boolean + title: Consider HETATM records + short: Include HETATM atoms from the reference during coordinate loading. + long: If false (default), only ATOM coordinate lines are used. If true, HETATM + records from the reference are also included. + group: analysis + explevel: easy + +lovoalign_exec: + default: '' + type: string + minchars: 0 + maxchars: 200 + title: LovoAlign executable path + short: Path to the LovoAlign executable. + long: Path to the LovoAlign executable. Only required when alignment_method is + set to structure. + group: analysis + explevel: easy + + diff --git a/src/haddock/modules/analysis/rmsdfilter/rmsdfilter.py b/src/haddock/modules/analysis/rmsdfilter/rmsdfilter.py new file mode 100644 index 000000000..4911c8d57 --- /dev/null +++ b/src/haddock/modules/analysis/rmsdfilter/rmsdfilter.py @@ -0,0 +1,182 @@ +"""RMSD calculation and output for the rmsdfilter module.""" + +import copy +from math import isnan +from pathlib import Path + +import numpy as np + +from haddock import log +from haddock.libs.libalign import ( + ALIGNError, + calc_rmsd, + centroid, + get_atoms, + kabsch, + load_coords, +) +from haddock.libs.libontology import PDBFile + + +class RMSDFilter: + """Compute RMSD between model and reference.""" + + def __init__( + self, + identificator: int, + model: PDBFile, + reference: Path, + path: Path, + params: dict, + align_func, + ref_id: int = 1, + ) -> None: + self.identificator = identificator + self.model = model + self.reference = reference + self.path = path + self.params = params + self.align_func = align_func + self.ref_id = ref_id + self.rmsd = float("nan") + + def run(self) -> "RMSDFilter": + """Compute RMSD and return self.""" + allatoms = self.params["allatoms"] + keep_hetatm = self.params["keep_hetatm"] + + # map model residue numbers onto reference residue numbers + try: + model2ref_numbering, model2ref_chain_dict = self.align_func( + self.reference, self.model, self.path + ) + except ALIGNError: + log.warning( + f"Alignment failed between {self.reference} " + f"and {self.model}, skipping..." + ) + # deepcopy so the scheduler gets an independent object with rmsd=nan + return copy.deepcopy(self) + + atoms = get_atoms(self.model, full=allatoms) + atoms.update(get_atoms(self.reference, full=allatoms)) + + ref_coord_dic, _ = load_coords( + self.reference, + atoms, + keep_hetatm=keep_hetatm, + ) + try: + # numbering_dic remaps model residues to reference numbering so that + # coordinate keys are directly comparable between the two dicts + mod_coord_dic, _ = load_coords( + self.model, + atoms, + numbering_dic=model2ref_numbering, + model2ref_chain_dict=model2ref_chain_dict, + keep_hetatm=keep_hetatm, + ) + except ALIGNError as e: + log.warning(e) + return copy.deepcopy(self) + + # Only superpose on shared atoms between model and reference + common_keys = ref_coord_dic.keys() & mod_coord_dic.keys() + if not common_keys: + log.warning( + f"No common atoms found between {self.reference} and {self.model}" + ) + return copy.deepcopy(self) + + Q = np.asarray([ref_coord_dic[k] for k in common_keys]) + P = np.asarray([mod_coord_dic[k] for k in common_keys]) + + # Kabsch superposition: centre both structures Q (ref) and P (model), + # find optimal rotation U, apply it to P, then compute RMSD on the superposed coordinates. + Q = Q - centroid(Q) + P = P - centroid(P) + U = kabsch(P, Q) + P = np.dot(P, U) + + self.rmsd = calc_rmsd(P, Q) + return copy.deepcopy(self) + +# helper functions to not clutter init.py +def _sort_key(row: tuple, sortby: str) -> float: + """Return the sort value for a (model, rmsd) row.""" + model, rmsd = row + if sortby == "score": + val = model.score + if val is None or (isinstance(val, float) and isnan(val)): + return float("inf") + return val + return float("inf") if isnan(rmsd) else rmsd + + +def collect_rmsd_map(results: list) -> dict[int, float]: + """Filter per-job results to minimum RMSD per model across all references.""" + rmsd_map: dict[int, float] = {} + for job in results: + current = rmsd_map.get(job.identificator, float("nan")) + job_rmsd = job.rmsd + if isnan(current): + rmsd_map[job.identificator] = job_rmsd + elif not isnan(job_rmsd): + rmsd_map[job.identificator] = min(current, job_rmsd) + return rmsd_map + + +def build_sorted_rows( + rmsd_map: dict[int, float], + id_to_model: dict[int, PDBFile], + sortby: str, + sort_ascending: bool, +) -> list[tuple]: + """Build a list of (model, rmsd) pairs sorted by sortby.""" + rows = [(id_to_model[i], rmsd_map[i]) for i in sorted(rmsd_map.keys())] + rows.sort(key=lambda row: _sort_key(row, sortby), reverse=not sort_ascending) + return rows + + +def write_rmsdfilter_ss( + rows: list[tuple], + filtered: list[PDBFile], + percent_filtered: float, + threshold: float, + fname: str = "rmsdfilter_ss.tsv", +) -> None: + """Write _ss_.tsv.""" + with open(fname, "w") as fh: + fh.write( + f"# RMSD filtering threshold is set to {threshold:.3f} Å; " + f"{len(filtered)} model(s) were kept; {percent_filtered:.2f}% were filtered out. " + "This file contains all models for user information.\n" + ) + fh.write("model\tscore\trmsd\n") + for model, rmsd in rows: + score_str = ( + f"{model.score:.3f}" + if model.score is not None + and not (isinstance(model.score, float) and isnan(model.score)) + else "nan" + ) + rmsd_str = f"{rmsd:.3f}" if not isnan(rmsd) else "nan" + fh.write(f"{model.rel_path}\t{score_str}\t{rmsd_str}\n") + + +def write_rmsdfilter_multiref( + results: list, + id_to_model: dict[int, PDBFile], + fname: str = "rmsdfilter_ss_multiref.tsv", +) -> None: + """Write multiref file: one row per (model, reference) pair, sorted by model name then ref_id.""" + multiref_rows = sorted( + results, + key=lambda job: (str(id_to_model[job.identificator].rel_path), job.ref_id), + ) + with open(fname, "w") as fh: + fh.write("model\tref_id\trmsd\n") + for job in multiref_rows: + model = id_to_model[job.identificator] + rmsd_str = f"{job.rmsd:.3f}" if not isnan(job.rmsd) else "nan" + fh.write(f"{model.rel_path}\t{job.ref_id}\t{rmsd_str}\n")