Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crypto-ffi/src/core_crypto/e2ei/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub(crate) mod identities;
impl CoreCryptoFfi {
/// Returns true if the PKI environment has been set up and its provider is configured.
pub async fn e2ei_is_pki_env_setup(&self) -> bool {
self.inner.get_pki_environment().read().await.is_some()
self.inner.get_pki_environment().await.is_some()
}

/// Returns true if end-to-end identity is enabled for the given ciphersuite.
Expand Down
2 changes: 1 addition & 1 deletion crypto-ffi/src/core_crypto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{CoreCryptoResult, Database};
/// CoreCrypto wraps around MLS and Proteus implementations and provides a transactional interface for each.
#[derive(Debug, uniffi::Object)]
pub struct CoreCryptoFfi {
pub(crate) inner: core_crypto::CoreCrypto,
pub(crate) inner: Arc<core_crypto::CoreCrypto>,
}

/// Construct a new `CoreCryptoFfi` instance.
Expand Down
8 changes: 4 additions & 4 deletions crypto-ffi/src/pki_env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,9 @@ impl CoreCryptoFfi {
///
/// Returns null if it is not set.
pub async fn get_pki_environment(&self) -> Option<Arc<PkiEnvironment>> {
let pki_env = self.inner.get_pki_environment();
(*pki_env.read().await)
.as_ref()
.map(|env| Arc::new(PkiEnvironment(env.clone())))
self.inner
.get_pki_environment()
.await
.map(|inner| Arc::new(PkiEnvironment(inner)))
}
}
2 changes: 1 addition & 1 deletion crypto/src/ephemeral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ impl CoreCrypto {
///
/// This client exposes the full interface of `CoreCrypto`, but it should only be used to decrypt messages.
/// Other use is a logic error.
pub async fn history_client(history_secret: HistorySecret) -> Result<Self> {
pub async fn history_client(history_secret: HistorySecret) -> Result<Arc<Self>> {
if !history_secret
.client_id
.starts_with(HISTORY_CLIENT_ID_PREFIX.as_bytes())
Expand Down
24 changes: 10 additions & 14 deletions crypto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,45 +123,41 @@ impl MlsTransport for CoreCryptoTransportNotImplementedProvider {
///
/// As [std::ops::Deref] is implemented, this struct is automatically dereferred to [mls::session::Session] apart from
/// `proteus_*` calls
///
/// This is cheap to clone as all internal members have `Arc` wrappers or are `Copy`.
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct CoreCrypto {
database: Database,
pki_environment: Arc<RwLock<Option<Arc<PkiEnvironment>>>>,
mls: Arc<RwLock<Option<mls::session::Session<Database>>>>,
pki_environment: RwLock<Option<Arc<PkiEnvironment>>>,
mls: RwLock<Option<mls::session::Session<Database>>>,
#[cfg(feature = "proteus")]
proteus: Arc<Mutex<Option<proteus::ProteusCentral>>>,
proteus: Mutex<Option<proteus::ProteusCentral>>,
#[cfg(not(feature = "proteus"))]
#[allow(dead_code)]
proteus: (),
}

impl CoreCrypto {
/// Create an new CoreCrypto client without any initialized session.
pub fn new(database: Database) -> Self {
pub fn new(database: Database) -> Arc<Self> {
Self {
database,
pki_environment: Default::default(),
mls: Default::default(),
proteus: Default::default(),
}
.into()
}

/// Set the session's PKI Environment
pub async fn set_pki_environment(&self, pki_environment: Option<Arc<PkiEnvironment>>) {
*self.pki_environment.write().await = pki_environment;
*self.pki_environment.write().await = pki_environment.clone();
if let Some(mls_session) = self.mls.write().await.as_mut() {
mls_session
.crypto_provider
.set_pki_environment(self.pki_environment.clone())
.await;
mls_session.crypto_provider.set_pki_environment(pki_environment).await;
}
}

/// Get the session's PKI Environment
pub fn get_pki_environment(&self) -> Arc<RwLock<Option<Arc<PkiEnvironment>>>> {
self.pki_environment.clone()
pub async fn get_pki_environment(&self) -> Option<Arc<PkiEnvironment>> {
self.pki_environment.read().await.clone()
}

/// Get the mls session if initialized
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,9 @@ impl ConversationGuard {
let credential = message.credential();
let epoch = message.epoch();

let pki_env = provider.authentication_service().pki_env();
let guard = pki_env.read().await;
let pki_env = provider.authentication_service().pki_env().await;
let identity = credential
.extract_identity(self.ciphersuite().await, guard.as_ref().map(|v| &**v))
.extract_identity(self.ciphersuite().await, pki_env.as_deref())
.await
.map_err(RecursiveError::mls_credential("extracting identity"))?;

Expand Down
10 changes: 4 additions & 6 deletions crypto/src/mls/conversation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,14 @@ pub trait Conversation<'a>: ConversationWithMls<'a> {
let auth_service = mls_provider.authentication_service();
let conversation = self.conversation().await;

let pki_env = auth_service.pki_env();
let guard = pki_env.read().await;
let pki_env = auth_service.pki_env().await;

let mut identities = vec![];
for (id, credential) in conversation.members_with_key() {
if device_ids.iter().any(|client_id| client_id.borrow() == id) {
identities.push(
credential
.extract_identity(conversation.ciphersuite(), guard.as_ref().map(|v| &**v))
.extract_identity(conversation.ciphersuite(), pki_env.as_deref())
.await
.map_err(RecursiveError::mls_credential("extracting identity"))?,
);
Expand All @@ -255,8 +254,7 @@ pub trait Conversation<'a>: ConversationWithMls<'a> {
let conversation = self.conversation().await;
let user_ids = user_ids.iter().map(|uid| uid.as_bytes()).collect::<Vec<_>>();

let pki_env = auth_service.pki_env();
let guard = pki_env.read().await;
let pki_env = auth_service.pki_env().await;

let mut identities = HashMap::new();
for (id, credential) in conversation.members_with_key() {
Expand All @@ -271,7 +269,7 @@ pub trait Conversation<'a>: ConversationWithMls<'a> {

let uid = String::try_from(uid).map_err(RecursiveError::mls_client("getting user identities"))?;
let identity = credential
.extract_identity(conversation.ciphersuite(), guard.as_ref().map(|v| &**v))
.extract_identity(conversation.ciphersuite(), pki_env.as_deref())
.await
.map_err(RecursiveError::mls_credential("extracting identity"))?;
let value = identities.entry(uid).or_insert_with(Vec::new);
Expand Down
5 changes: 2 additions & 3 deletions crypto/src/mls/conversation/own_commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,9 @@ impl MlsConversation {
credential: own_leaf.credential().clone(),
signature_key: own_leaf.signature_key().clone(),
};
let pki_env = provider.authentication_service().pki_env();
let guard = pki_env.read().await;
let pki_env = provider.authentication_service().pki_env().await;
let identity = own_leaf_credential_with_key
.extract_identity(self.ciphersuite(), guard.as_ref().map(|v| &**v))
.extract_identity(self.ciphersuite(), pki_env.as_deref())
.await
.map_err(RecursiveError::mls_credential("extracting identity"))?;

Expand Down
9 changes: 2 additions & 7 deletions crypto/src/mls/conversation/pending_conversation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,10 @@ impl PendingConversation {
credential: own_leaf.credential().clone(),
signature_key: own_leaf.signature_key().clone(),
};
let pki_env = self
.context
.pki_environment()
.await
.map_err(RecursiveError::transaction("getting PKI environment"))?;
let guard = pki_env.read().await;
let pki_env = self.context.pki_environment().await.ok();

let identity = own_leaf_credential_with_key
.extract_identity(conversation.ciphersuite(), guard.as_ref().map(|v| &**v))
.extract_identity(conversation.ciphersuite(), pki_env.as_deref())
.await
.map_err(RecursiveError::mls_credential("extracting identity"))?;

Expand Down
5 changes: 2 additions & 3 deletions crypto/src/mls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,9 @@ mod tests {
CertificateBundle::rand_identifier(&session_id, &[x509_test_chain.find_local_intermediate_ca()])
}
};
let pki_env = cc.get_pki_environment();
let guard = pki_env.read().await;
let pki_env = cc.get_pki_environment().await;
let session_id = identifier
.get_id(guard.as_ref().map(|v| &**v))
.get_id(pki_env.as_deref())
.await
.expect("get session_id from identifier")
.into_owned();
Expand Down
10 changes: 4 additions & 6 deletions crypto/src/mls/session/e2e_identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,9 @@ impl<D> Session<D> {
_credential_type: CredentialType,
auth_service: &AuthenticationService,
) -> E2eiConversationState {
let pki_env = auth_service.pki_env();
let guard = pki_env.read().await;
let env = match *guard {
Some(ref env) => env,
None => return E2eiConversationState::NotEnabled,
let pki_env = auth_service.pki_env().await;
let Some(env) = pki_env else {
return E2eiConversationState::NotEnabled;
};

let mut is_e2ei = false;
Expand All @@ -136,7 +134,7 @@ impl<D> Session<D> {

is_e2ei = true;

let invalid_identity = cert.extract_identity(env, ciphersuite.e2ei_hash_alg()).await.is_err();
let invalid_identity = cert.extract_identity(&env, ciphersuite.e2ei_hash_alg()).await.is_err();

use openmls_x509_credential::X509Ext as _;
let is_time_valid = cert.is_time_valid().unwrap_or(false);
Expand Down
29 changes: 19 additions & 10 deletions crypto/src/mls_provider/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,22 @@ impl std::ops::DerefMut for EntropySeed {
}
}

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct AuthenticationService {
pki_env: Arc<RwLock<Option<Arc<PkiEnvironment>>>>,
/// The PKI Environment type is complicated, but it's all necessary:
///
/// - The inner `Arc` derives from two facts: the PKI environment is provided across FFI, and it's `!Clone`, so we
/// have to retain that `Arc` because the foreign environment is more-or-less guaranteed to have kept a reference
/// to it.
/// - The `Option` is there because the PKI environment is initially unset and may never be set, according to
/// client behavior.
/// - The `RwLock` is there because we need to be able to set the PKI environment, implying interior mutability.
pki_env: RwLock<Option<Arc<PkiEnvironment>>>,
}

impl AuthenticationService {
pub fn pki_env(&self) -> Arc<RwLock<Option<Arc<PkiEnvironment>>>> {
self.pki_env.clone()
pub async fn pki_env(&self) -> Option<Arc<PkiEnvironment>> {
self.pki_env.read().await.clone()
}
}

Expand Down Expand Up @@ -105,7 +113,7 @@ impl openmls_traits::authentication_service::AuthenticationServiceDelegate for A
pub struct MlsCryptoProvider {
crypto: Arc<RustCrypto>,
key_store: Database,
auth_service: AuthenticationService,
auth_service: Arc<AuthenticationService>,
}

impl MlsCryptoProvider {
Expand All @@ -115,12 +123,13 @@ impl MlsCryptoProvider {
///
/// - [Database::open]
pub fn new(key_store: Database) -> Self {
Self::new_with_pki_env(key_store, Arc::new(RwLock::new(None)))
Self::new_with_pki_env(key_store, None)
}

/// Construct a crypto provider with the given database and the PKI environment.
pub fn new_with_pki_env(key_store: Database, pki_env: Arc<RwLock<Option<Arc<PkiEnvironment>>>>) -> Self {
let auth_service = AuthenticationService { pki_env };
pub fn new_with_pki_env(key_store: Database, pki_env: Option<Arc<PkiEnvironment>>) -> Self {
let pki_env = RwLock::new(pki_env);
let auth_service = Arc::new(AuthenticationService { pki_env });
Self {
key_store,
crypto: Arc::clone(&CRYPTO),
Expand All @@ -135,8 +144,8 @@ impl MlsCryptoProvider {
}

/// Set pki_env to a new shared pki environment provider
pub async fn set_pki_environment(&mut self, pki_env: Arc<RwLock<Option<Arc<PkiEnvironment>>>>) {
self.auth_service.pki_env = pki_env;
pub async fn set_pki_environment(&mut self, pki_env: Option<Arc<PkiEnvironment>>) {
*self.auth_service.pki_env.write().await = pki_env;
}

/// Returns whether we have a PKI env setup
Expand Down
9 changes: 4 additions & 5 deletions crypto/src/proteus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ mod tests {
.await
.unwrap();

let cc: CoreCrypto = CoreCrypto::new(db);
let cc = CoreCrypto::new(db);
let context = cc.new_transaction().await.unwrap();
assert!(context.proteus_init().await.is_ok());
assert!(context.proteus_new_prekey(1).await.is_ok());
Expand All @@ -626,7 +626,7 @@ mod tests {
.await
.unwrap();

let cc: CoreCrypto = CoreCrypto::new(db.clone());
let cc = CoreCrypto::new(db.clone());
let hooks = Arc::new(DummyPkiEnvironmentHooks);
let pki_env = PkiEnvironment::new(hooks, db).await.expect("creating pki environment");
cc.set_pki_environment(Some(Arc::new(pki_env))).await;
Expand All @@ -645,10 +645,9 @@ mod tests {
CertificateBundle::rand_identifier(&session_id, &[x509_test_chain.find_local_intermediate_ca()])
}
};
let pki_env = cc.get_pki_environment();
let guard = pki_env.read().await;
let pki_env = cc.get_pki_environment().await;
let session_id = identifier
.get_id(guard.as_ref().map(|v| &**v))
.get_id(pki_env.as_deref())
.await
.expect("Getting session id from identifier")
.into_owned();
Expand Down
11 changes: 5 additions & 6 deletions crypto/src/test_utils/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,11 @@ impl SessionContext {
&expected_credential.mls_credential().mls_credential()
{
let session = self.session().await;
let pki_env = session.crypto_provider.authentication_service().pki_env();
let guard = pki_env.read().await;
assert!(guard.as_ref().is_some());
let pki_env = session.crypto_provider.authentication_service().pki_env().await;
assert!(pki_env.is_some());

let mls_identity = certificate
.extract_identity(case.ciphersuite(), guard.as_ref().map(|v| &**v))
.extract_identity(case.ciphersuite(), pki_env.as_deref())
.await
.unwrap();
let mls_client_id = mls_identity.client_id.as_bytes();
Expand All @@ -246,7 +245,7 @@ impl SessionContext {
let leaf: Vec<u8> = certificate.certificates.first().unwrap().clone().into();
let identity = leaf
.as_slice()
.extract_identity(guard.as_ref().unwrap(), case.ciphersuite().e2ei_hash_alg())
.extract_identity(pki_env.as_deref().unwrap(), case.ciphersuite().e2ei_hash_alg())
.await
.unwrap();
let identity = WireIdentity::try_from((identity, leaf.as_slice())).unwrap();
Expand All @@ -273,7 +272,7 @@ impl SessionContext {
let chain = x509_cert::Certificate::load_pem_chain(decrypted_x509_identity.certificate.as_bytes()).unwrap();
let leaf = chain.first().unwrap();
let cert_identity = leaf
.extract_identity(guard.as_ref().unwrap(), case.ciphersuite().e2ei_hash_alg())
.extract_identity(pki_env.as_deref().unwrap(), case.ciphersuite().e2ei_hash_alg())
.await
.unwrap();

Expand Down
Loading
Loading