Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- [[#586](https://github.com/nf-core/proteinfold/pull/586)] - Allow local msa for Boltz with non-protein entities.
- [[#618](https://github.com/nf-core/proteinfold/pull/618)] - Resolve boltz `ext.args` in closure.
- [[PR #626](https://github.com/nf-core/proteinfold/pull/618)] - Move scientific validation tests and BioPython setup to manual workflow.
- [[#619](https://github.com/nf-core/proteinfold/issues/619)] - Fix `extract_metrics.py` shebang to use `python3` for compatibility with minimal containers.
- [[#209](https://github.com/nf-core/proteinfold/issues/209)] - Prevent local ColabFold runs from enabling remote template lookups unless `--colabfold_template_path` is provided.
- [[#456](https://github.com/nf-core/proteinfold/issues/456)] - Derive ranked metric ordering from structure filenames when generating TSV outputs.
- [[#489](https://github.com/nf-core/proteinfold/issues/489)] - Specified Boltz output paths on `boltz_results_<sample_id>/`.
- [[#576](https://github.com/nf-core/proteinfold/issues/576)] - Preserve native metric rank numbering and sort rank-derived outputs numerically.

| Old parameter | New parameter |
| -------------------------- | --------------- |
Expand Down
103 changes: 64 additions & 39 deletions bin/extract_metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env python
#!/usr/bin/env python3

import pickle
import os
Expand All @@ -11,6 +11,7 @@
import numpy as np
import csv
import string
import re
from utils import plddt_from_struct_b_factor, get_chain_ids

# TODO: Issue #309, make into a proper separate process, it its own module so that dependencies can be managed better
Expand Down Expand Up @@ -85,13 +86,15 @@ def idx_to_letter(idx):
break
return result

sorted_entries = sorted(chain_pair_entries.items(), key=lambda item: sort_model_label(item[0]))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this sorted() lambda approach and clear index access to the first sorted entry here 👍


if chain_ids:
#would be better with some model_id sorting
iptm_rows = [[""]+[f"{chain_ids[idx[0]]}:{chain_ids[idx[1]]}" for idx, val in next(iter(chain_pair_entries.values()))]]
iptm_rows = [[""]+[f"{chain_ids[idx[0]]}:{chain_ids[idx[1]]}" for idx, val in sorted_entries[0][1]]]
else:
iptm_rows = [[""]+[f"{idx_to_letter(idx[0])}:{idx_to_letter(idx[1])}" for idx, val in next(iter(chain_pair_entries.values()))]]
iptm_rows = [[""]+[f"{idx_to_letter(idx[0])}:{idx_to_letter(idx[1])}" for idx, val in sorted_entries[0][1]]]

for model_idx, chain_pair_entries_values in chain_pair_entries.items():
for model_idx, chain_pair_entries_values in sorted_entries:
iptm_rows.append([model_idx]+[f"{val:.4f}" for idx, val in chain_pair_entries_values])

return [list(row) for row in zip(*iptm_rows)]
Expand All @@ -102,7 +105,7 @@ def format_pair_score_rows(pair_score_entries, pair_labels=None):
pair_labels = sorted({label for score_values in pair_score_entries.values() for label, _ in score_values})

rows = [[""] + pair_labels]
for model_idx, score_values in pair_score_entries.items():
for model_idx, score_values in sorted(pair_score_entries.items(), key=lambda item: sort_model_label(item[0])):
score_map = {label: value for label, value in score_values}
rows.append([model_idx] + [f"{score_map[label]:.4f}" if label in score_map else "n/a" for label in pair_labels])

Expand All @@ -127,6 +130,46 @@ def write_tsv(file_path, rows):
writer = csv.writer(out_f, delimiter='\t')
writer.writerows(rows)

def sort_model_label(label):
try:
return (0, int(label))
except (TypeError, ValueError):
return (1, str(label))

def infer_model_rank(file_path):
Comment on lines +135 to +140

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking for a colabfold quick fix for v2.1, but it would be good for rank sorting to be in utils.py with --prog passed to it. Then imported like the other previous helper functions.

That would allow easy extension of mode-specific ranking patterns.
Would also allow one central function to maintain that can be called anywhere else rank-based ordering is needed.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A second benefit is that, if --prog is passed through in a future function, it makes the intended pattern for that mode explicit and avoids ambiguous matches and potential regex collisions when new modes might do different filenaming

normalized_path = file_path.replace(os.sep, "/")
rank_patterns = [
r"ranked_(\d+)",
r"_rank_(\d+)",
r"-rank(\d+)(?:/|$)",
r"_model_(\d+)",
]

for pattern in rank_patterns:
match = re.search(pattern, normalized_path)
if match:
return int(match.group(1))

return None


def sort_paths_by_rank(paths):
def sort_key(path):
rank = infer_model_rank(path)
if rank is None:
return (1, os.path.basename(path))
return (0, rank, os.path.basename(path))

return sorted(paths, key=sort_key)
Comment thread
jscgh marked this conversation as resolved.


def build_struct_map(struct_files):
struct_map = {}
for idx, struct_file in enumerate(sort_paths_by_rank(struct_files)):
rank = infer_model_rank(struct_file)
struct_map[rank if rank is not None else idx] = struct_file
return struct_map


def resolve_struct_for_model(struct_map, model_id):
if model_id in struct_map:
Expand All @@ -135,7 +178,7 @@ def resolve_struct_for_model(struct_map, model_id):
numeric_model_id = int(model_id)
except (TypeError, ValueError):
return None
return struct_map.get(numeric_model_id, struct_map.get(numeric_model_id - 1))
return struct_map.get(numeric_model_id)


def parse_ipsae_text_report(report_path):
Expand Down Expand Up @@ -222,13 +265,17 @@ def extract_structs_plddt_to_tsv(name, structures):
Write out a tsv file contain pLDDTs for reading by MultiQC in nf-core/proteinfold
Uses utils function with BioPython PDB package to extract residue pLDDT values from the b-factor column.
"""
plddt_cols = [plddt_from_struct_b_factor(structure) for structure in structures]
sorted_structures = sort_paths_by_rank(structures)
plddt_cols = [plddt_from_struct_b_factor(structure) for structure in sorted_structures]
res_counts = [len(plddt_col) for plddt_col in plddt_cols]

if len(set(res_counts)) != 1:
raise ValueError("Not all structures have the same number of residues!")

rank_names = [f"rank_{i}" for i in range(len(structures))]
rank_names = []
for idx, structure in enumerate(sorted_structures):
rank = infer_model_rank(structure)
rank_names.append(f"rank_{rank}" if rank is not None else f"rank_{idx}")
# Create header as the first row
plddt_rows = [["Positions"] + rank_names]
res_id_col = list(range(len(plddt_cols[0])))
Expand All @@ -244,10 +291,7 @@ def read_pkl(name, pkl_files, struct_files=None):
ipsae_data = {}
chainwise_iptm = {}
chainwise_ipsae = {}
struct_map = {}
if struct_files:
for idx, struct_file in enumerate(sorted(struct_files)):
struct_map[idx] = struct_file
struct_map = build_struct_map(struct_files) if struct_files else {}
for pkl_file in pkl_files:
print(f"Processing {pkl_file}")
data = pickle.load(open(pkl_file, "rb"))
Expand Down Expand Up @@ -357,10 +401,7 @@ def read_a3m(name, a3m_files):
def read_npz(name, npz_files, struct_files=None):
ipsae_rows = []
chainwise_ipsae = {}
struct_map = {}
if struct_files:
for idx, struct_file in enumerate(sorted(struct_files)):
struct_map[idx] = struct_file
struct_map = build_struct_map(struct_files) if struct_files else {}
for idx, npz_file in enumerate(npz_files):
data = np.load(npz_file)
#Boltz PAE files if --write_full_pae is used
Expand Down Expand Up @@ -467,38 +508,23 @@ def read_json(name, json_files, struct_files=None):
chain_pair_entries = {}
chainwise_ptms = {}
chain_ids = []
struct_map = {}
if struct_files:
for idx, struct_file in enumerate(sorted(struct_files)):
struct_map[idx] = struct_file
struct_map = build_struct_map(struct_files) if struct_files else {}

for idx, json_file in enumerate(json_files):
with open(json_file, 'r') as f:
data = json.load(f)
if json_file.endswith("_data.json"): #AF3 output with MSA info
# Can't just used format_msa_rows since there's FASTA headers in the json content
paired_msa_rows = []
unpaired_msa_rows = []
for chain in data['sequences']:
unpaired_MSA = chain['protein']['unpairedMsa']
unpaired_msa_lines = [''.join(c for c in line if not c.islower()) for line in unpaired_MSA.split("\n") if line.strip() and not line.startswith(">")]
unpaired_msa_rows.append([[str(AA_to_int.get(residue, 20)) for residue in line] for line in unpaired_msa_lines])
paired_MSA = chain['protein']['pairedMsa']
paired_msa_lines = [''.join(c for c in line if not c.islower()) for line in paired_MSA.split("\n") if line.strip() and not line.startswith(">")]
paired_msa_rows.append([[str(AA_to_int.get(residue, 20)) for residue in line] for line in paired_msa_lines])

chains = len(data['sequences'])
final_rows = []
# Paired
for i in range(len(paired_msa_rows[0])): #The number of paired lines is common to all MSAs
temp_row = []
#This needs to be fixed if inference is batched in future.
for j in range(chains):
temp_row.extend(paired_msa_rows[j][i])
final_rows.append(temp_row)

# Un-paired
msa_widths = [len(paired_msa_rows[chain][0]) for chain in range(chains)]
# Exclude the paired block for now; use the unpaired MSA only.
msa_widths = [len(unpaired_msa_rows[chain][0]) if unpaired_msa_rows[chain] else 0 for chain in range(chains)]
msa_heights = [len(unpaired_msa_rows[chain]) for chain in range(chains)]

cum_total_rows = np.cumsum(msa_heights)
Expand Down Expand Up @@ -632,14 +658,13 @@ def read_colabfold_metrics(name, colabfold_metrics_fns, struct_files=None):
ipsae_rows = []
chainwise_iptm = {}
chainwise_ipsae = {}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future (v3), not for this minor release: I like the move toward program-specific extraction logic rather than the prior file-extension-driven initial implementation.

For future refactoring development, I think a lot of the metric row list initialisation and write_tsv() boilerplate can be avoided being repeated by building this into a shared read_metrics(...) function that has program details passed through the function signature.

With something like:
read_metrics(program=args.prog, struct_files=args.structs, metrics_files=...)

Then all program specific compressed data deserialisation can be branched on --prog=${meta.mode}, and all the great ranking and structure mapping helpers can be re-used without altering lines in different functions.

struct_map = {}
if struct_files:
for idx, struct_file in enumerate(sorted(struct_files)):
struct_map[idx] = struct_file
struct_map = build_struct_map(struct_files) if struct_files else {}
for fn in colabfold_metrics_fns:
with open(fn) as f:
data = json.load(f)
rank_id = int(fn.split("rank_")[1].split("_")[0])-1
rank_id = infer_model_rank(fn)
if rank_id is None:
raise ValueError(f"Unable to infer ColabFold rank from metrics filename: {fn}")
if "pae" in data:
write_tsv(f"{name}_{rank_id}_pae.tsv", format_pae_rows(data["pae"]))
if "ptm" in data:
Expand Down
7 changes: 4 additions & 3 deletions conf/modules_colabfold.config
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@ process {
process {
withName: 'COLABFOLD_BATCH' {
accelerator = { params.use_gpu ? 1 : 0 }
ext.args = [
ext.args = {[
params.colabfold_use_gpu_relax ? '--use-gpu-relax' : '',
params.colabfold_use_amber ? '--amber' : '',
params.colabfold_use_templates ? '--templates' : '',
params.colabfold_use_templates && (params.use_msa_server || params.colabfold_template_path) ? '--templates' : '',
params.colabfold_template_path ? "--custom-template-path ${params.colabfold_template_path}" : '',
params.random_seed != null ? "--random-seed ${params.random_seed}" : '',
params.use_msa_server && params.msa_server_url ? "--host-url ${params.msa_server_url}" : ''
].join(' ').trim()
].findAll { arg -> arg }.join(' ').trim()}
publishDir = [
[
path: { "${params.outdir}/colabfold/${meta.id}/" },
Expand Down
4 changes: 2 additions & 2 deletions docs/output.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ In the HTML reports, chainwise iPTM and ipSAE are displayed as chain-by-chain ma
Predicted alignment error of residues `j` aligned by residue `i`, rounded to 4 decimal places.
The row number gives you the index of residue `i` and the column value within the row gives the index of residue `j` for the 2D PAE matrix.

Each model prediction generates a separate file containing the rank number. The `_0_pae.tsv` file corresponds to the top ranked model, other ranked results are stored within the `paes/` folder.
Each model prediction generates a separate file containing the rank number. Rank numbering follows the native convention of the underlying tool, so top-ranked models may appear as either `_0_pae.tsv` or `_1_pae.tsv` depending on the mode. Additional ranked results are stored within the `paes/` folder.

```
0.2500 1.5710 3.9037 6.2177 8.4471 11.4583 12.9679 15.1237 18.0263 18.3868 18.9381 20.5747 19.3314 20.1825 21.6145 23.2190
Expand Down Expand Up @@ -224,7 +224,7 @@ Examples include:

- `alphafold2/<MODE>/<SEQUENCE NAME>/raw/`
- `colabfold/<SEQUENCE NAME>/raw/`
- `boltz/<SEQUENCE NAME>/boltz_results_*/`
- `boltz/<SEQUENCE NAME>/boltz_results_<SEQUENCE NAME>/`
- `rosettafold_all_atom/<SEQUENCE NAME>/raw/`
- `alphafold3/<SEQUENCE NAME>/raw/`
- `helixfold3/<SEQUENCE NAME>/raw/`
Expand Down
3 changes: 2 additions & 1 deletion docs/usage/colabfold.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ See the [ColabFold](https://github.com/sokrypton/ColabFold) documentation for a
| `--colabfold_use_amber` | `true` | ColabFold outputs will sometimes contain phsyical violations such as steric clashes. These clashes can be resolved by post-processing the outputs with a short relaxation using the Amber Force Field. Non-clashing atoms are pinned to starting coordinates such that the relaxation has a minimal impact on final structures. |
| `--colabfold_db_load_mode` | `0` | Specify the way that MMSeqs2 will load the required databases in memory |
| `--colabfold_alphafold2_params_prefix` | `alphafold_params_2022-12-06` | Specify the alphafold2 params used for prediction. |
| `--colabfold_use_templates` | `false` | Use PDB templates to support predictions. The ColabFold notebooks do not use templates by default. |
| `--colabfold_use_templates` | `false` | Use PDB templates to support predictions. When `--use_msa_server` is disabled, this only takes effect if `--colabfold_template_path` is also set so ColabFold can use local templates without contacting the MMSeqs API. The ColabFold notebooks do not use templates by default. |
| `--colabfold_template_path` | `null` | Path to a local ColabFold template directory. Set this together with `--colabfold_use_templates` to enable template use in local ColabFold mode without remote template lookups. |
| `--colabfold_create_index` | `false` | Create index for ColabFold databases during setup. On network filesystems it can be more performant to re-compute the index on the fly |

> You can override any of these parameters via the command line or a params file.
16 changes: 13 additions & 3 deletions modules/local/colabfold_batch/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ process COLABFOLD_BATCH {
tuple val(meta), path ("${meta.id}_colabfold_msa.tsv") , emit: msa
tuple val(meta), path ("${meta.id}_plddt_mqc.tsv") , emit: multiqc
tuple val(meta), path ("${meta.id}_*_pae.tsv") , optional: true, emit: paes
tuple val(meta), path ("${meta.id}_0_pae.tsv") , optional: true, emit: pae
tuple val(meta), path ("${meta.id}_1_pae.tsv") , optional: true, emit: pae
tuple val(meta), path ("${meta.id}_ptm.tsv") , optional: true, emit: ptms
tuple val(meta), path ("${meta.id}_iptm.tsv") , optional: true, emit: iptms
tuple val(meta), path ("${meta.id}_ipsae.tsv") , optional: true, emit: ipsaes
Expand Down Expand Up @@ -83,9 +83,19 @@ process COLABFOLD_BATCH {
touch ./raw/${meta.id}_relaxed_rank_001_model_1_seed_000.pdb
touch ./raw/${meta.id}_relaxed_rank_002_model_2_seed_000.pdb
touch ./raw/${meta.id}_relaxed_rank_003_model_3_seed_000.pdb
touch ./raw/${meta.id}_relaxed_rank_004_model_4_seed_000.pdb
touch ./raw/${meta.id}_relaxed_rank_005_model_5_seed_000.pdb
touch ./${meta.id}_seq_coverage.png
touch ./raw/${meta.id}_scores_rank.json
touch ./${meta.id}_0_pae.tsv
touch ./raw/${meta.id}_scores_rank_001_model_1_seed_000.json
touch ./raw/${meta.id}_scores_rank_002_model_2_seed_000.json
touch ./raw/${meta.id}_scores_rank_003_model_3_seed_000.json
touch ./raw/${meta.id}_scores_rank_004_model_4_seed_000.json
touch ./raw/${meta.id}_scores_rank_005_model_5_seed_000.json
touch ./${meta.id}_1_pae.tsv
touch ./${meta.id}_2_pae.tsv
touch ./${meta.id}_3_pae.tsv
touch ./${meta.id}_4_pae.tsv
touch ./${meta.id}_5_pae.tsv
touch ./${meta.id}_ptm.tsv
touch ./${meta.id}_iptm.tsv
touch ./${meta.id}_ipsae.tsv
Expand Down
Loading
Loading