diff --git a/pkg/metricsservice/utils/tls.go b/pkg/metricsservice/utils/tls.go index 4c37ede1338..898e8462d65 100644 --- a/pkg/metricsservice/utils/tls.go +++ b/pkg/metricsservice/utils/tls.go @@ -21,6 +21,7 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "net" "os" "path" "strings" @@ -73,6 +74,9 @@ func LoadGrpcTLSCredentials(ctx context.Context, certDir string, server bool) (c } certMutex := sync.RWMutex{} + // currentPool holds the live CA pool and is replaced (never appended to) + // on every rotation so the pool does not grow unboundedly. + currentPool := certPool go func() { log.V(1).Info("starting mTLS certificates monitoring") for { @@ -97,22 +101,29 @@ func LoadGrpcTLSCredentials(ctx context.Context, certDir string, server bool) (c log.Error(err, "error reading grpc ca certificate") continue } - if !certPool.AppendCertsFromPEM(pemClientCA) { - log.Error(err, "failed to add client CA's certificate") + // Rebuild the pool from scratch to prevent monotonic growth. + newPool, _ := x509.SystemCertPool() + if newPool == nil { + newPool = x509.NewCertPool() + } + if !newPool.AppendCertsFromPEM(pemClientCA) { + log.Error(fmt.Errorf("failed to parse CA PEM from %s", caPath), "failed to add client CA's certificate") continue } - log.V(1).Info("grpc ca certificate has been updated") - // Load certificate of the CA who signed client's certificate + // Load the new leaf certificate *before* swapping the pool so + // we never end up trusting only the new CA while still presenting + // the old cert (inconsistent state during rotation). cert, err := tls.LoadX509KeyPair(certPath, keyPath) if err != nil { log.Error(err, "error reading grpc certificate") continue } certMutex.Lock() + currentPool = newPool mTLSCertificate = cert certMutex.Unlock() - log.V(1).Info("grpc mTLS certificate has been updated") + log.V(1).Info("grpc CA and mTLS certificate have been updated") case err, ok := <-watcher.Errors: if !ok { // Channel was closed (i.e. Watcher.Close() was called). @@ -144,10 +155,71 @@ func LoadGrpcTLSCredentials(ctx context.Context, certDir string, server bool) (c } if server { config.ClientAuth = tls.RequireAndVerifyClientCert - config.ClientCAs = certPool + // GetConfigForClient is called per-connection; injecting currentPool + // here ensures every new handshake picks up the latest CA pool. + config.GetConfigForClient = func(_ *tls.ClientHelloInfo) (*tls.Config, error) { + certMutex.RLock() + defer certMutex.RUnlock() + clone := config.Clone() + clone.ClientCAs = currentPool + // GetConfigForClient overrides GetCertificate; restore it. + clone.GetCertificate = config.GetCertificate + return clone, nil + } } else { config.RootCAs = certPool } - return credentials.NewTLS(config), nil + baseCreds := credentials.NewTLS(config) + if server { + return baseCreds, nil + } + // Wrap the client credentials so that every new TLS handshake picks up the + // latest CA pool. credentials.NewTLS snapshots the config once; by cloning + // it on each ClientHandshake we ensure that a rotated CA is trusted without + // any data-race on config.RootCAs. + return &dynamicClientCredentials{ + base: baseCreds, + baseConfig: config, + mu: &certMutex, + currentPool: ¤tPool, + }, nil +} + +// dynamicClientCredentials wraps a gRPC TransportCredentials and injects the +// latest CA pool into a fresh tls.Config clone on every client handshake. +type dynamicClientCredentials struct { + base credentials.TransportCredentials + baseConfig *tls.Config + mu *sync.RWMutex + currentPool **x509.CertPool +} + +func (d *dynamicClientCredentials) ClientHandshake(ctx context.Context, authority string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + d.mu.RLock() + cfg := d.baseConfig.Clone() + cfg.RootCAs = *d.currentPool + d.mu.RUnlock() + return credentials.NewTLS(cfg).ClientHandshake(ctx, authority, conn) +} + +func (d *dynamicClientCredentials) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return d.base.ServerHandshake(conn) +} + +func (d *dynamicClientCredentials) Info() credentials.ProtocolInfo { + return d.base.Info() +} + +func (d *dynamicClientCredentials) Clone() credentials.TransportCredentials { + return &dynamicClientCredentials{ + base: d.base.Clone(), + baseConfig: d.baseConfig.Clone(), + mu: d.mu, + currentPool: d.currentPool, + } +} + +func (d *dynamicClientCredentials) OverrideServerName(name string) error { + return d.base.OverrideServerName(name) }