diff --git a/src/gax-internal/src/grpc.rs b/src/gax-internal/src/grpc.rs index 31c0296a41..5fe4e9a207 100644 --- a/src/gax-internal/src/grpc.rs +++ b/src/gax-internal/src/grpc.rs @@ -19,6 +19,7 @@ pub mod status; pub mod tonic; use crate::observability::attributes::{self, keys::*, otel_status_codes}; +use crate::universe_domain::DEFAULT_UNIVERSE_DOMAIN; use ::tonic::client::Grpc; use ::tonic::transport::Channel; use from_status::to_gax_error; @@ -102,9 +103,18 @@ impl Client { ) -> ClientBuilderResult { let credentials = Self::make_credentials(&config).await?; let tracing_enabled = crate::options::tracing_enabled(&config); + let universe_domain = + crate::universe_domain::resolve(config.universe_domain.as_deref(), &credentials) + .await?; - let (inner, tracing_attributes) = - Self::make_inner(&config, default_endpoint, tracing_enabled, instrumentation).await?; + let (inner, tracing_attributes) = Self::make_inner( + &config, + default_endpoint, + tracing_enabled, + &universe_domain, + instrumentation, + ) + .await?; Ok(Self { inner, @@ -428,12 +438,14 @@ impl Client { config: &crate::options::ClientConfig, default_endpoint: &str, tracing_enabled: bool, + universe_domain: &str, instrumentation: Option<&'static crate::options::InstrumentationClientInfo>, ) -> ClientBuilderResult<(InnerClient, Option)> { use ::tonic::transport::{Channel, channel::Change}; let endpoint = Self::make_endpoint( config.endpoint.clone(), default_endpoint, + universe_domain, config.grpc_max_header_list_size, ) .await?; @@ -478,15 +490,16 @@ impl Client { async fn make_endpoint( endpoint: Option, default_endpoint: &str, + universe_domain: &str, grpc_max_header_list_size: Option, ) -> ClientBuilderResult<::tonic::transport::Endpoint> { use ::tonic::transport::{ClientTlsConfig, Endpoint}; - let origin = crate::host::origin(endpoint.as_deref(), default_endpoint) + let service_endpoint = default_endpoint.replace(DEFAULT_UNIVERSE_DOMAIN, universe_domain); + let origin = crate::host::origin(endpoint.as_deref(), default_endpoint, universe_domain) .map_err(|e| e.client_builder())?; - let endpoint = - Endpoint::from_shared(endpoint.unwrap_or_else(|| default_endpoint.to_string())) - .map_err(BuilderError::transport)?; + let target_endpoint = endpoint.unwrap_or(service_endpoint); + let endpoint = Endpoint::from_shared(target_endpoint).map_err(BuilderError::transport)?; let endpoint = if endpoint .uri() .scheme() @@ -619,8 +632,50 @@ where #[cfg(test)] mod tests { - use super::Client; + use super::*; use crate::options::InstrumentationClientInfo; + use test_case::test_case; + + type TestResult = anyhow::Result<()>; + + #[tokio::test] + #[test_case(None, "my-universe-domain.com", "https://language.my-universe-domain.com/"; "default endpoint")] + #[test_case(Some("https://yet-another-universe-domain.com/"), "yet-another-universe-domain.com", "https://yet-another-universe-domain.com/"; "custom endpoint override")] + #[test_case(Some("https://rep.language.googleapis.com/"), "my-universe-domain.com", "https://rep.language.googleapis.com/"; "regional endpoint with universe domain")] + #[test_case(Some("https://us-central1-language.googleapis.com/"), "my-universe-domain.com", "https://us-central1-language.googleapis.com/"; "locational endpoint with universe domain")] + async fn make_endpoint_with_universe_domain( + endpoint_override: Option<&str>, + universe_domain: &str, + expected_uri: &str, + ) -> TestResult { + let default_endpoint = "https://language.googleapis.com"; + let endpoint = Client::make_endpoint( + endpoint_override.map(String::from), + default_endpoint, + universe_domain, + None, + ) + .await?; + + assert_eq!(endpoint.uri().to_string(), expected_uri); + + Ok(()) + } + + #[tokio::test] + async fn make_endpoint_with_universe_domain_mismatch() -> TestResult { + let mut config = crate::options::ClientConfig::default(); + config.universe_domain = Some("my-universe-domain.com".to_string()); + config.cred = Some(google_cloud_auth::credentials::anonymous::Builder::new().build()); + + let err = Client::new(config, "https://language.googleapis.com") + .await + .unwrap_err(); + + assert!(err.is_universe_domain_mismatch(), "{err:?}"); + + Ok(()) + } #[tokio::test(flavor = "multi_thread")] async fn test_new_with_instrumentation() { diff --git a/src/gax-internal/src/host.rs b/src/gax-internal/src/host.rs index a5dbefabc8..e6c00e0a10 100644 --- a/src/gax-internal/src/host.rs +++ b/src/gax-internal/src/host.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::universe_domain::DEFAULT_UNIVERSE_DOMAIN; use google_cloud_gax::client_builder::Error as BuilderError; #[cfg(any(test, feature = "_internal-http-client"))] use google_cloud_gax::error::Error; @@ -23,8 +24,12 @@ use std::str::FromStr; /// Notably, locational and regional endpoints are detected and used as the /// host. For VIPs and private networks, we need to use the default host. #[cfg(any(test, feature = "_internal-http-client"))] -pub(crate) fn header(endpoint: Option<&str>, default_endpoint: &str) -> Result { - origin_and_header(endpoint, default_endpoint).map(|(_, header)| header) +pub(crate) fn header( + endpoint: Option<&str>, + default_endpoint: &str, + universe_domain: &str, +) -> Result { + origin_and_header(endpoint, default_endpoint, universe_domain).map(|(_, header)| header) } /// Calculate the gRPC authority given the endpoint and default endpoint. @@ -34,15 +39,21 @@ pub(crate) fn header(endpoint: Option<&str>, default_endpoint: &str) -> Result, default_endpoint: &str) -> Result { - origin_and_header(endpoint, default_endpoint).map(|(origin, _)| origin) +pub(crate) fn origin( + endpoint: Option<&str>, + default_endpoint: &str, + universe_domain: &str, +) -> Result { + origin_and_header(endpoint, default_endpoint, universe_domain).map(|(origin, _)| origin) } fn origin_and_header( endpoint: Option<&str>, default_endpoint: &str, + universe_domain: &str, ) -> Result<(Uri, String), HostError> { - let default_origin = Uri::from_str(default_endpoint).map_err(HostError::Uri)?; + let service_endpoint = default_endpoint.replace(DEFAULT_UNIVERSE_DOMAIN, &universe_domain); + let default_origin = Uri::from_str(&service_endpoint).map_err(HostError::Uri)?; let default_host = default_origin .authority() .expect("missing authority in default endpoint") @@ -58,9 +69,11 @@ fn origin_and_header( .ok_or_else(|| HostError::MissingAuthority(endpoint.to_string()))? .host() .to_string(); + + let universe_suffix = format!(".{universe_domain}"); let (Some(prefix), Some(service)) = ( - custom_host.strip_suffix(".googleapis.com"), - default_host.strip_suffix(".googleapis.com"), + custom_host.strip_suffix(&universe_suffix), + default_host.strip_suffix(&universe_suffix), ) else { return Ok((default_origin, default_host)); }; @@ -115,6 +128,7 @@ impl HostError { #[cfg(test)] mod tests { use super::*; + use crate::universe_domain::DEFAULT_UNIVERSE_DOMAIN; use std::error::Error as _; use test_case::test_case; @@ -131,7 +145,11 @@ mod tests { #[test_case("localhost:5678", "test.googleapis.com"; "emulator")] #[test_case("https://localhost:5678", "test.googleapis.com"; "emulator with scheme")] fn header_success(input: &str, want: &str) -> anyhow::Result<()> { - let got = header(Some(input), "https://test.googleapis.com")?; + let got = header( + Some(input), + "https://test.googleapis.com", + DEFAULT_UNIVERSE_DOMAIN, + )?; assert_eq!(got, want, "input={input:?}"); Ok(()) } @@ -144,7 +162,23 @@ mod tests { #[test_case("localhost:5678", "localhost"; "emulator")] #[test_case("https://localhost:5678", "localhost"; "emulator with scheme")] fn header_default(input: &str, want: &str) -> anyhow::Result<()> { - let got = header(None, input)?; + let got = header(None, input, DEFAULT_UNIVERSE_DOMAIN)?; + assert_eq!(got, want, "input={input:?}"); + Ok(()) + } + + #[test_case("http://www.my-custom-universe.com", "test.my-custom-universe.com"; "global")] + #[test_case("http://private.my-custom-universe.com", "test.my-custom-universe.com"; "VPC-SC private")] + #[test_case("http://restricted.my-custom-universe.com", "test.my-custom-universe.com"; "VPC-SC restricted")] + #[test_case("http://test-my-private-ep.p.my-custom-universe.com", "test.my-custom-universe.com"; "PSC custom endpoint")] + #[test_case("https://us-central1-test.my-custom-universe.com", "us-central1-test.my-custom-universe.com"; "locational endpoint")] + #[test_case("https://test.us-central1.rep.my-custom-universe.com", "test.us-central1.rep.my-custom-universe.com"; "regional endpoint")] + fn header_universe_domain(input: &str, want: &str) -> anyhow::Result<()> { + let got = header( + Some(input), + "https://test.googleapis.com", + "my-custom-universe.com", + )?; assert_eq!(got, want, "input={input:?}"); Ok(()) } @@ -162,7 +196,11 @@ mod tests { #[test_case("localhost:5678", "https://test.googleapis.com"; "emulator")] #[test_case("http://localhost:5678", "https://test.googleapis.com"; "emulator with scheme")] fn origin_success(input: &str, want: &str) -> anyhow::Result<()> { - let got = origin(Some(input), "https://test.googleapis.com")?; + let got = origin( + Some(input), + "https://test.googleapis.com", + DEFAULT_UNIVERSE_DOMAIN, + )?; assert_eq!(got, want, "input={input:?}"); Ok(()) } @@ -175,21 +213,45 @@ mod tests { #[test_case("https://localhost:5678", "https://localhost:5678")] #[test_case("http://localhost:5678", "http://localhost:5678")] fn origin_default(input: &str, want: &str) -> anyhow::Result<()> { - let got = origin(None, input)?; + let got = origin(None, input, DEFAULT_UNIVERSE_DOMAIN)?; + assert_eq!(got, want, "input={input:?}"); + Ok(()) + } + + #[test_case("http://www.my-custom-universe.com", "https://test.my-custom-universe.com"; "global")] + #[test_case("http://private.my-custom-universe.com", "https://test.my-custom-universe.com"; "VPC-SC private")] + #[test_case("http://restricted.my-custom-universe.com", "https://test.my-custom-universe.com"; "VPC-SC restricted")] + #[test_case("http://test-my-private-ep.p.my-custom-universe.com", "https://test.my-custom-universe.com"; "PSC custom endpoint")] + #[test_case("https://us-central1-test.my-custom-universe.com", "https://us-central1-test.my-custom-universe.com"; "locational endpoint")] + #[test_case("https://test.us-central1.rep.my-custom-universe.com", "https://test.us-central1.rep.my-custom-universe.com"; "regional endpoint")] + fn origin_universe_domain(input: &str, want: &str) -> anyhow::Result<()> { + let got = origin( + Some(input), + "https://test.googleapis.com", + "my-custom-universe.com", + )?; assert_eq!(got, want, "input={input:?}"); Ok(()) } #[test] fn errors() { - let got = origin_and_header(Some("https:///a/b/c"), "https://test.googleapis.com"); + let got = origin_and_header( + Some("https:///a/b/c"), + "https://test.googleapis.com", + DEFAULT_UNIVERSE_DOMAIN, + ); assert!(matches!(got, Err(HostError::Uri(_))), "{got:?}"); - let got = origin_and_header(Some("/a/b/c"), "https://test.googleapis.com"); + let got = origin_and_header( + Some("/a/b/c"), + "https://test.googleapis.com", + DEFAULT_UNIVERSE_DOMAIN, + ); assert!( matches!(got, Err(HostError::MissingAuthority(ref e)) if e == "/a/b/c"), "{got:?}" ); - let got = origin_and_header(None, "https:///"); + let got = origin_and_header(None, "https:///", DEFAULT_UNIVERSE_DOMAIN); assert!(matches!(got, Err(HostError::Uri(_))), "{got:?}"); } diff --git a/src/gax-internal/src/http.rs b/src/gax-internal/src/http.rs index 499a16eed7..7cf91861a4 100644 --- a/src/gax-internal/src/http.rs +++ b/src/gax-internal/src/http.rs @@ -27,6 +27,7 @@ pub mod reqwest; use crate::as_inner::as_inner; use crate::attempt_info::AttemptInfo; use crate::observability::{HttpResultExt, RequestRecorder, create_http_attempt_span}; +use crate::universe_domain::DEFAULT_UNIVERSE_DOMAIN; use google_cloud_auth::credentials::{ Builder as CredentialsBuilder, CacheableResource, Credentials, }; @@ -67,6 +68,7 @@ pub struct ReqwestClient { polling_backoff_policy: Arc, instrumentation: Option<&'static crate::options::InstrumentationClientInfo>, _tracing_enabled: bool, + universe_domain: String, transport_metric: Option, } @@ -87,12 +89,17 @@ impl ReqwestClient { builder = builder.redirect(::reqwest::redirect::Policy::none()); } let inner = builder.build().map_err(BuilderError::transport)?; - let host = crate::host::header(config.endpoint.as_deref(), default_endpoint) - .map_err(|e| e.client_builder())?; + let universe_domain = + crate::universe_domain::resolve(config.universe_domain.as_deref(), &cred).await?; + let host = crate::host::header( + config.endpoint.as_deref(), + &default_endpoint, + &universe_domain, + ) + .map_err(|e| e.client_builder())?; + let service_endpoint = default_endpoint.replace(DEFAULT_UNIVERSE_DOMAIN, &universe_domain); let tracing_enabled = crate::options::tracing_enabled(&config); - let endpoint = config - .endpoint - .unwrap_or_else(|| default_endpoint.to_string()); + let endpoint = config.endpoint.unwrap_or(service_endpoint); Ok(Self { inner, cred, @@ -117,6 +124,7 @@ impl ReqwestClient { .unwrap_or_else(|| Arc::new(ExponentialBackoff::default())), instrumentation: None, _tracing_enabled: tracing_enabled, + universe_domain, transport_metric: None, }) } @@ -220,7 +228,8 @@ impl ReqwestClient { url: &str, default_endpoint: &str, ) -> Result { - let host = crate::host::header(Some(url), default_endpoint).map_err(|e| e.gax())?; + let host = crate::host::header(Some(url), &default_endpoint, &self.universe_domain) + .map_err(|e| e.gax())?; let builder = self .inner .request(method, url) @@ -581,9 +590,26 @@ mod tests { use crate::options::ClientConfig; use crate::options::InstrumentationClientInfo; use google_cloud_auth::credentials::anonymous::Builder as Anonymous; + use google_cloud_auth::credentials::{CacheableResource, CredentialsProvider}; + use google_cloud_auth::errors::CredentialsError; use http::{HeaderMap, HeaderValue, Method}; + use scoped_env::ScopedEnv; + use serial_test::serial; use test_case::test_case; + type AuthResult = std::result::Result; + type TestResult = anyhow::Result<()>; + + mockall::mock! { + #[derive(Debug)] + Credentials {} + + impl CredentialsProvider for Credentials { + async fn headers(&self, extensions: Extensions) -> AuthResult>; + async fn universe_domain(&self) -> Option; + } + } + #[tokio::test] async fn client_http_error_bytes() -> anyhow::Result<()> { let http_resp = http::Response::builder() @@ -767,6 +793,52 @@ mod tests { Ok(()) } + #[tokio::test] + #[test_case(None, "test.my-custom-universe.com"; "default")] + #[test_case(Some("http://www.my-custom-universe.com"), "test.my-custom-universe.com"; "global")] + #[test_case(Some("http://private.my-custom-universe.com"), "test.my-custom-universe.com"; "VPC-SC private")] + #[test_case(Some("http://restricted.my-custom-universe.com"), "test.my-custom-universe.com"; "VPC-SC restricted")] + #[test_case(Some("http://test-my-private-ep.p.my-custom-universe.com"), "test.my-custom-universe.com"; "PSC custom endpoint")] + #[test_case(Some("https://us-central1-test.my-custom-universe.com"), "us-central1-test.my-custom-universe.com"; "locational endpoint")] + #[test_case(Some("https://test.us-central1.rep.my-custom-universe.com"), "test.us-central1.rep.my-custom-universe.com"; "regional endpoint")] + #[serial] + async fn host_from_endpoint_with_universe_domain_success( + endpoint_override: Option<&str>, + expected_host: &str, + ) -> TestResult { + let _env = ScopedEnv::remove("GOOGLE_CLOUD_UNIVERSE_DOMAIN"); + let universe_domain = "my-custom-universe.com"; + let mut config = ClientConfig::default(); + config.universe_domain = Some(universe_domain.to_string()); + config.endpoint = endpoint_override.map(String::from); + + let mut cred = MockCredentials::new(); + cred.expect_universe_domain() + .returning(move || Some(universe_domain.to_string())); + config.cred = Some(cred.into()); + + let client = ReqwestClient::new(config, "https://test.googleapis.com").await?; + assert_eq!(client.universe_domain, universe_domain); + assert_eq!(client.host, expected_host); + + Ok(()) + } + + #[tokio::test] + async fn host_from_endpoint_with_universe_domain_mismatch_fails() -> TestResult { + let mut config = ClientConfig::default(); + config.universe_domain = Some("custom.com".to_string()); + config.cred = Some(Anonymous::new().build()); + + let err = ReqwestClient::new(config, "https://language.googleapis.com") + .await + .unwrap_err(); + + assert!(err.is_universe_domain_mismatch(), "{err:?}"); + + Ok(()) + } + #[test_case(None; "default")] #[test_case(Some("localhost:5678"); "custom")] #[tokio::test] diff --git a/src/gax-internal/src/universe_domain.rs b/src/gax-internal/src/universe_domain.rs index 011d05392a..9e303fc4b6 100644 --- a/src/gax-internal/src/universe_domain.rs +++ b/src/gax-internal/src/universe_domain.rs @@ -15,11 +15,9 @@ use google_cloud_auth::credentials::Credentials; use google_cloud_gax::client_builder::{Error, Result}; -#[allow(dead_code)] pub(crate) const DEFAULT_UNIVERSE_DOMAIN: &str = "googleapis.com"; const UNIVERSE_DOMAIN_VAR: &str = "GOOGLE_CLOUD_UNIVERSE_DOMAIN"; -#[allow(dead_code)] pub(crate) async fn resolve( universe_domain_client_override: Option<&str>, cred: &Credentials, diff --git a/src/gax-internal/tests/grpc_auth.rs b/src/gax-internal/tests/grpc_auth.rs index 0f0bea20e3..4c913b54ea 100644 --- a/src/gax-internal/tests/grpc_auth.rs +++ b/src/gax-internal/tests/grpc_auth.rs @@ -59,6 +59,7 @@ mod tests { mock.expect_headers() .times(retry_count..) .returning(|_extensions| Err(CredentialsError::from_msg(true, "mock retryable error"))); + mock.expect_universe_domain().returning(|| None); let retry_policy = Aip194Strict.with_attempt_limit(retry_count as u32); let client = builder(endpoint) @@ -90,6 +91,7 @@ mod tests { mock.expect_headers() .times(1) .returning(move |_extensions| headers_response.clone()); + mock.expect_universe_domain().returning(|| None); let client = builder(endpoint) .with_credentials(Credentials::from(mock)) diff --git a/src/gax-internal/tests/http_auth.rs b/src/gax-internal/tests/http_auth.rs index d7e3fd3f2e..03d117e2d6 100644 --- a/src/gax-internal/tests/http_auth.rs +++ b/src/gax-internal/tests/http_auth.rs @@ -61,6 +61,7 @@ mod tests { mock.expect_headers() .times(retry_count..) .returning(|_extensions| Err(CredentialsError::from_msg(true, "mock retryable error"))); + mock.expect_universe_domain().returning(|| None); let retry_policy = Aip194Strict.with_attempt_limit(retry_count as u32); let client = echo_server::builder(endpoint) @@ -97,6 +98,7 @@ mod tests { mock.expect_headers() .times(1) .returning(move |_extensions| headers_response.clone()); + mock.expect_universe_domain().returning(|| None); let client = echo_server::builder(endpoint) .with_credentials(Credentials::from(mock)) diff --git a/src/gax-internal/tests/mock_credentials.rs b/src/gax-internal/tests/mock_credentials.rs index 49b3d6cf1a..5e495b38ab 100644 --- a/src/gax-internal/tests/mock_credentials.rs +++ b/src/gax-internal/tests/mock_credentials.rs @@ -52,5 +52,6 @@ pub fn mock_credentials() -> MockCredentials { data: header, }) }); + mock.expect_universe_domain().returning(|| None); mock } diff --git a/src/storage/tests/default_credentials.rs b/src/storage/tests/default_credentials.rs index b8017b2536..0ac2733fe2 100644 --- a/src/storage/tests/default_credentials.rs +++ b/src/storage/tests/default_credentials.rs @@ -45,7 +45,7 @@ mod tests { "private_key_id": "test-private-key-id", "private_key": "-----BEGIN PRIVATE KEY-----\nBLAHBLAHBLAH\n-----END PRIVATE KEY-----\n", "client_email": "test-client-email", - "universe_domain": "test-universe-domain" + "universe_domain": "googleapis.com", }); std::fs::write(destination.clone(), contents.to_string())?;