Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[resolver]
incompatible-rust-versions = "fallback"
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ serde_derive = "1.0"
serde_json = "1"
take_mut = "0.2.2"
thiserror = "1"
tokio = { version = "1", features = ["sync", "rt-multi-thread", "macros"] }
tokio = { version = "1", features = ["sync", "rt-multi-thread", "macros", "fs"] }
tonic = { version = "0.10", features = ["tls", "gzip"] }

[dev-dependencies]
Expand Down
135 changes: 118 additions & 17 deletions src/common/security.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
// Copyright 2018 TiKV Project Authors. Licensed under Apache-2.0.

use std::collections::hash_map::DefaultHasher;
use std::fs::File;
use std::hash::Hash;
use std::hash::Hasher;
use std::io::Read;
use std::path::Path;
use std::path::PathBuf;
use std::time::Duration;
use std::time::SystemTime;

use log::info;
use regex::Regex;
Expand Down Expand Up @@ -43,12 +47,12 @@ fn load_pem_file(tag: &str, path: &Path) -> Result<Vec<u8>> {
/// Manages the TLS protocol
#[derive(Default)]
pub struct SecurityManager {
/// The PEM encoding of the server’s CA certificates.
ca: Vec<u8>,
/// The PEM encoding of the server’s certificate chain.
cert: Vec<u8>,
/// The path to the PEM encoding of the server’s CA certificates.
ca_path: Option<PathBuf>,
/// The path to the PEM encoding of the server’s certificate chain.
cert_path: Option<PathBuf>,
/// The path to the file that contains the PEM encoding of the server’s private key.
key: PathBuf,
key_path: Option<PathBuf>,
}

impl SecurityManager {
Expand All @@ -58,15 +62,41 @@ impl SecurityManager {
cert_path: impl AsRef<Path>,
key_path: impl Into<PathBuf>,
) -> Result<SecurityManager> {
let ca_path = ca_path.as_ref().to_path_buf();
let cert_path = cert_path.as_ref().to_path_buf();
let key_path = key_path.into();
check_pem_file("ca", &ca_path)?;
check_pem_file("certificate", &cert_path)?;
check_pem_file("private key", &key_path)?;
Ok(SecurityManager {
ca: load_pem_file("ca", ca_path.as_ref())?,
cert: load_pem_file("certificate", cert_path.as_ref())?,
key: key_path,
ca_path: Some(ca_path),
cert_path: Some(cert_path),
key_path: Some(key_path),
})
}

pub(crate) fn tls_configured(&self) -> bool {
self.ca_path.is_some()
}

pub(crate) async fn connection_cache_key(&self) -> Result<Option<u64>> {
if !self.tls_configured() {
return Ok(None);
}

let mut hasher = DefaultHasher::new();
file_signature(self.ca_path.as_ref().expect("tls_configured checked"))
.await?
.hash(&mut hasher);
file_signature(self.cert_path.as_ref().expect("tls_configured checked"))
.await?
.hash(&mut hasher);
file_signature(self.key_path.as_ref().expect("tls_configured checked"))
.await?
.hash(&mut hasher);
Ok(Some(hasher.finish()))
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

/// Connect to gRPC server using TLS connection. If TLS is not configured, use normal connection.
pub async fn connect<Factory, Client>(
&self,
Expand All @@ -78,7 +108,7 @@ impl SecurityManager {
Factory: FnOnce(Channel) -> Client,
{
info!("connect to rpc server at endpoint: {:?}", addr);
let channel = if !self.ca.is_empty() {
let channel = if self.tls_configured() {
self.tls_channel(addr).await?
} else {
self.default_channel(addr).await?
Expand All @@ -89,18 +119,37 @@ impl SecurityManager {
}

async fn tls_channel(&self, addr: &str) -> Result<Endpoint> {
let (ca, cert, key) = self.load_tls_materials()?;
let addr = "https://".to_string() + &SCHEME_REG.replace(addr, "");
let builder = self.endpoint(addr.to_string())?;
let tls = ClientTlsConfig::new()
.ca_certificate(Certificate::from_pem(&self.ca))
.identity(Identity::from_pem(
&self.cert,
load_pem_file("private key", &self.key)?,
));
.ca_certificate(Certificate::from_pem(ca))
.identity(Identity::from_pem(cert, key));
let builder = builder.tls_config(tls)?;
Ok(builder)
}

fn load_tls_materials(&self) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> {
let ca_path = self
.ca_path
.as_ref()
.ok_or_else(|| internal_err!("TLS is not configured"))?;
let cert_path = self
.cert_path
.as_ref()
.ok_or_else(|| internal_err!("TLS is not configured"))?;
let key_path = self
.key_path
.as_ref()
.ok_or_else(|| internal_err!("TLS is not configured"))?;

Ok((
load_pem_file("ca", ca_path)?,
load_pem_file("certificate", cert_path)?,
load_pem_file("private key", key_path)?,
))
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

async fn default_channel(&self, addr: &str) -> Result<Endpoint> {
let addr = "http://".to_string() + &SCHEME_REG.replace(addr, "");
self.endpoint(addr)
Expand All @@ -114,6 +163,18 @@ impl SecurityManager {
}
}

async fn file_signature(path: &Path) -> Result<(u64, Option<u128>)> {
let metadata = tokio::fs::metadata(path)
.await
.map_err(|e| internal_err!("failed to stat {}: {:?}", path.display(), e))?;
let modified = metadata.modified().ok().and_then(|t: SystemTime| {
t.duration_since(SystemTime::UNIX_EPOCH)
.ok()
.map(|d| d.as_nanos())
});
Ok((metadata.len(), modified))
}

#[cfg(test)]
mod tests {
use std::fs::File;
Expand All @@ -140,9 +201,49 @@ mod tests {
let key_path: PathBuf = format!("{}", example_pem.display()).into();
let ca_path: PathBuf = format!("{}", example_ca.display()).into();
let mgr = SecurityManager::load(ca_path, cert_path, &key_path).unwrap();
assert_eq!(mgr.ca, vec![0]);
assert_eq!(mgr.cert, vec![1]);
let key = load_pem_file("private key", &key_path).unwrap();
assert!(mgr.tls_configured());
let (ca, cert, key) = mgr.load_tls_materials().unwrap();
assert_eq!(ca, vec![0]);
assert_eq!(cert, vec![1]);
assert_eq!(key, vec![2]);
}

#[tokio::test]
async fn test_security_reload() {
let temp = tempfile::tempdir().unwrap();
let example_ca = temp.path().join("ca");
let example_cert = temp.path().join("cert");
let example_pem = temp.path().join("key");
for (id, f) in [&example_ca, &example_cert, &example_pem]
.iter()
.enumerate()
{
File::create(f).unwrap().write_all(&[id as u8]).unwrap();
}

let mgr = SecurityManager::load(&example_ca, &example_cert, &example_pem).unwrap();
let first = mgr.load_tls_materials().unwrap();
let key1 = mgr.connection_cache_key().await.unwrap();

File::create(&example_ca)
.unwrap()
.write_all(&[9, 9])
.unwrap();
File::create(&example_cert)
.unwrap()
.write_all(&[8, 8, 8])
.unwrap();
File::create(&example_pem)
.unwrap()
.write_all(&[7, 7, 7, 7])
.unwrap();

let second = mgr.load_tls_materials().unwrap();
let key2 = mgr.connection_cache_key().await.unwrap();
assert_ne!(first, second);
assert_eq!(second.0, vec![9, 9]);
assert_eq!(second.1, vec![8, 8, 8]);
assert_eq!(second.2, vec![7, 7, 7, 7]);
assert_ne!(key1, key2);
}
}
142 changes: 134 additions & 8 deletions src/pd/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,17 @@ pub trait PdClient: Send + Sync + 'static {
pub struct PdRpcClient<KvC: KvConnect + Send + Sync + 'static = TikvConnect, Cl = Cluster> {
pd: Arc<RetryClient<Cl>>,
kv_connect: KvC,
kv_client_cache: Arc<RwLock<HashMap<String, KvC::KvClient>>>,
kv_client_cache: Arc<RwLock<HashMap<String, CachedKvClient<KvC::KvClient>>>>,
enable_codec: bool,
region_cache: RegionCache<RetryClient<Cl>>,
}

#[derive(Clone)]
struct CachedKvClient<Client> {
cache_key: Option<u64>,
client: Client,
}

#[async_trait]
impl<KvC: KvConnect + Send + Sync + 'static> PdClient for PdRpcClient<KvC> {
type KvClient = KvC::KvClient;
Expand Down Expand Up @@ -338,16 +344,26 @@ impl<KvC: KvConnect + Send + Sync + 'static, Cl> PdRpcClient<KvC, Cl> {
}

async fn kv_client(&self, address: &str) -> Result<KvC::KvClient> {
if let Some(client) = self.kv_client_cache.read().await.get(address) {
return Ok(client.clone());
};
let cache_key = self.kv_connect.connection_cache_key().await;
if let Ok(cache_key) = cache_key {
if let Some(cached) = self.kv_client_cache.read().await.get(address) {
if cached.cache_key == cache_key {
return Ok(cached.client.clone());
}
}
}
info!("connect to tikv endpoint: {:?}", address);
match self.kv_connect.connect(address).await {
Ok(client) => {
self.kv_client_cache
.write()
.await
.insert(address.to_owned(), client.clone());
if let Ok(cache_key) = cache_key {
self.kv_client_cache.write().await.insert(
address.to_owned(),
CachedKvClient {
cache_key,
client: client.clone(),
},
);
}
Ok(client)
}
Err(e) => Err(e),
Expand All @@ -364,11 +380,18 @@ fn make_key_range(start_key: Vec<u8>, end_key: Vec<u8>) -> kvrpcpb::KeyRange {

#[cfg(test)]
pub mod test {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use async_trait::async_trait;
use futures::executor;
use futures::executor::block_on;

use super::*;
use crate::mock::*;
use crate::pd::RetryClient;
use crate::store::KvConnect;
use crate::Config;

#[tokio::test]
async fn test_kv_client_caching() {
Expand All @@ -384,6 +407,109 @@ pub mod test {
assert_eq!(kv2.addr, kv3.addr);
}

#[tokio::test]
async fn test_kv_client_cache_hits_when_key_is_stable() {
#[derive(Clone)]
struct CountingConnect {
connects: Arc<AtomicUsize>,
}

#[async_trait]
impl KvConnect for CountingConnect {
type KvClient = MockKvClient;

async fn connect(&self, address: &str) -> Result<Self::KvClient> {
self.connects.fetch_add(1, Ordering::SeqCst);
let mut client = MockKvClient::default();
client.addr = address.to_owned();
Ok(client)
}

async fn connection_cache_key(&self) -> Result<Option<u64>> {
Ok(Some(0))
}
}

let connects = Arc::new(AtomicUsize::new(0));
let connects_clone = connects.clone();
let client = PdRpcClient::new(
Config::default(),
move |_| CountingConnect {
connects: connects_clone.clone(),
},
|sm| async move {
Ok(RetryClient::new_with_cluster(
sm,
Config::default().timeout,
MockCluster,
))
},
false,
)
.await
.unwrap();

let kv1 = client.kv_client("foo").await.unwrap();
let kv2 = client.kv_client("foo").await.unwrap();
assert_eq!(kv1.addr, "foo");
assert_eq!(kv2.addr, "foo");
assert_eq!(connects.load(Ordering::SeqCst), 1);
}

#[tokio::test]
async fn test_kv_client_cache_invalidate_on_key_change() {
#[derive(Clone)]
struct CountingConnect {
connects: Arc<AtomicUsize>,
cache_key: Arc<AtomicUsize>,
}

#[async_trait]
impl KvConnect for CountingConnect {
type KvClient = MockKvClient;

async fn connect(&self, address: &str) -> Result<Self::KvClient> {
self.connects.fetch_add(1, Ordering::SeqCst);
let mut client = MockKvClient::default();
client.addr = address.to_owned();
Ok(client)
}

async fn connection_cache_key(&self) -> Result<Option<u64>> {
Ok(Some(self.cache_key.load(Ordering::SeqCst) as u64))
}
}

let connects = Arc::new(AtomicUsize::new(0));
let cache_key = Arc::new(AtomicUsize::new(1));
let connects_clone = connects.clone();
let cache_key_clone = cache_key.clone();
let client = PdRpcClient::new(
Config::default(),
move |_| CountingConnect {
connects: connects_clone.clone(),
cache_key: cache_key_clone.clone(),
},
|sm| async move {
Ok(RetryClient::new_with_cluster(
sm,
Config::default().timeout,
MockCluster,
))
},
false,
)
.await
.unwrap();

let kv1 = client.kv_client("foo").await.unwrap();
cache_key.store(2, Ordering::SeqCst);
let kv2 = client.kv_client("foo").await.unwrap();
assert_eq!(kv1.addr, "foo");
assert_eq!(kv2.addr, "foo");
assert_eq!(connects.load(Ordering::SeqCst), 2);
}

#[test]
fn test_group_keys_by_region() {
let client = MockPdClient::default();
Expand Down
Loading
Loading