diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index da1cb08c5..1d2c80ccc 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -13,12 +13,18 @@ # limitations under the License. """Functions for processing confidence metrics.""" - +from jax import debug +from functools import partial +from jax import jit import jax.numpy as jnp import jax import numpy as np from alphafold.common import residue_constants import scipy.special +### + + + def compute_tol(prev_pos, current_pos, mask, use_jnp=False): # Early stopping criteria based on criteria used in @@ -197,6 +203,13 @@ def predicted_tm_score_chain(logits, breaks, residue_weights = None, if chain_num is None: chain_num = 1 +# batch = {'asym_id': asym_id} +# apply_fn_jit = jax.jit(apply_fn) +# max_asym_id_traced = apply_fn_jit(batch) +# max_asym_id_as_int = int(max_asym_id_traced) + +# jax.debug.print('max_asym_id_as_int={max_asym_id_as_int}',max_asym_id_as_int=max_asym_id_as_int) + # residue_weights has to be in [0, 1], but can be floating-point, i.e. the # exp. resolved head's probability. if residue_weights is None: @@ -286,4 +299,4 @@ def get_confidence_metrics(prediction_result, mask, rank_by = "plddt", use_jnp=F else: mean_score = confidence_metrics["mean_plddt"] confidence_metrics["ranking_confidence"] = mean_score - return confidence_metrics \ No newline at end of file + return confidence_metrics diff --git a/alphafold/model/model.py b/alphafold/model/model.py index 88e90f1f4..9327dd5fb 100644 --- a/alphafold/model/model.py +++ b/alphafold/model/model.py @@ -35,15 +35,17 @@ class RunModel: def __init__(self, config: ml_collections.ConfigDict, params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None, - is_training = False): + is_training = False, + chain_num=1): self.config = config self.params = params self.multimer_mode = config.model.global_config.multimer_mode + self.chain_num = chain_num if self.multimer_mode: def _forward_fn(batch): - model = modules_multimer.AlphaFold(self.config.model) + model = modules_multimer.AlphaFold(self.config.model,self.chain_num) return model(batch, is_training=is_training) else: def _forward_fn(batch): @@ -134,6 +136,7 @@ def predict(self, Returns: A dictionary of model outputs. """ + self.init_params(feat) logging.info('Running predict with shape(feat) = %s', tree.map_structure(lambda x: x.shape, feat)) @@ -146,6 +149,7 @@ def predict(self, else: num_ensemble = self.config.data.eval.num_ensemble L = aatype.shape[1] + # initialize @@ -169,6 +173,7 @@ def _jnp_to_np(x): # initialize random key key = jax.random.PRNGKey(random_seed) + # iterate through recyckes for r in range(num_iters): @@ -197,6 +202,5 @@ def _jnp_to_np(x): break if r > 0 and result["tol"] < self.config.model.recycle_early_stop_tolerance: break - logging.info('Output shape was %s', tree.map_structure(lambda x: x.shape, result)) - return result, r \ No newline at end of file + return result, r diff --git a/alphafold/model/modules_multimer.py b/alphafold/model/modules_multimer.py index 1e909c944..aa43adbfb 100644 --- a/alphafold/model/modules_multimer.py +++ b/alphafold/model/modules_multimer.py @@ -405,11 +405,11 @@ class AlphaFold(hk.Module): """AlphaFold-Multimer model with recycling. """ - def __init__(self, config, name='alphafold'): + def __init__(self, config,chain_num, name='alphafold'): super().__init__(name=name) self.config = config self.global_config = config.global_config - + self.chain_num = chain_num def __call__( self, batch, @@ -418,6 +418,7 @@ def __call__( safe_key=None): c = self.config + chain_num = self.chain_num impl = AlphaFoldIteration(c, self.global_config) if safe_key is None: @@ -461,9 +462,6 @@ def apply_network(prev, safe_key): if not return_representations: del ret['representations'] - # Extract chain NUM - chain_num = c.embeddings_and_evoformer.max_relative_chain + 1 - # add confidence metrics ret.update(confidence.get_confidence_metrics( prediction_result=ret, diff --git a/setup.py b/setup.py index 244b7503c..b08a98bd7 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ setup( name='alphafold-colabfold', - version='2.3.6', + version='2.3.8', long_description_content_type='text/markdown', description='An implementation of the inference pipeline of AlphaFold v2.3.1. ' 'This is a completely new model that was entered as AlphaFold2 in CASP14 ' @@ -26,7 +26,7 @@ author='DeepMind', author_email='alphafold@deepmind.com', license='Apache License, Version 2.0', - url='https://github.com/sokrypton/alphafold', + url='https://github.com/ntnn19/alphafold/tree/chain_iptm', packages=find_packages(), install_requires=[ 'absl-py',