From 7a1736326c1db5eb0bbb80ac58a60773c916aade Mon Sep 17 00:00:00 2001 From: alex-hh Date: Mon, 24 Apr 2023 12:14:11 +0100 Subject: [PATCH 1/2] flag to specify lower case aa meaning in predict.py --- examples/variant-prediction/predict.py | 29 +++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/examples/variant-prediction/predict.py b/examples/variant-prediction/predict.py index 81d72c40..f15cd916 100644 --- a/examples/variant-prediction/predict.py +++ b/examples/variant-prediction/predict.py @@ -29,14 +29,26 @@ def remove_insertions(sequence: str) -> str: return sequence.translate(translation) -def read_msa(filename: str, nseq: int) -> List[Tuple[str, str]]: +def process_sequence(sequence: str, lowercase_type="nonfocus") -> str: + if lowercase_type == "nonfocus": + return sequence.upper() + elif lowercase_type == "insertions": + return remove_insertions(sequence) + else: + raise ValueError() + + +def read_msa(filename: str, nseq: int, lowercase_type="nonfocus") -> List[Tuple[str, str]]: """ Reads the first nseq sequences from an MSA file, automatically removes insertions. - The input file must be in a3m format (although we use the SeqIO fasta parser) - for remove_insertions to work properly.""" + If lowercase_type is 'insertion', the input file must be in a2m/a3m format + for remove_insertions to work properly. + + If lowercase_type is 'nonfocus', all sequences should have the same length. + """ msa = [ - (record.description, remove_insertions(str(record.seq))) + (record.description, process_sequence(str(record.seq), lowercase_type)) for record in itertools.islice(SeqIO.parse(filename, "fasta"), nseq) ] return msa @@ -99,6 +111,13 @@ def create_parser(): default=400, help="number of sequences to select from the start of the MSA" ) + parser.add_argument( + "--lowercase-type", choices=["insertion", "nonfocus"], default="nonfocus", + help=( + "How lowercase amino acids in MSA should be interpreted: " + "nonfocus for EVMutation/Gym-style, insertion for a2m/a3m." + ) + ) # fmt: on parser.add_argument("--nogpu", action="store_true", help="Do not use GPU even if available") return parser @@ -159,7 +178,7 @@ def main(args): batch_converter = alphabet.get_batch_converter() if isinstance(model, MSATransformer): - data = [read_msa(args.msa_path, args.msa_samples)] + data = [read_msa(args.msa_path, args.msa_samples, lowercase_type=args.lowercase_type)] assert ( args.scoring_strategy == "masked-marginals" ), "MSA Transformer only supports masked marginal strategy" From 3c3eadc294636ee38cf2a367d998770beba01749 Mon Sep 17 00:00:00 2001 From: alex-hh Date: Mon, 24 Apr 2023 12:38:33 +0100 Subject: [PATCH 2/2] fix typo --- examples/variant-prediction/predict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/variant-prediction/predict.py b/examples/variant-prediction/predict.py index f15cd916..f05d65d4 100644 --- a/examples/variant-prediction/predict.py +++ b/examples/variant-prediction/predict.py @@ -32,10 +32,10 @@ def remove_insertions(sequence: str) -> str: def process_sequence(sequence: str, lowercase_type="nonfocus") -> str: if lowercase_type == "nonfocus": return sequence.upper() - elif lowercase_type == "insertions": + elif lowercase_type == "insertion": return remove_insertions(sequence) else: - raise ValueError() + raise ValueError(f"lowercase_type should be nonfocus or insert but got {lowercase_type}") def read_msa(filename: str, nseq: int, lowercase_type="nonfocus") -> List[Tuple[str, str]]: