-
Notifications
You must be signed in to change notification settings - Fork 71
Resolve milestone 2.1.0 issues #628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
ad3480d
44dd5c4
18466ea
803ac74
366a8e6
4a4f6e9
97e4121
2ac4368
1228ff1
e38174a
4c8f90c
8551cf8
f597903
1a69d40
1aad0fa
10bc532
02b886b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| #!/usr/bin/env python | ||
| #!/usr/bin/env python3 | ||
|
|
||
| import pickle | ||
| import os | ||
|
|
@@ -11,6 +11,7 @@ | |
| import numpy as np | ||
| import csv | ||
| import string | ||
| import re | ||
| from utils import plddt_from_struct_b_factor, get_chain_ids | ||
|
|
||
| # TODO: Issue #309, make into a proper separate process, it its own module so that dependencies can be managed better | ||
|
|
@@ -85,13 +86,15 @@ def idx_to_letter(idx): | |
| break | ||
| return result | ||
|
|
||
| sorted_entries = sorted(chain_pair_entries.items(), key=lambda item: sort_model_label(item[0])) | ||
|
|
||
| if chain_ids: | ||
| #would be better with some model_id sorting | ||
| iptm_rows = [[""]+[f"{chain_ids[idx[0]]}:{chain_ids[idx[1]]}" for idx, val in next(iter(chain_pair_entries.values()))]] | ||
| iptm_rows = [[""]+[f"{chain_ids[idx[0]]}:{chain_ids[idx[1]]}" for idx, val in sorted_entries[0][1]]] | ||
| else: | ||
| iptm_rows = [[""]+[f"{idx_to_letter(idx[0])}:{idx_to_letter(idx[1])}" for idx, val in next(iter(chain_pair_entries.values()))]] | ||
| iptm_rows = [[""]+[f"{idx_to_letter(idx[0])}:{idx_to_letter(idx[1])}" for idx, val in sorted_entries[0][1]]] | ||
|
|
||
| for model_idx, chain_pair_entries_values in chain_pair_entries.items(): | ||
| for model_idx, chain_pair_entries_values in sorted_entries: | ||
| iptm_rows.append([model_idx]+[f"{val:.4f}" for idx, val in chain_pair_entries_values]) | ||
|
|
||
| return [list(row) for row in zip(*iptm_rows)] | ||
|
|
@@ -102,7 +105,7 @@ def format_pair_score_rows(pair_score_entries, pair_labels=None): | |
| pair_labels = sorted({label for score_values in pair_score_entries.values() for label, _ in score_values}) | ||
|
|
||
| rows = [[""] + pair_labels] | ||
| for model_idx, score_values in pair_score_entries.items(): | ||
| for model_idx, score_values in sorted(pair_score_entries.items(), key=lambda item: sort_model_label(item[0])): | ||
| score_map = {label: value for label, value in score_values} | ||
| rows.append([model_idx] + [f"{score_map[label]:.4f}" if label in score_map else "n/a" for label in pair_labels]) | ||
|
|
||
|
|
@@ -127,6 +130,46 @@ def write_tsv(file_path, rows): | |
| writer = csv.writer(out_f, delimiter='\t') | ||
| writer.writerows(rows) | ||
|
|
||
| def sort_model_label(label): | ||
| try: | ||
| return (0, int(label)) | ||
| except (TypeError, ValueError): | ||
| return (1, str(label)) | ||
|
|
||
| def infer_model_rank(file_path): | ||
|
Comment on lines
+135
to
+140
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Non-blocking for a colabfold quick fix for v2.1, but it would be good for rank sorting to be in That would allow easy extension of mode-specific ranking patterns.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A second benefit is that, if |
||
| normalized_path = file_path.replace(os.sep, "/") | ||
| rank_patterns = [ | ||
| r"ranked_(\d+)", | ||
| r"_rank_(\d+)", | ||
| r"-rank(\d+)(?:/|$)", | ||
| r"_model_(\d+)", | ||
| ] | ||
|
|
||
| for pattern in rank_patterns: | ||
| match = re.search(pattern, normalized_path) | ||
| if match: | ||
| return int(match.group(1)) | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| def sort_paths_by_rank(paths): | ||
| def sort_key(path): | ||
| rank = infer_model_rank(path) | ||
| if rank is None: | ||
| return (1, os.path.basename(path)) | ||
| return (0, rank, os.path.basename(path)) | ||
|
|
||
| return sorted(paths, key=sort_key) | ||
|
jscgh marked this conversation as resolved.
|
||
|
|
||
|
|
||
| def build_struct_map(struct_files): | ||
| struct_map = {} | ||
| for idx, struct_file in enumerate(sort_paths_by_rank(struct_files)): | ||
| rank = infer_model_rank(struct_file) | ||
| struct_map[rank if rank is not None else idx] = struct_file | ||
| return struct_map | ||
|
|
||
|
|
||
| def resolve_struct_for_model(struct_map, model_id): | ||
| if model_id in struct_map: | ||
|
|
@@ -135,7 +178,7 @@ def resolve_struct_for_model(struct_map, model_id): | |
| numeric_model_id = int(model_id) | ||
| except (TypeError, ValueError): | ||
| return None | ||
| return struct_map.get(numeric_model_id, struct_map.get(numeric_model_id - 1)) | ||
| return struct_map.get(numeric_model_id) | ||
|
|
||
|
|
||
| def parse_ipsae_text_report(report_path): | ||
|
|
@@ -222,13 +265,17 @@ def extract_structs_plddt_to_tsv(name, structures): | |
| Write out a tsv file contain pLDDTs for reading by MultiQC in nf-core/proteinfold | ||
| Uses utils function with BioPython PDB package to extract residue pLDDT values from the b-factor column. | ||
| """ | ||
| plddt_cols = [plddt_from_struct_b_factor(structure) for structure in structures] | ||
| sorted_structures = sort_paths_by_rank(structures) | ||
| plddt_cols = [plddt_from_struct_b_factor(structure) for structure in sorted_structures] | ||
| res_counts = [len(plddt_col) for plddt_col in plddt_cols] | ||
|
|
||
| if len(set(res_counts)) != 1: | ||
| raise ValueError("Not all structures have the same number of residues!") | ||
|
|
||
| rank_names = [f"rank_{i}" for i in range(len(structures))] | ||
| rank_names = [] | ||
| for idx, structure in enumerate(sorted_structures): | ||
| rank = infer_model_rank(structure) | ||
| rank_names.append(f"rank_{rank}" if rank is not None else f"rank_{idx}") | ||
| # Create header as the first row | ||
| plddt_rows = [["Positions"] + rank_names] | ||
| res_id_col = list(range(len(plddt_cols[0]))) | ||
|
|
@@ -244,10 +291,7 @@ def read_pkl(name, pkl_files, struct_files=None): | |
| ipsae_data = {} | ||
| chainwise_iptm = {} | ||
| chainwise_ipsae = {} | ||
| struct_map = {} | ||
| if struct_files: | ||
| for idx, struct_file in enumerate(sorted(struct_files)): | ||
| struct_map[idx] = struct_file | ||
| struct_map = build_struct_map(struct_files) if struct_files else {} | ||
| for pkl_file in pkl_files: | ||
| print(f"Processing {pkl_file}") | ||
| data = pickle.load(open(pkl_file, "rb")) | ||
|
|
@@ -357,10 +401,7 @@ def read_a3m(name, a3m_files): | |
| def read_npz(name, npz_files, struct_files=None): | ||
| ipsae_rows = [] | ||
| chainwise_ipsae = {} | ||
| struct_map = {} | ||
| if struct_files: | ||
| for idx, struct_file in enumerate(sorted(struct_files)): | ||
| struct_map[idx] = struct_file | ||
| struct_map = build_struct_map(struct_files) if struct_files else {} | ||
| for idx, npz_file in enumerate(npz_files): | ||
| data = np.load(npz_file) | ||
| #Boltz PAE files if --write_full_pae is used | ||
|
|
@@ -467,38 +508,23 @@ def read_json(name, json_files, struct_files=None): | |
| chain_pair_entries = {} | ||
| chainwise_ptms = {} | ||
| chain_ids = [] | ||
| struct_map = {} | ||
| if struct_files: | ||
| for idx, struct_file in enumerate(sorted(struct_files)): | ||
| struct_map[idx] = struct_file | ||
| struct_map = build_struct_map(struct_files) if struct_files else {} | ||
|
|
||
| for idx, json_file in enumerate(json_files): | ||
| with open(json_file, 'r') as f: | ||
| data = json.load(f) | ||
| if json_file.endswith("_data.json"): #AF3 output with MSA info | ||
| # Can't just used format_msa_rows since there's FASTA headers in the json content | ||
| paired_msa_rows = [] | ||
| unpaired_msa_rows = [] | ||
| for chain in data['sequences']: | ||
| unpaired_MSA = chain['protein']['unpairedMsa'] | ||
| unpaired_msa_lines = [''.join(c for c in line if not c.islower()) for line in unpaired_MSA.split("\n") if line.strip() and not line.startswith(">")] | ||
| unpaired_msa_rows.append([[str(AA_to_int.get(residue, 20)) for residue in line] for line in unpaired_msa_lines]) | ||
| paired_MSA = chain['protein']['pairedMsa'] | ||
| paired_msa_lines = [''.join(c for c in line if not c.islower()) for line in paired_MSA.split("\n") if line.strip() and not line.startswith(">")] | ||
| paired_msa_rows.append([[str(AA_to_int.get(residue, 20)) for residue in line] for line in paired_msa_lines]) | ||
|
|
||
| chains = len(data['sequences']) | ||
| final_rows = [] | ||
| # Paired | ||
| for i in range(len(paired_msa_rows[0])): #The number of paired lines is common to all MSAs | ||
| temp_row = [] | ||
| #This needs to be fixed if inference is batched in future. | ||
| for j in range(chains): | ||
| temp_row.extend(paired_msa_rows[j][i]) | ||
| final_rows.append(temp_row) | ||
|
|
||
| # Un-paired | ||
| msa_widths = [len(paired_msa_rows[chain][0]) for chain in range(chains)] | ||
| # Exclude the paired block for now; use the unpaired MSA only. | ||
| msa_widths = [len(unpaired_msa_rows[chain][0]) if unpaired_msa_rows[chain] else 0 for chain in range(chains)] | ||
| msa_heights = [len(unpaired_msa_rows[chain]) for chain in range(chains)] | ||
|
|
||
| cum_total_rows = np.cumsum(msa_heights) | ||
|
|
@@ -632,14 +658,13 @@ def read_colabfold_metrics(name, colabfold_metrics_fns, struct_files=None): | |
| ipsae_rows = [] | ||
| chainwise_iptm = {} | ||
| chainwise_ipsae = {} | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Future (v3), not for this minor release: I like the move toward program-specific extraction logic rather than the prior file-extension-driven initial implementation. For future refactoring development, I think a lot of the metric row list initialisation and With something like: Then all program specific compressed data deserialisation can be branched on |
||
| struct_map = {} | ||
| if struct_files: | ||
| for idx, struct_file in enumerate(sorted(struct_files)): | ||
| struct_map[idx] = struct_file | ||
| struct_map = build_struct_map(struct_files) if struct_files else {} | ||
| for fn in colabfold_metrics_fns: | ||
| with open(fn) as f: | ||
| data = json.load(f) | ||
| rank_id = int(fn.split("rank_")[1].split("_")[0])-1 | ||
| rank_id = infer_model_rank(fn) | ||
| if rank_id is None: | ||
| raise ValueError(f"Unable to infer ColabFold rank from metrics filename: {fn}") | ||
| if "pae" in data: | ||
| write_tsv(f"{name}_{rank_id}_pae.tsv", format_pae_rows(data["pae"])) | ||
| if "ptm" in data: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this
sorted()lambda approach and clear index access to the first sorted entry here 👍