Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 79 additions & 7 deletions pkg/metricsservice/utils/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os"
"path"
"strings"
Expand Down Expand Up @@ -73,6 +74,9 @@ func LoadGrpcTLSCredentials(ctx context.Context, certDir string, server bool) (c
}

certMutex := sync.RWMutex{}
// currentPool holds the live CA pool and is replaced (never appended to)
// on every rotation so the pool does not grow unboundedly.
currentPool := certPool
go func() {
log.V(1).Info("starting mTLS certificates monitoring")
for {
Expand All @@ -97,22 +101,29 @@ func LoadGrpcTLSCredentials(ctx context.Context, certDir string, server bool) (c
log.Error(err, "error reading grpc ca certificate")
continue
}
if !certPool.AppendCertsFromPEM(pemClientCA) {
log.Error(err, "failed to add client CA's certificate")
// Rebuild the pool from scratch to prevent monotonic growth.
newPool, _ := x509.SystemCertPool()
if newPool == nil {
newPool = x509.NewCertPool()
}
if !newPool.AppendCertsFromPEM(pemClientCA) {
log.Error(fmt.Errorf("failed to parse CA PEM from %s", caPath), "failed to add client CA's certificate")
continue
}
log.V(1).Info("grpc ca certificate has been updated")

// Load certificate of the CA who signed client's certificate
// Load the new leaf certificate *before* swapping the pool so
// we never end up trusting only the new CA while still presenting
// the old cert (inconsistent state during rotation).
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
log.Error(err, "error reading grpc certificate")
continue
}
certMutex.Lock()
currentPool = newPool
mTLSCertificate = cert
certMutex.Unlock()
log.V(1).Info("grpc mTLS certificate has been updated")
log.V(1).Info("grpc CA and mTLS certificate have been updated")

case err, ok := <-watcher.Errors:
if !ok { // Channel was closed (i.e. Watcher.Close() was called).
Expand Down Expand Up @@ -144,10 +155,71 @@ func LoadGrpcTLSCredentials(ctx context.Context, certDir string, server bool) (c
}
if server {
config.ClientAuth = tls.RequireAndVerifyClientCert
config.ClientCAs = certPool
// GetConfigForClient is called per-connection; injecting currentPool
// here ensures every new handshake picks up the latest CA pool.
config.GetConfigForClient = func(_ *tls.ClientHelloInfo) (*tls.Config, error) {
certMutex.RLock()
defer certMutex.RUnlock()
clone := config.Clone()
clone.ClientCAs = currentPool
// GetConfigForClient overrides GetCertificate; restore it.
clone.GetCertificate = config.GetCertificate
return clone, nil
}
} else {
config.RootCAs = certPool
}

return credentials.NewTLS(config), nil
baseCreds := credentials.NewTLS(config)
if server {
return baseCreds, nil
}
// Wrap the client credentials so that every new TLS handshake picks up the
// latest CA pool. credentials.NewTLS snapshots the config once; by cloning
// it on each ClientHandshake we ensure that a rotated CA is trusted without
// any data-race on config.RootCAs.
return &dynamicClientCredentials{
base: baseCreds,
baseConfig: config,
mu: &certMutex,
currentPool: &currentPool,
}, nil
}

// dynamicClientCredentials wraps a gRPC TransportCredentials and injects the
// latest CA pool into a fresh tls.Config clone on every client handshake.
type dynamicClientCredentials struct {
base credentials.TransportCredentials
baseConfig *tls.Config
mu *sync.RWMutex
currentPool **x509.CertPool
}

func (d *dynamicClientCredentials) ClientHandshake(ctx context.Context, authority string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
d.mu.RLock()
cfg := d.baseConfig.Clone()
cfg.RootCAs = *d.currentPool
d.mu.RUnlock()
return credentials.NewTLS(cfg).ClientHandshake(ctx, authority, conn)
}

func (d *dynamicClientCredentials) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return d.base.ServerHandshake(conn)
}

func (d *dynamicClientCredentials) Info() credentials.ProtocolInfo {
return d.base.Info()
}

func (d *dynamicClientCredentials) Clone() credentials.TransportCredentials {
return &dynamicClientCredentials{
base: d.base.Clone(),
baseConfig: d.baseConfig.Clone(),
mu: d.mu,
currentPool: d.currentPool,
}
}

func (d *dynamicClientCredentials) OverrideServerName(name string) error {
return d.base.OverrideServerName(name)
}
Loading