diff --git a/crypto-ffi/src/core_crypto/e2ei/mod.rs b/crypto-ffi/src/core_crypto/e2ei/mod.rs index 9a2ba1512a..f9f7d5ac2c 100644 --- a/crypto-ffi/src/core_crypto/e2ei/mod.rs +++ b/crypto-ffi/src/core_crypto/e2ei/mod.rs @@ -9,7 +9,7 @@ pub(crate) mod identities; impl CoreCryptoFfi { /// Returns true if the PKI environment has been set up and its provider is configured. pub async fn e2ei_is_pki_env_setup(&self) -> bool { - self.inner.get_pki_environment().read().await.is_some() + self.inner.get_pki_environment().await.is_some() } /// Returns true if end-to-end identity is enabled for the given ciphersuite. diff --git a/crypto-ffi/src/core_crypto/mod.rs b/crypto-ffi/src/core_crypto/mod.rs index 2e7e71023f..21aae03c46 100644 --- a/crypto-ffi/src/core_crypto/mod.rs +++ b/crypto-ffi/src/core_crypto/mod.rs @@ -17,7 +17,7 @@ use crate::{CoreCryptoResult, Database}; /// CoreCrypto wraps around MLS and Proteus implementations and provides a transactional interface for each. #[derive(Debug, uniffi::Object)] pub struct CoreCryptoFfi { - pub(crate) inner: core_crypto::CoreCrypto, + pub(crate) inner: Arc, } /// Construct a new `CoreCryptoFfi` instance. diff --git a/crypto-ffi/src/pki_env.rs b/crypto-ffi/src/pki_env.rs index a632db2d1b..b89f5ee046 100644 --- a/crypto-ffi/src/pki_env.rs +++ b/crypto-ffi/src/pki_env.rs @@ -265,9 +265,9 @@ impl CoreCryptoFfi { /// /// Returns null if it is not set. pub async fn get_pki_environment(&self) -> Option> { - let pki_env = self.inner.get_pki_environment(); - (*pki_env.read().await) - .as_ref() - .map(|env| Arc::new(PkiEnvironment(env.clone()))) + self.inner + .get_pki_environment() + .await + .map(|inner| Arc::new(PkiEnvironment(inner))) } } diff --git a/crypto/src/ephemeral.rs b/crypto/src/ephemeral.rs index 7ae4571f41..3a29150bb7 100644 --- a/crypto/src/ephemeral.rs +++ b/crypto/src/ephemeral.rs @@ -126,7 +126,7 @@ impl CoreCrypto { /// /// This client exposes the full interface of `CoreCrypto`, but it should only be used to decrypt messages. /// Other use is a logic error. - pub async fn history_client(history_secret: HistorySecret) -> Result { + pub async fn history_client(history_secret: HistorySecret) -> Result> { if !history_secret .client_id .starts_with(HISTORY_CLIENT_ID_PREFIX.as_bytes()) diff --git a/crypto/src/lib.rs b/crypto/src/lib.rs index 02ed69294e..296008a0b1 100644 --- a/crypto/src/lib.rs +++ b/crypto/src/lib.rs @@ -123,15 +123,13 @@ impl MlsTransport for CoreCryptoTransportNotImplementedProvider { /// /// As [std::ops::Deref] is implemented, this struct is automatically dereferred to [mls::session::Session] apart from /// `proteus_*` calls -/// -/// This is cheap to clone as all internal members have `Arc` wrappers or are `Copy`. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct CoreCrypto { database: Database, - pki_environment: Arc>>>, - mls: Arc>>>, + pki_environment: RwLock>>, + mls: RwLock>>, #[cfg(feature = "proteus")] - proteus: Arc>>, + proteus: Mutex>, #[cfg(not(feature = "proteus"))] #[allow(dead_code)] proteus: (), @@ -139,29 +137,27 @@ pub struct CoreCrypto { impl CoreCrypto { /// Create an new CoreCrypto client without any initialized session. - pub fn new(database: Database) -> Self { + pub fn new(database: Database) -> Arc { Self { database, pki_environment: Default::default(), mls: Default::default(), proteus: Default::default(), } + .into() } /// Set the session's PKI Environment pub async fn set_pki_environment(&self, pki_environment: Option>) { - *self.pki_environment.write().await = pki_environment; + *self.pki_environment.write().await = pki_environment.clone(); if let Some(mls_session) = self.mls.write().await.as_mut() { - mls_session - .crypto_provider - .set_pki_environment(self.pki_environment.clone()) - .await; + mls_session.crypto_provider.set_pki_environment(pki_environment).await; } } /// Get the session's PKI Environment - pub fn get_pki_environment(&self) -> Arc>>> { - self.pki_environment.clone() + pub async fn get_pki_environment(&self) -> Option> { + self.pki_environment.read().await.clone() } /// Get the mls session if initialized diff --git a/crypto/src/mls/conversation/conversation_guard/decrypt/mod.rs b/crypto/src/mls/conversation/conversation_guard/decrypt/mod.rs index 49c83fe14e..5466015593 100644 --- a/crypto/src/mls/conversation/conversation_guard/decrypt/mod.rs +++ b/crypto/src/mls/conversation/conversation_guard/decrypt/mod.rs @@ -191,10 +191,9 @@ impl ConversationGuard { let credential = message.credential(); let epoch = message.epoch(); - let pki_env = provider.authentication_service().pki_env(); - let guard = pki_env.read().await; + let pki_env = provider.authentication_service().pki_env().await; let identity = credential - .extract_identity(self.ciphersuite().await, guard.as_ref().map(|v| &**v)) + .extract_identity(self.ciphersuite().await, pki_env.as_deref()) .await .map_err(RecursiveError::mls_credential("extracting identity"))?; diff --git a/crypto/src/mls/conversation/mod.rs b/crypto/src/mls/conversation/mod.rs index 993e40089b..8d3fa51110 100644 --- a/crypto/src/mls/conversation/mod.rs +++ b/crypto/src/mls/conversation/mod.rs @@ -221,15 +221,14 @@ pub trait Conversation<'a>: ConversationWithMls<'a> { let auth_service = mls_provider.authentication_service(); let conversation = self.conversation().await; - let pki_env = auth_service.pki_env(); - let guard = pki_env.read().await; + let pki_env = auth_service.pki_env().await; let mut identities = vec![]; for (id, credential) in conversation.members_with_key() { if device_ids.iter().any(|client_id| client_id.borrow() == id) { identities.push( credential - .extract_identity(conversation.ciphersuite(), guard.as_ref().map(|v| &**v)) + .extract_identity(conversation.ciphersuite(), pki_env.as_deref()) .await .map_err(RecursiveError::mls_credential("extracting identity"))?, ); @@ -255,8 +254,7 @@ pub trait Conversation<'a>: ConversationWithMls<'a> { let conversation = self.conversation().await; let user_ids = user_ids.iter().map(|uid| uid.as_bytes()).collect::>(); - let pki_env = auth_service.pki_env(); - let guard = pki_env.read().await; + let pki_env = auth_service.pki_env().await; let mut identities = HashMap::new(); for (id, credential) in conversation.members_with_key() { @@ -271,7 +269,7 @@ pub trait Conversation<'a>: ConversationWithMls<'a> { let uid = String::try_from(uid).map_err(RecursiveError::mls_client("getting user identities"))?; let identity = credential - .extract_identity(conversation.ciphersuite(), guard.as_ref().map(|v| &**v)) + .extract_identity(conversation.ciphersuite(), pki_env.as_deref()) .await .map_err(RecursiveError::mls_credential("extracting identity"))?; let value = identities.entry(uid).or_insert_with(Vec::new); diff --git a/crypto/src/mls/conversation/own_commit.rs b/crypto/src/mls/conversation/own_commit.rs index ce3c47df54..cb2ef758df 100644 --- a/crypto/src/mls/conversation/own_commit.rs +++ b/crypto/src/mls/conversation/own_commit.rs @@ -91,10 +91,9 @@ impl MlsConversation { credential: own_leaf.credential().clone(), signature_key: own_leaf.signature_key().clone(), }; - let pki_env = provider.authentication_service().pki_env(); - let guard = pki_env.read().await; + let pki_env = provider.authentication_service().pki_env().await; let identity = own_leaf_credential_with_key - .extract_identity(self.ciphersuite(), guard.as_ref().map(|v| &**v)) + .extract_identity(self.ciphersuite(), pki_env.as_deref()) .await .map_err(RecursiveError::mls_credential("extracting identity"))?; diff --git a/crypto/src/mls/conversation/pending_conversation.rs b/crypto/src/mls/conversation/pending_conversation.rs index 2e1fbb5ec6..a714ac8b54 100644 --- a/crypto/src/mls/conversation/pending_conversation.rs +++ b/crypto/src/mls/conversation/pending_conversation.rs @@ -176,15 +176,10 @@ impl PendingConversation { credential: own_leaf.credential().clone(), signature_key: own_leaf.signature_key().clone(), }; - let pki_env = self - .context - .pki_environment() - .await - .map_err(RecursiveError::transaction("getting PKI environment"))?; - let guard = pki_env.read().await; + let pki_env = self.context.pki_environment().await.ok(); let identity = own_leaf_credential_with_key - .extract_identity(conversation.ciphersuite(), guard.as_ref().map(|v| &**v)) + .extract_identity(conversation.ciphersuite(), pki_env.as_deref()) .await .map_err(RecursiveError::mls_credential("extracting identity"))?; diff --git a/crypto/src/mls/mod.rs b/crypto/src/mls/mod.rs index be3d41d6e9..c9898553d1 100644 --- a/crypto/src/mls/mod.rs +++ b/crypto/src/mls/mod.rs @@ -117,10 +117,9 @@ mod tests { CertificateBundle::rand_identifier(&session_id, &[x509_test_chain.find_local_intermediate_ca()]) } }; - let pki_env = cc.get_pki_environment(); - let guard = pki_env.read().await; + let pki_env = cc.get_pki_environment().await; let session_id = identifier - .get_id(guard.as_ref().map(|v| &**v)) + .get_id(pki_env.as_deref()) .await .expect("get session_id from identifier") .into_owned(); diff --git a/crypto/src/mls/session/e2e_identity.rs b/crypto/src/mls/session/e2e_identity.rs index ed34f95475..e5f8fac630 100644 --- a/crypto/src/mls/session/e2e_identity.rs +++ b/crypto/src/mls/session/e2e_identity.rs @@ -115,11 +115,9 @@ impl Session { _credential_type: CredentialType, auth_service: &AuthenticationService, ) -> E2eiConversationState { - let pki_env = auth_service.pki_env(); - let guard = pki_env.read().await; - let env = match *guard { - Some(ref env) => env, - None => return E2eiConversationState::NotEnabled, + let pki_env = auth_service.pki_env().await; + let Some(env) = pki_env else { + return E2eiConversationState::NotEnabled; }; let mut is_e2ei = false; @@ -136,7 +134,7 @@ impl Session { is_e2ei = true; - let invalid_identity = cert.extract_identity(env, ciphersuite.e2ei_hash_alg()).await.is_err(); + let invalid_identity = cert.extract_identity(&env, ciphersuite.e2ei_hash_alg()).await.is_err(); use openmls_x509_credential::X509Ext as _; let is_time_valid = cert.is_time_valid().unwrap_or(false); diff --git a/crypto/src/mls_provider/mod.rs b/crypto/src/mls_provider/mod.rs index a5bc7fd044..1fdcf23224 100644 --- a/crypto/src/mls_provider/mod.rs +++ b/crypto/src/mls_provider/mod.rs @@ -70,14 +70,22 @@ impl std::ops::DerefMut for EntropySeed { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct AuthenticationService { - pki_env: Arc>>>, + /// The PKI Environment type is complicated, but it's all necessary: + /// + /// - The inner `Arc` derives from two facts: the PKI environment is provided across FFI, and it's `!Clone`, so we + /// have to retain that `Arc` because the foreign environment is more-or-less guaranteed to have kept a reference + /// to it. + /// - The `Option` is there because the PKI environment is initially unset and may never be set, according to + /// client behavior. + /// - The `RwLock` is there because we need to be able to set the PKI environment, implying interior mutability. + pki_env: RwLock>>, } impl AuthenticationService { - pub fn pki_env(&self) -> Arc>>> { - self.pki_env.clone() + pub async fn pki_env(&self) -> Option> { + self.pki_env.read().await.clone() } } @@ -105,7 +113,7 @@ impl openmls_traits::authentication_service::AuthenticationServiceDelegate for A pub struct MlsCryptoProvider { crypto: Arc, key_store: Database, - auth_service: AuthenticationService, + auth_service: Arc, } impl MlsCryptoProvider { @@ -115,12 +123,13 @@ impl MlsCryptoProvider { /// /// - [Database::open] pub fn new(key_store: Database) -> Self { - Self::new_with_pki_env(key_store, Arc::new(RwLock::new(None))) + Self::new_with_pki_env(key_store, None) } /// Construct a crypto provider with the given database and the PKI environment. - pub fn new_with_pki_env(key_store: Database, pki_env: Arc>>>) -> Self { - let auth_service = AuthenticationService { pki_env }; + pub fn new_with_pki_env(key_store: Database, pki_env: Option>) -> Self { + let pki_env = RwLock::new(pki_env); + let auth_service = Arc::new(AuthenticationService { pki_env }); Self { key_store, crypto: Arc::clone(&CRYPTO), @@ -135,8 +144,8 @@ impl MlsCryptoProvider { } /// Set pki_env to a new shared pki environment provider - pub async fn set_pki_environment(&mut self, pki_env: Arc>>>) { - self.auth_service.pki_env = pki_env; + pub async fn set_pki_environment(&mut self, pki_env: Option>) { + *self.auth_service.pki_env.write().await = pki_env; } /// Returns whether we have a PKI env setup diff --git a/crypto/src/proteus.rs b/crypto/src/proteus.rs index d14b7dbcbd..32093e2404 100644 --- a/crypto/src/proteus.rs +++ b/crypto/src/proteus.rs @@ -601,7 +601,7 @@ mod tests { .await .unwrap(); - let cc: CoreCrypto = CoreCrypto::new(db); + let cc = CoreCrypto::new(db); let context = cc.new_transaction().await.unwrap(); assert!(context.proteus_init().await.is_ok()); assert!(context.proteus_new_prekey(1).await.is_ok()); @@ -626,7 +626,7 @@ mod tests { .await .unwrap(); - let cc: CoreCrypto = CoreCrypto::new(db.clone()); + let cc = CoreCrypto::new(db.clone()); let hooks = Arc::new(DummyPkiEnvironmentHooks); let pki_env = PkiEnvironment::new(hooks, db).await.expect("creating pki environment"); cc.set_pki_environment(Some(Arc::new(pki_env))).await; @@ -645,10 +645,9 @@ mod tests { CertificateBundle::rand_identifier(&session_id, &[x509_test_chain.find_local_intermediate_ca()]) } }; - let pki_env = cc.get_pki_environment(); - let guard = pki_env.read().await; + let pki_env = cc.get_pki_environment().await; let session_id = identifier - .get_id(guard.as_ref().map(|v| &**v)) + .get_id(pki_env.as_deref()) .await .expect("Getting session id from identifier") .into_owned(); diff --git a/crypto/src/test_utils/context.rs b/crypto/src/test_utils/context.rs index 249e830eda..539b6ea73f 100644 --- a/crypto/src/test_utils/context.rs +++ b/crypto/src/test_utils/context.rs @@ -231,12 +231,11 @@ impl SessionContext { &expected_credential.mls_credential().mls_credential() { let session = self.session().await; - let pki_env = session.crypto_provider.authentication_service().pki_env(); - let guard = pki_env.read().await; - assert!(guard.as_ref().is_some()); + let pki_env = session.crypto_provider.authentication_service().pki_env().await; + assert!(pki_env.is_some()); let mls_identity = certificate - .extract_identity(case.ciphersuite(), guard.as_ref().map(|v| &**v)) + .extract_identity(case.ciphersuite(), pki_env.as_deref()) .await .unwrap(); let mls_client_id = mls_identity.client_id.as_bytes(); @@ -246,7 +245,7 @@ impl SessionContext { let leaf: Vec = certificate.certificates.first().unwrap().clone().into(); let identity = leaf .as_slice() - .extract_identity(guard.as_ref().unwrap(), case.ciphersuite().e2ei_hash_alg()) + .extract_identity(pki_env.as_deref().unwrap(), case.ciphersuite().e2ei_hash_alg()) .await .unwrap(); let identity = WireIdentity::try_from((identity, leaf.as_slice())).unwrap(); @@ -273,7 +272,7 @@ impl SessionContext { let chain = x509_cert::Certificate::load_pem_chain(decrypted_x509_identity.certificate.as_bytes()).unwrap(); let leaf = chain.first().unwrap(); let cert_identity = leaf - .extract_identity(guard.as_ref().unwrap(), case.ciphersuite().e2ei_hash_alg()) + .extract_identity(pki_env.as_deref().unwrap(), case.ciphersuite().e2ei_hash_alg()) .await .unwrap(); diff --git a/crypto/src/test_utils/mod.rs b/crypto/src/test_utils/mod.rs index 59d47053e8..df601b27aa 100644 --- a/crypto/src/test_utils/mod.rs +++ b/crypto/src/test_utils/mod.rs @@ -96,7 +96,7 @@ pub struct SessionContext { mls_transport: Arc>>, x509_test_chain: Arc>, history_observer: Arc>>>, - core_crypto: CoreCrypto, + core_crypto: Arc, // We need to store the `TempDir` struct for the duration of the test session, // because its drop implementation takes care of the directory deletion. _db: Option<(Database, Arc)>, @@ -130,10 +130,9 @@ impl SessionContext { chain.register_with_central(&transaction).await; } - let pki_env = core_crypto.get_pki_environment(); - let guard = pki_env.read().await; + let pki_env = core_crypto.get_pki_environment().await; let session_id = identifier - .get_id(guard.as_ref().map(|v| &**v)) + .get_id(pki_env.as_deref()) .await .map_err(RecursiveError::mls_client("getting client id"))? .into_owned(); @@ -164,7 +163,7 @@ impl SessionContext { pub(crate) async fn new_from_cc( context: &TestContext, - core_crypto: CoreCrypto, + core_crypto: Arc, chain: Option<&X509TestChain>, ) -> Self { let transport = context.transport.clone(); @@ -303,19 +302,18 @@ impl SessionContext { let signer = signer.expect("Missing intermediate CA"); let cert = CertificateBundle::rand(&session_id, signer); let session = self.session.read().await; - let pki_env = session.crypto_provider.authentication_service().pki_env(); - let session_id = match *pki_env.read().await { - None => { - return Err(RecursiveError::mls_credential("")( - crate::mls::credential::Error::MissingPKIEnvironment, - ) - .into()); - } - Some(ref pki_env) => cert - .get_client_id(pki_env) - .await - .expect("Getting client id from certificate bundle"), - }; + let pki_env = session + .crypto_provider + .authentication_service() + .pki_env() + .await + .ok_or_else(|| { + RecursiveError::mls_credential("")(crate::mls::credential::Error::MissingPKIEnvironment) + })?; + let session_id = cert + .get_client_id(&pki_env) + .await + .expect("Getting client id from certificate bundle"); let credential = Credential::x509(case.ciphersuite(), cert).expect("creating x509 credential"); (session_id, credential) diff --git a/crypto/src/test_utils/test_conversation/mod.rs b/crypto/src/test_utils/test_conversation/mod.rs index c838a041de..e80c95e26b 100644 --- a/crypto/src/test_utils/test_conversation/mod.rs +++ b/crypto/src/test_utils/test_conversation/mod.rs @@ -367,10 +367,9 @@ impl<'a> TestConversation<'a> { let mls_credential_with_key = credential.to_mls_credential_with_key(); let ciphersuite = self.case.ciphersuite(); let session = self.actor().session().await; - let pki_env = session.crypto_provider.authentication_service().pki_env(); - let guard = pki_env.read().await; + let pki_env = session.crypto_provider.authentication_service().pki_env().await; let local_identity = mls_credential_with_key - .extract_identity(ciphersuite, guard.as_ref().map(|v| &**v)) + .extract_identity(ciphersuite, pki_env.as_deref()) .await .unwrap(); @@ -392,7 +391,7 @@ impl<'a> TestConversation<'a> { assert_eq!(credential.credential.identity(), &cid.0); let keystore_identity = credential - .extract_identity(ciphersuite, guard.as_ref().map(|v| &**v)) + .extract_identity(ciphersuite, pki_env.as_deref()) .await .unwrap(); assert_eq!( diff --git a/crypto/src/test_utils/x509.rs b/crypto/src/test_utils/x509.rs index b6c4cf44bf..a098867c62 100644 --- a/crypto/src/test_utils/x509.rs +++ b/crypto/src/test_utils/x509.rs @@ -273,12 +273,7 @@ impl X509TestChain { pub async fn register_with_central(&self, context: &TransactionContext) { use x509_cert::der::Encode as _; - let pki_env = context.pki_environment().await.unwrap(); - let guard = pki_env.read().await; - let env = match *guard { - None => panic!("PKI environment must be set"), - Some(ref env) => env, - }; + let env = context.pki_environment().await.unwrap(); env.add_trust_anchor("root", self.trust_anchor.certificate.clone()) .await diff --git a/crypto/src/transaction_context/credential/check.rs b/crypto/src/transaction_context/credential/check.rs index 8669a3c448..af00329425 100644 --- a/crypto/src/transaction_context/credential/check.rs +++ b/crypto/src/transaction_context/credential/check.rs @@ -9,7 +9,7 @@ use crate::{ crl::{CrlUris, extract_crl_uris_from_credentials, extract_crl_uris_from_group}, ext::CredentialExt as _, }, - transaction_context::{TransactionContext, e2e_identity}, + transaction_context::TransactionContext, }; impl TransactionContext { @@ -17,12 +17,7 @@ impl TransactionContext { /// because in case x509 credentials are used, HTTP requests are done to fetch new certificate revocation lists. pub async fn check_credentials(&self) -> Result<()> { let database = self.database().await?; - let pki_env = self.pki_environment().await?; - let guard = pki_env.read().await; - let env = match *guard { - None => return Err(e2e_identity::Error::PkiEnvironmentUnset.into()), - Some(ref env) => env, - }; + let env = self.pki_environment().await?; let credentials = Credential::get_all(&database) .await @@ -57,7 +52,7 @@ impl TransactionContext { // check our own credentials for expiration or revocation for credential in credentials { - if self.check_credential(env, &credential).await.is_err() { + if self.check_credential(&env, &credential).await.is_err() { invalid_credential_refs.push(CredentialRef::from_credential(&credential)); } } diff --git a/crypto/src/transaction_context/mod.rs b/crypto/src/transaction_context/mod.rs index 9b25e7471b..cc63d7aac1 100644 --- a/crypto/src/transaction_context/mod.rs +++ b/crypto/src/transaction_context/mod.rs @@ -9,8 +9,6 @@ pub use error::{Error, Result}; use openmls_traits::OpenMlsCryptoProvider as _; use wire_e2e_identity::pki_env::PkiEnvironment; -#[cfg(feature = "proteus")] -use crate::proteus::ProteusCentral; use crate::{ ClientId, ConversationId, CoreCrypto, CredentialFindFilters, CredentialRef, KeystoreError, MlsConversation, MlsError, MlsTransport, RecursiveError, Session, @@ -45,13 +43,9 @@ pub struct TransactionContext { #[derive(Debug, Clone)] enum TransactionContextInner { Valid { - pki_environment: Arc>>>, - database: Database, - mls_session: Arc>>>, + core_crypto: Arc, mls_groups: Arc>>, pending_epoch_changes: Arc>>, - #[cfg(feature = "proteus")] - proteus_central: Arc>>, }, Invalid, } @@ -60,15 +54,8 @@ impl CoreCrypto { /// Creates a new transaction. All operations that persist data will be /// buffered in memory and when [TransactionContext::finish] is called, the data will be persisted /// in a single database transaction. - pub async fn new_transaction(&self) -> Result { - TransactionContext::new( - self.database.clone(), - self.pki_environment.clone(), - self.mls.clone(), - #[cfg(feature = "proteus")] - self.proteus.clone(), - ) - .await + pub async fn new_transaction(self: &Arc) -> Result { + TransactionContext::new(self.clone()).await } } @@ -91,27 +78,18 @@ impl HasSessionAndCrypto for TransactionContext { } impl TransactionContext { - async fn new( - keystore: Database, - pki_environment: Arc>>>, - mls_session: Arc>>>, - #[cfg(feature = "proteus")] proteus_central: Arc>>, - ) -> Result { - keystore + async fn new(core_crypto: Arc) -> Result { + core_crypto + .database .new_transaction() .await .map_err(MlsError::wrap("creating new transaction"))?; - let mls_groups = Arc::new(RwLock::new(Default::default())); Ok(Self { inner: Arc::new( TransactionContextInner::Valid { - database: keystore, - pki_environment, - mls_session: mls_session.clone(), - mls_groups, + core_crypto, + mls_groups: Default::default(), pending_epoch_changes: Default::default(), - #[cfg(feature = "proteus")] - proteus_central, } .into(), ), @@ -120,7 +98,7 @@ impl TransactionContext { pub(crate) async fn session(&self) -> Result> { match &*self.inner.read().await { - TransactionContextInner::Valid { mls_session, .. } => mls_session.read().await.as_ref().cloned().ok_or( + TransactionContextInner::Valid { core_crypto, .. } => core_crypto.mls.read().await.as_ref().cloned().ok_or( RecursiveError::mls_client("Getting mls session from transaction context")( mls::session::Error::MlsNotInitialized, ) @@ -133,8 +111,8 @@ impl TransactionContext { #[cfg(test)] pub(crate) async fn set_session_if_exists(&self, new_session: Session) { match &*self.inner.read().await { - TransactionContextInner::Valid { mls_session, .. } => { - let mut guard = mls_session.write().await; + TransactionContextInner::Valid { core_crypto, .. } => { + let mut guard = core_crypto.mls.write().await; if guard.as_ref().is_some() { *guard = Some(new_session) @@ -146,14 +124,18 @@ impl TransactionContext { pub(crate) async fn mls_transport(&self) -> Result> { match &*self.inner.read().await { - TransactionContextInner::Valid { mls_session, .. } => { - mls_session.read().await.as_ref().map(|s| s.transport.clone()).ok_or( + TransactionContextInner::Valid { core_crypto, .. } => core_crypto + .mls + .read() + .await + .as_ref() + .map(|s| s.transport.clone()) + .ok_or( RecursiveError::mls_client("Getting mls session from transaction context")( mls::session::Error::MlsNotInitialized, ) .into(), - ) - } + ), TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext), } @@ -162,7 +144,8 @@ impl TransactionContext { /// Clones all references that the [MlsCryptoProvider] comprises. pub async fn mls_provider(&self) -> Result { match &*self.inner.read().await { - TransactionContextInner::Valid { mls_session, .. } => mls_session + TransactionContextInner::Valid { core_crypto, .. } => core_crypto + .mls .read() .await .as_ref() @@ -179,14 +162,25 @@ impl TransactionContext { pub(crate) async fn database(&self) -> Result { match &*self.inner.read().await { - TransactionContextInner::Valid { database, .. } => Ok(database.clone()), + TransactionContextInner::Valid { core_crypto, .. } => Ok(core_crypto.database.clone()), TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext), } } - pub(crate) async fn pki_environment(&self) -> Result>>>> { + pub(crate) async fn pki_environment(&self) -> Result> { match &*self.inner.read().await { - TransactionContextInner::Valid { pki_environment, .. } => Ok(pki_environment.clone()), + TransactionContextInner::Valid { core_crypto, .. } => core_crypto + .pki_environment + .read() + .await + .as_ref() + .map(Clone::clone) + .ok_or( + RecursiveError::transaction("getting PKI environment from transaction context")( + e2e_identity::Error::PkiEnvironmentUnset, + ) + .into(), + ), TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext), } } @@ -210,36 +204,28 @@ impl TransactionContext { } } - #[cfg(feature = "proteus")] - pub(crate) async fn proteus_central(&self) -> Result>>> { - match &*self.inner.read().await { - TransactionContextInner::Valid { proteus_central, .. } => Ok(proteus_central.clone()), - TransactionContextInner::Invalid => Err(Error::InvalidTransactionContext), - } - } - /// Commits the transaction, meaning it takes all the enqueued operations and persist them into /// the keystore. After that the internal state is switched to invalid, causing errors if /// something is called from this object. pub async fn finish(&self) -> Result<()> { let mut guard = self.inner.write().await; let TransactionContextInner::Valid { - database, + core_crypto, pending_epoch_changes, - mls_session, .. } = &*guard else { return Err(Error::InvalidTransactionContext); }; - let commit_result = database + let commit_result = core_crypto + .database .commit_transaction() .await .map_err(KeystoreError::wrap("commiting transaction")) .map_err(Into::into); - if let Some(session) = mls_session.read_arc().await.clone() + if let Some(session) = core_crypto.mls.read().await.as_ref() && commit_result.is_ok() { // We need owned values, so we could just clone the conversation ids, but we don't need the events anymore, @@ -261,11 +247,12 @@ impl TransactionContext { pub async fn abort(&self) -> Result<()> { let mut guard = self.inner.write().await; - let TransactionContextInner::Valid { database: keystore, .. } = &*guard else { + let TransactionContextInner::Valid { core_crypto, .. } = &*guard else { return Err(Error::InvalidTransactionContext); }; - let result = keystore + let result = core_crypto + .database .rollback_transaction() .await .map_err(KeystoreError::wrap("rolling back transaction")) @@ -278,7 +265,7 @@ impl TransactionContext { /// Initializes the MLS client of [super::CoreCrypto]. pub async fn mls_init(&self, session_id: ClientId, transport: Arc) -> Result<()> { let database = self.database().await?; - let pki_env = self.pki_environment().await?; + let pki_env = self.pki_environment().await.ok(); let crypto_provider = MlsCryptoProvider::new_with_pki_env(database.clone(), pki_env); let session = Session::new(session_id.clone(), crypto_provider, database, transport); self.set_mls_session(session).await?; @@ -289,8 +276,8 @@ impl TransactionContext { /// Set the `mls_session` Arc (also sets it on the transaction's CoreCrypto instance) pub(crate) async fn set_mls_session(&self, session: Session) -> Result<()> { match &*self.inner.read().await { - TransactionContextInner::Valid { mls_session, .. } => { - let mut guard = mls_session.write().await; + TransactionContextInner::Valid { core_crypto, .. } => { + let mut guard = core_crypto.mls.write().await; *guard = Some(session); Ok(()) } diff --git a/crypto/src/transaction_context/proteus.rs b/crypto/src/transaction_context/proteus.rs index c4db11c479..876ef36d84 100644 --- a/crypto/src/transaction_context/proteus.rs +++ b/crypto/src/transaction_context/proteus.rs @@ -1,6 +1,6 @@ //! This module contains all [super::TransactionContext] methods concerning proteus. -use super::{Error, Result, TransactionContext}; +use super::{Error, Result, TransactionContext, TransactionContextInner}; use crate::{ RecursiveError, group_store::GroupStoreValue, @@ -21,8 +21,10 @@ impl TransactionContext { .await .map_err(RecursiveError::root("getting last resort prekey"))?; - let mutex = self.proteus_central().await?; - let mut guard = mutex.lock().await; + let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else { + return Err(Error::InvalidTransactionContext); + }; + let mut guard = core_crypto.proteus.lock().await; *guard = Some(proteus_client); Ok(()) } @@ -32,9 +34,11 @@ impl TransactionContext { /// Warning: The Proteus client **MUST** be initialized with [TransactionContext::proteus_init] first or it will do /// nothing pub async fn proteus_reload_sessions(&self) -> Result<()> { - let arc = self.proteus_central().await?; - let mut mutex = arc.lock().await; - let Some(proteus) = mutex.as_mut() else { return Ok(()) }; + let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else { + return Err(Error::InvalidTransactionContext); + }; + let mut guard = core_crypto.proteus.lock().await; + let Some(proteus) = guard.as_mut() else { return Ok(()) }; let keystore = self.database().await?; proteus .reload_sessions(&keystore) @@ -52,9 +56,11 @@ impl TransactionContext { session_id: &str, prekey: &[u8], ) -> Result> { - let arc = self.proteus_central().await?; - let mut mutex = arc.lock().await; - let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?; + let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else { + return Err(Error::InvalidTransactionContext); + }; + let mut guard = core_crypto.proteus.lock().await; + let proteus = guard.as_mut().ok_or(Error::ProteusNotInitialized)?; let keystore = self.database().await?; let session = proteus .session_from_prekey(session_id, prekey) @@ -76,9 +82,11 @@ impl TransactionContext { session_id: &str, envelope: &[u8], ) -> Result<(GroupStoreValue, Vec)> { - let arc = self.proteus_central().await?; - let mut mutex = arc.lock().await; - let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?; + let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else { + return Err(Error::InvalidTransactionContext); + }; + let mut guard = core_crypto.proteus.lock().await; + let proteus = guard.as_mut().ok_or(Error::ProteusNotInitialized)?; let mut keystore = self.database().await?; let (session, message) = proteus .session_from_message(&mut keystore, session_id, envelope) @@ -96,9 +104,11 @@ impl TransactionContext { /// Warning: The Proteus client **MUST** be initialized with [TransactionContext::proteus_init] first or an error /// will be returned pub async fn proteus_session_save(&self, session_id: &str) -> Result<()> { - let arc = self.proteus_central().await?; - let mut mutex = arc.lock().await; - let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?; + let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else { + return Err(Error::InvalidTransactionContext); + }; + let mut guard = core_crypto.proteus.lock().await; + let proteus = guard.as_mut().ok_or(Error::ProteusNotInitialized)?; let keystore = self.database().await?; proteus .session_save(&keystore, session_id) @@ -112,9 +122,11 @@ impl TransactionContext { /// Warning: The Proteus client **MUST** be initialized with [TransactionContext::proteus_init] first or an error /// will be returned pub async fn proteus_session_delete(&self, session_id: &str) -> Result<()> { - let arc = self.proteus_central().await?; - let mut mutex = arc.lock().await; - let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?; + let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else { + return Err(Error::InvalidTransactionContext); + }; + let mut guard = core_crypto.proteus.lock().await; + let proteus = guard.as_mut().ok_or(Error::ProteusNotInitialized)?; let keystore = self.database().await?; proteus .session_delete(&keystore, session_id) @@ -131,9 +143,11 @@ impl TransactionContext { &self, session_id: &str, ) -> Result>> { - let arc = self.proteus_central().await?; - let mut mutex = arc.lock().await; - let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?; + let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else { + return Err(Error::InvalidTransactionContext); + }; + let mut guard = core_crypto.proteus.lock().await; + let proteus = guard.as_mut().ok_or(Error::ProteusNotInitialized)?; let keystore = self.database().await?; proteus .session(session_id, &keystore) @@ -147,9 +161,11 @@ impl TransactionContext { /// Warning: The Proteus client **MUST** be initialized with [TransactionContext::proteus_init] first or an error /// will be returned pub async fn proteus_session_exists(&self, session_id: &str) -> Result { - let arc = self.proteus_central().await?; - let mut mutex = arc.lock().await; - let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?; + let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else { + return Err(Error::InvalidTransactionContext); + }; + let mut guard = core_crypto.proteus.lock().await; + let proteus = guard.as_mut().ok_or(Error::ProteusNotInitialized)?; let keystore = self.database().await?; Ok(proteus.session_exists(session_id, &keystore).await) } @@ -159,9 +175,11 @@ impl TransactionContext { /// Warning: The Proteus client **MUST** be initialized with [TransactionContext::proteus_init] first or an error /// will be returned pub async fn proteus_decrypt(&self, session_id: &str, ciphertext: &[u8]) -> Result> { - let arc = self.proteus_central().await?; - let mut mutex = arc.lock().await; - let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?; + let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else { + return Err(Error::InvalidTransactionContext); + }; + let mut guard = core_crypto.proteus.lock().await; + let proteus = guard.as_mut().ok_or(Error::ProteusNotInitialized)?; let mut keystore = self.database().await?; proteus .decrypt(&mut keystore, session_id, ciphertext) @@ -175,9 +193,11 @@ impl TransactionContext { /// Warning: The Proteus client **MUST** be initialized with [TransactionContext::proteus_init] first or an error /// will be returned pub async fn proteus_encrypt(&self, session_id: &str, plaintext: &[u8]) -> Result> { - let arc = self.proteus_central().await?; - let mut mutex = arc.lock().await; - let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?; + let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else { + return Err(Error::InvalidTransactionContext); + }; + let mut guard = core_crypto.proteus.lock().await; + let proteus = guard.as_mut().ok_or(Error::ProteusNotInitialized)?; let mut keystore = self.database().await?; proteus .encrypt(&mut keystore, session_id, plaintext) @@ -196,9 +216,11 @@ impl TransactionContext { sessions: &[impl AsRef], plaintext: &[u8], ) -> Result>> { - let arc = self.proteus_central().await?; - let mut mutex = arc.lock().await; - let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?; + let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else { + return Err(Error::InvalidTransactionContext); + }; + let mut guard = core_crypto.proteus.lock().await; + let proteus = guard.as_mut().ok_or(Error::ProteusNotInitialized)?; let mut keystore = self.database().await?; proteus .encrypt_batched(&mut keystore, sessions, plaintext) @@ -212,9 +234,11 @@ impl TransactionContext { /// Warning: The Proteus client **MUST** be initialized with [TransactionContext::proteus_init] first or an error /// will be returned pub async fn proteus_new_prekey(&self, prekey_id: u16) -> Result> { - let arc = self.proteus_central().await?; - let mut mutex = arc.lock().await; - let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?; + let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else { + return Err(Error::InvalidTransactionContext); + }; + let mut guard = core_crypto.proteus.lock().await; + let proteus = guard.as_mut().ok_or(Error::ProteusNotInitialized)?; let keystore = self.database().await?; proteus .new_prekey(prekey_id, &keystore) @@ -229,9 +253,11 @@ impl TransactionContext { /// Warning: The Proteus client **MUST** be initialized with [TransactionContext::proteus_init] first or an error /// will be returned pub async fn proteus_new_prekey_auto(&self) -> Result<(u16, Vec)> { - let arc = self.proteus_central().await?; - let mut mutex = arc.lock().await; - let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?; + let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else { + return Err(Error::InvalidTransactionContext); + }; + let mut guard = core_crypto.proteus.lock().await; + let proteus = guard.as_mut().ok_or(Error::ProteusNotInitialized)?; let keystore = self.database().await?; proteus .new_prekey_auto(&keystore) @@ -242,9 +268,11 @@ impl TransactionContext { /// Returns the last resort prekey pub async fn proteus_last_resort_prekey(&self) -> Result> { - let arc = self.proteus_central().await?; - let mut mutex = arc.lock().await; - let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?; + let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else { + return Err(Error::InvalidTransactionContext); + }; + let mut guard = core_crypto.proteus.lock().await; + let proteus = guard.as_mut().ok_or(Error::ProteusNotInitialized)?; let keystore = self.database().await?; proteus @@ -264,9 +292,11 @@ impl TransactionContext { /// Warning: The Proteus client **MUST** be initialized with [TransactionContext::proteus_init] first or an error /// will be returned pub async fn proteus_fingerprint(&self) -> Result { - let arc = self.proteus_central().await?; - let mut mutex = arc.lock().await; - let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?; + let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else { + return Err(Error::InvalidTransactionContext); + }; + let mut guard = core_crypto.proteus.lock().await; + let proteus = guard.as_mut().ok_or(Error::ProteusNotInitialized)?; Ok(proteus.fingerprint()) } @@ -275,9 +305,11 @@ impl TransactionContext { /// Warning: The Proteus client **MUST** be initialized with [TransactionContext::proteus_init] first or an error /// will be returned pub async fn proteus_fingerprint_local(&self, session_id: &str) -> Result { - let arc = self.proteus_central().await?; - let mut mutex = arc.lock().await; - let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?; + let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else { + return Err(Error::InvalidTransactionContext); + }; + let mut guard = core_crypto.proteus.lock().await; + let proteus = guard.as_mut().ok_or(Error::ProteusNotInitialized)?; let keystore = self.database().await?; proteus .fingerprint_local(session_id, &keystore) @@ -291,9 +323,11 @@ impl TransactionContext { /// Warning: The Proteus client **MUST** be initialized with [TransactionContext::proteus_init] first or an error /// will be returned pub async fn proteus_fingerprint_remote(&self, session_id: &str) -> Result { - let arc = self.proteus_central().await?; - let mut mutex = arc.lock().await; - let proteus = mutex.as_mut().ok_or(Error::ProteusNotInitialized)?; + let TransactionContextInner::Valid { core_crypto, .. } = &*self.inner.read().await else { + return Err(Error::InvalidTransactionContext); + }; + let mut guard = core_crypto.proteus.lock().await; + let proteus = guard.as_mut().ok_or(Error::ProteusNotInitialized)?; let keystore = self.database().await?; proteus .fingerprint_remote(session_id, &keystore)