Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[resolver]
incompatible-rust-versions = "fallback"
107 changes: 88 additions & 19 deletions src/common/security.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ fn load_pem_file(tag: &str, path: &Path) -> Result<Vec<u8>> {
/// Manages the TLS protocol
#[derive(Default)]
pub struct SecurityManager {
/// The PEM encoding of the server’s CA certificates.
ca: Vec<u8>,
/// The PEM encoding of the server’s certificate chain.
cert: Vec<u8>,
/// The path to the PEM encoding of the server’s CA certificates.
ca_path: Option<PathBuf>,
/// The path to the PEM encoding of the server’s certificate chain.
cert_path: Option<PathBuf>,
/// The path to the file that contains the PEM encoding of the server’s private key.
key: PathBuf,
key_path: Option<PathBuf>,
}

impl SecurityManager {
Expand All @@ -58,15 +58,23 @@ impl SecurityManager {
cert_path: impl AsRef<Path>,
key_path: impl Into<PathBuf>,
) -> Result<SecurityManager> {
let ca_path = ca_path.as_ref().to_path_buf();
let cert_path = cert_path.as_ref().to_path_buf();
let key_path = key_path.into();
check_pem_file("ca", &ca_path)?;
check_pem_file("certificate", &cert_path)?;
check_pem_file("private key", &key_path)?;
Ok(SecurityManager {
ca: load_pem_file("ca", ca_path.as_ref())?,
cert: load_pem_file("certificate", cert_path.as_ref())?,
key: key_path,
ca_path: Some(ca_path),
cert_path: Some(cert_path),
key_path: Some(key_path),
})
}

pub(crate) fn tls_configured(&self) -> bool {
self.ca_path.is_some()
}

/// Connect to gRPC server using TLS connection. If TLS is not configured, use normal connection.
pub async fn connect<Factory, Client>(
&self,
Expand All @@ -78,7 +86,7 @@ impl SecurityManager {
Factory: FnOnce(Channel) -> Client,
{
info!("connect to rpc server at endpoint: {:?}", addr);
let channel = if !self.ca.is_empty() {
let channel = if self.tls_configured() {
self.tls_channel(addr).await?
} else {
self.default_channel(addr).await?
Expand All @@ -89,18 +97,42 @@ impl SecurityManager {
}

async fn tls_channel(&self, addr: &str) -> Result<Endpoint> {
let (ca, cert, key) = self.load_tls_materials().await?;
let addr = "https://".to_string() + &SCHEME_REG.replace(addr, "");
let builder = self.endpoint(addr.to_string())?;
let tls = ClientTlsConfig::new()
.ca_certificate(Certificate::from_pem(&self.ca))
.identity(Identity::from_pem(
&self.cert,
load_pem_file("private key", &self.key)?,
));
.ca_certificate(Certificate::from_pem(ca))
.identity(Identity::from_pem(cert, key));
let builder = builder.tls_config(tls)?;
Ok(builder)
}

async fn load_tls_materials(&self) -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> {
let ca_path = self
.ca_path
.clone()
.ok_or_else(|| internal_err!("TLS is not configured"))?;
let cert_path = self
.cert_path
.clone()
.ok_or_else(|| internal_err!("TLS is not configured"))?;
let key_path = self
.key_path
.clone()
.ok_or_else(|| internal_err!("TLS is not configured"))?;

let materials =
tokio::task::spawn_blocking(move || -> Result<(Vec<u8>, Vec<u8>, Vec<u8>)> {
Ok((
load_pem_file("ca", &ca_path)?,
load_pem_file("certificate", &cert_path)?,
load_pem_file("private key", &key_path)?,
))
})
.await??;
Ok(materials)
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

async fn default_channel(&self, addr: &str) -> Result<Endpoint> {
let addr = "http://".to_string() + &SCHEME_REG.replace(addr, "");
self.endpoint(addr)
Expand All @@ -124,8 +156,8 @@ mod tests {

use super::*;

#[test]
fn test_security() {
#[tokio::test]
async fn test_security() {
let temp = tempfile::tempdir().unwrap();
let example_ca = temp.path().join("ca");
let example_cert = temp.path().join("cert");
Expand All @@ -140,9 +172,46 @@ mod tests {
let key_path: PathBuf = format!("{}", example_pem.display()).into();
let ca_path: PathBuf = format!("{}", example_ca.display()).into();
let mgr = SecurityManager::load(ca_path, cert_path, &key_path).unwrap();
assert_eq!(mgr.ca, vec![0]);
assert_eq!(mgr.cert, vec![1]);
let key = load_pem_file("private key", &key_path).unwrap();
assert!(mgr.tls_configured());
let (ca, cert, key) = mgr.load_tls_materials().await.unwrap();
assert_eq!(ca, vec![0]);
assert_eq!(cert, vec![1]);
assert_eq!(key, vec![2]);
}

#[tokio::test]
async fn test_security_reload() {
let temp = tempfile::tempdir().unwrap();
let example_ca = temp.path().join("ca");
let example_cert = temp.path().join("cert");
let example_pem = temp.path().join("key");
for (id, f) in [&example_ca, &example_cert, &example_pem]
.iter()
.enumerate()
{
File::create(f).unwrap().write_all(&[id as u8]).unwrap();
}

let mgr = SecurityManager::load(&example_ca, &example_cert, &example_pem).unwrap();
let first = mgr.load_tls_materials().await.unwrap();

File::create(&example_ca)
.unwrap()
.write_all(&[9, 9])
.unwrap();
File::create(&example_cert)
.unwrap()
.write_all(&[8, 8, 8])
.unwrap();
File::create(&example_pem)
.unwrap()
.write_all(&[7, 7, 7, 7])
.unwrap();

let second = mgr.load_tls_materials().await.unwrap();
assert_ne!(first, second);
assert_eq!(second.0, vec![9, 9]);
assert_eq!(second.1, vec![8, 8, 8]);
assert_eq!(second.2, vec![7, 7, 7, 7]);
}
}
122 changes: 112 additions & 10 deletions src/pd/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,10 @@ impl<KvC: KvConnect + Send + Sync + 'static> PdClient for PdRpcClient<KvC> {
}

async fn invalidate_store_cache(&self, store_id: StoreId) {
self.region_cache.invalidate_store_cache(store_id).await
let store = self.region_cache.invalidate_store_cache(store_id).await;
if let Some(store) = store {
self.invalidate_kv_client_cache(&store.address).await;
}
Comment on lines 276 to +280
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot May 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Address-only KV client caching still bypasses TLS rotation.

These changes only evict clients when store cache invalidation happens, but kv_client() still treats address as the entire cache key. After cert/key rotation, an existing entry for that store keeps reusing the old channel, so the new on-demand TLS reload never runs for long-lived clients. Please store a connector/TLS cache stamp alongside the client and force a reconnect when that stamp changes before returning a cache hit.

Also applies to: 343-353

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/pd/client.rs` around lines 276 - 280, The cache currently keys KV clients
only by store address so TLS cert/key rotations can be bypassed; update the KV
client cache to store and compare a TLS/connector "stamp" alongside the channel
and force a reconnect when the stamp changes. Concretely: extend the cached
entry used by kv_client() to include a tls_stamp (e.g., a version token from the
connector or TLS manager), update the client creation code path to record the
current tls_stamp when constructing a new client, modify kv_client() to check
the current tls_stamp against the cached entry and call
invalidate_kv_client_cache(&address) / recreate the client when they differ, and
ensure invalidate_store_cache (and the code referenced around
invalidate_kv_client_cache) also clears the stamp entry so subsequent
kv_client() calls will rebuild the channel with the new TLS materials.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

• Overall, the current implementation is acceptable if the goal is to tolerate TLS certificate rotation with at most one
  failure and rely on the caller to retry.

  The cache is still keyed only by store address, so an already cached long-lived gRPC/TLS client will not proactively pick
  up newly rotated certificates. However, once that stale connection fails with a gRPC transport or API error, the client
  invalidates the corresponding store/client cache entry. On the next full retry, it creates a new connection, reloads the
  TLS files, and recovers automatically without manual intervention.

  In short, the current design is reactive rather than proactive: it does not provide seamless certificate rotation or a
  zero-failure guarantee, but it is generally sufficient for a "fail once, then recover on retry" model. The main assumption
  is that the retry is a fresh end-to-end request, not reuse of an already captured stale store/client handle.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

@yongman, thanks for the clear explanation of the design intent. The reactive model — accept one failure on a stale TLS connection, let the error propagate through the existing retry infrastructure to invalidate the cache, then rebuild the channel with freshly loaded cert material — is a reasonable trade-off, particularly given the retry scaffolding already present in RetryableMultiRegion. The key pre-condition you called out (that retries are fresh end-to-end, not reuses of a captured RegionStore handle) is an important assumption to keep in mind for future changes.

I'll record this as a design note and withdraw the comment.


🏁 Script executed:

#!/bin/bash
# Resolve the review comment left by coderabbitai on this PR
gh api graphql -f query='
query {
  repository(owner: "tikv", name: "client-rust") {
    pullRequest(number: 537) {
      reviewThreads(first: 50) {
        nodes {
          id
          isResolved
          comments(first: 1) {
            nodes {
              author { login }
              body
            }
          }
        }
      }
    }
  }
}'

Length of output: 10108


🏁 Script executed:

#!/bin/bash
gh api graphql -f query='
mutation {
  resolveReviewThread(input: { threadId: "PRRT_kwDOCL4jiM6CTX__" }) {
    thread {
      id
      isResolved
    }
  }
}'

Length of output: 266


@yongman, thanks for the clear write-up of the design intent. The reactive model you've described — accept one failure, let the error path invalidate both the region and KV client cache entries, then rebuild the channel with freshly loaded TLS materials on the next retry — is a reasonable and pragmatic trade-off given the retry infrastructure already in place. The key assumption (retries are fresh end-to-end requests rather than reuses of a captured RegionStore handle) is an important invariant to maintain going forward.

I wasn't able to resolve the comment programmatically (insufficient permissions), but I'll leave that to you. Happy to withdraw the concern — the current design is acceptable for the "fail once, recover on retry" goal.


🧬 Code Graph Analysis Results

<no_relevant_code_snippets>


✏️ Learnings added
Learnt from: yongman
Repo: tikv/client-rust PR: 537
File: src/pd/client.rs:276-280
Timestamp: 2026-05-18T03:17:55.642Z
Learning: In `src/pd/client.rs` (tikv/client-rust), the KV client cache is intentionally keyed only by store address. TLS certificate rotation is handled reactively: when a stale cached gRPC channel fails with a transport/API error, the error path invalidates the corresponding store and KV client cache entries, and the next retry rebuilds the connection with freshly loaded TLS materials from disk. This is a deliberate "fail once, recover on retry" design — it does not provide zero-failure/seamless cert rotation, but avoids proactive per-request TLS stamp comparisons. The correctness assumption is that retries are always fresh end-to-end requests, not reuses of an already captured `RegionStore` handle.

Learnt from: yongman
Repo: tikv/client-rust PR: 537
File: src/pd/client.rs:276-280
Timestamp: 2026-05-18T03:17:55.642Z
Learning: In `src/pd/client.rs` (tikv/client-rust), the KV client cache (`kv_client_cache`) is intentionally keyed only by store address. TLS certificate rotation is handled **reactively**: when a stale cached gRPC channel fails (transport/API error), the error path invalidates the corresponding region-cache and KV client-cache entries, and the next retry rebuilds the connection by reloading TLS materials from disk. This is a deliberate "fail once, recover on retry" design — it does not provide zero-failure/seamless cert rotation, but avoids proactive per-request TLS stamp comparisons. The correctness assumption is that retries are always fresh end-to-end requests, not reuses of an already captured `RegionStore` handle.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.

}

async fn load_keyspace(&self, keyspace: &str) -> Result<keyspacepb::KeyspaceMeta> {
Expand Down Expand Up @@ -338,9 +341,9 @@ impl<KvC: KvConnect + Send + Sync + 'static, Cl> PdRpcClient<KvC, Cl> {
}

async fn kv_client(&self, address: &str) -> Result<KvC::KvClient> {
if let Some(client) = self.kv_client_cache.read().await.get(address) {
return Ok(client.clone());
};
if let Some(cached) = self.kv_client_cache.read().await.get(address) {
return Ok(cached.clone());
}
info!("connect to tikv endpoint: {:?}", address);
match self.kv_connect.connect(address).await {
Ok(client) => {
Expand All @@ -353,6 +356,10 @@ impl<KvC: KvConnect + Send + Sync + 'static, Cl> PdRpcClient<KvC, Cl> {
Err(e) => Err(e),
}
}

async fn invalidate_kv_client_cache(&self, address: &str) {
self.kv_client_cache.write().await.remove(address);
}
}

fn make_key_range(start_key: Vec<u8>, end_key: Vec<u8>) -> kvrpcpb::KeyRange {
Expand All @@ -364,26 +371,121 @@ fn make_key_range(start_key: Vec<u8>, end_key: Vec<u8>) -> kvrpcpb::KeyRange {

#[cfg(test)]
pub mod test {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use async_trait::async_trait;
use futures::executor;
use futures::executor::block_on;

use super::*;
use crate::mock::*;
use crate::pd::RetryClient;
use crate::store::KvConnect;
use crate::Config;

#[tokio::test]
async fn test_kv_client_caching() {
let client = block_on(pd_rpc_client());

let addr1 = "foo";
let addr2 = "bar";

let kv1 = client.kv_client(addr1).await.unwrap();
let kv2 = client.kv_client(addr2).await.unwrap();
let kv3 = client.kv_client(addr2).await.unwrap();
let kv1 = client.kv_client("foo").await.unwrap();
let kv2 = client.kv_client("bar").await.unwrap();
let kv3 = client.kv_client("bar").await.unwrap();
assert!(kv1.addr != kv2.addr);
assert_eq!(kv2.addr, kv3.addr);
}

#[tokio::test]
async fn test_kv_client_cache_hits_lazily() {
#[derive(Clone)]
struct CountingConnect {
connects: Arc<AtomicUsize>,
}

#[async_trait]
impl KvConnect for CountingConnect {
type KvClient = MockKvClient;

async fn connect(&self, address: &str) -> Result<Self::KvClient> {
self.connects.fetch_add(1, Ordering::SeqCst);
let mut client = MockKvClient::default();
client.addr = address.to_owned();
Ok(client)
}
}

let connects = Arc::new(AtomicUsize::new(0));
let connects_clone = connects.clone();
let client = PdRpcClient::new(
Config::default(),
move |_| CountingConnect {
connects: connects_clone.clone(),
},
|sm| async move {
Ok(RetryClient::new_with_cluster(
sm,
Config::default().timeout,
MockCluster,
))
},
false,
)
.await
.unwrap();

let kv1 = client.kv_client("foo").await.unwrap();
let kv2 = client.kv_client("foo").await.unwrap();
assert_eq!(kv1.addr, "foo");
assert_eq!(kv2.addr, "foo");
assert_eq!(connects.load(Ordering::SeqCst), 1);
}

#[tokio::test]
async fn test_kv_client_cache_reconnects_after_invalidation() {
#[derive(Clone)]
struct CountingConnect {
connects: Arc<AtomicUsize>,
}

#[async_trait]
impl KvConnect for CountingConnect {
type KvClient = MockKvClient;

async fn connect(&self, address: &str) -> Result<Self::KvClient> {
self.connects.fetch_add(1, Ordering::SeqCst);
let mut client = MockKvClient::default();
client.addr = address.to_owned();
Ok(client)
}
}

let connects = Arc::new(AtomicUsize::new(0));
let connects_clone = connects.clone();
let client = PdRpcClient::new(
Config::default(),
move |_| CountingConnect {
connects: connects_clone.clone(),
},
|sm| async move {
Ok(RetryClient::new_with_cluster(
sm,
Config::default().timeout,
MockCluster,
))
},
false,
)
.await
.unwrap();

let kv1 = client.kv_client("foo").await.unwrap();
client.invalidate_kv_client_cache("foo").await;
let kv2 = client.kv_client("foo").await.unwrap();
assert_eq!(kv1.addr, "foo");
assert_eq!(kv2.addr, "foo");
assert_eq!(connects.load(Ordering::SeqCst), 2);
}

#[test]
fn test_group_keys_by_region() {
let client = MockPdClient::default();
Expand Down
4 changes: 2 additions & 2 deletions src/region_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@ impl<C: RetryClientTrait> RegionCache<C> {
}
}

pub async fn invalidate_store_cache(&self, store_id: StoreId) {
pub async fn invalidate_store_cache(&self, store_id: StoreId) -> Option<Store> {
let mut cache = self.store_cache.write().await;
cache.remove(&store_id);
cache.remove(&store_id)
}

pub async fn read_through_all_stores(&self) -> Result<Vec<Store>> {
Expand Down
Loading
Loading