diff --git a/src/spanner/src/batch_read_only_transaction.rs b/src/spanner/src/batch_read_only_transaction.rs index 941612989d..6e3f5c4acb 100644 --- a/src/spanner/src/batch_read_only_transaction.rs +++ b/src/spanner/src/batch_read_only_transaction.rs @@ -230,6 +230,42 @@ impl BatchReadOnlyTransaction { } } +/// Options for executing a partition. +#[derive(Clone, Debug, Default)] +pub struct PartitionExecuteOptions { + /// If true, use separate server resources on Spanner to execute the query. + data_boost_enabled: bool, +} + +impl PartitionExecuteOptions { + /// Sets whether Data Boost is enabled. + /// + /// # Example + /// ``` + /// # use google_cloud_spanner::client::{Spanner, Statement}; + /// # use google_cloud_spanner::{PartitionOptions, PartitionExecuteOptions}; + /// # async fn run_query(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> { + /// # let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?; + /// # let transaction = db_client.batch_read_only_transaction().build().await?; + /// # let partitions = transaction + /// # .partition_query( + /// # Statement::builder("SELECT * FROM Users").build(), + /// # PartitionOptions::default(), + /// # ) + /// # .await?; + /// // On a worker receiving a partition, execute it with Data Boost: + /// let options = PartitionExecuteOptions::default() + /// .with_data_boost(true); + /// let mut result_set = partitions[0].execute(&db_client, options).await?; + /// # Ok(()) + /// # } + /// ``` + pub fn with_data_boost(mut self, enabled: bool) -> Self { + self.data_boost_enabled = enabled; + self + } +} + /// Defines the segments of data to be read in a partitioned read or query. /// These partitions can be serialized and processed across several /// different machines or processes. @@ -245,7 +281,7 @@ impl Partition { /// # Example: executing a query partition /// ``` /// # use google_cloud_spanner::client::{Spanner, Statement}; - /// # use google_cloud_spanner::PartitionOptions; + /// # use google_cloud_spanner::{PartitionOptions, PartitionExecuteOptions}; /// # async fn run_query(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> { /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?; /// let transaction = db_client.batch_read_only_transaction().build().await?; @@ -257,7 +293,7 @@ impl Partition { /// // ... send partitions to other workers ... /// /// // On a worker receiving a partition, execute it: - /// let mut result_set = partitions[0].execute(&db_client).await?; + /// let mut result_set = partitions[0].execute(&db_client, PartitionExecuteOptions::default()).await?; /// while let Some(row) = result_set.next().await.transpose()? { /// // process row /// } @@ -267,7 +303,7 @@ impl Partition { /// # Example: executing a read partition /// ``` /// # use google_cloud_spanner::client::{Spanner, ReadRequest, KeySet}; - /// # use google_cloud_spanner::PartitionOptions; + /// # use google_cloud_spanner::{PartitionOptions, PartitionExecuteOptions}; /// # async fn run_read(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> { /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?; /// let transaction = db_client.batch_read_only_transaction().build().await?; @@ -277,7 +313,7 @@ impl Partition { /// // ... send partitions to other workers ... /// /// // On a worker receiving a partition, execute it: - /// let mut result_set = partitions[0].execute(&db_client).await?; + /// let mut result_set = partitions[0].execute(&db_client, PartitionExecuteOptions::default()).await?; /// while let Some(row) = result_set.next().await.transpose()? { /// // process row /// } @@ -287,7 +323,11 @@ impl Partition { /// /// A partition can be executed by any `DatabaseClient` that is connected to /// the database that the partitions belong to. - pub async fn execute(&self, client: &DatabaseClient) -> crate::Result { + pub async fn execute( + &self, + client: &DatabaseClient, + options: PartitionExecuteOptions, + ) -> crate::Result { match &self.inner { PartitionedOperation::Query { partition_token, @@ -301,6 +341,7 @@ impl Partition { transaction_selector, session_name, statement, + options, ) .await } @@ -316,6 +357,7 @@ impl Partition { transaction_selector, session_name, read_request, + options, ) .await } @@ -328,13 +370,15 @@ impl Partition { transaction_selector: &crate::model::TransactionSelector, session_name: &str, statement: &Statement, + options: PartitionExecuteOptions, ) -> crate::Result { let request = statement .clone() .into_request() .set_session(session_name.to_string()) .set_transaction(transaction_selector.clone()) - .set_partition_token(partition_token.clone()); + .set_partition_token(partition_token.clone()) + .set_data_boost_enabled(options.data_boost_enabled); let stream = client .spanner @@ -361,13 +405,15 @@ impl Partition { transaction_selector: &crate::model::TransactionSelector, session_name: &str, read_request: &crate::read::ReadRequest, + options: PartitionExecuteOptions, ) -> crate::Result { let request = read_request .clone() .into_request() .set_session(session_name.to_string()) .set_transaction(transaction_selector.clone()) - .set_partition_token(partition_token.clone()); + .set_partition_token(partition_token.clone()) + .set_data_boost_enabled(options.data_boost_enabled); let stream = client .spanner @@ -418,12 +464,15 @@ pub(crate) mod tests { use spanner_grpc_mock::google::spanner::v1::{ Partition as MockPartition, PartitionResponse, Transaction, }; + use static_assertions::assert_impl_all; + use std::fmt::Debug; #[test] fn auto_traits() { - static_assertions::assert_impl_all!(BatchReadOnlyTransactionBuilder: Send, Sync); - static_assertions::assert_impl_all!(BatchReadOnlyTransaction: Send, Sync, std::fmt::Debug); - static_assertions::assert_impl_all!(Partition: Send, Sync, std::fmt::Debug); + assert_impl_all!(BatchReadOnlyTransactionBuilder: Send, Sync); + assert_impl_all!(BatchReadOnlyTransaction: Send, Sync, Debug); + assert_impl_all!(Partition: Send, Sync, Debug); + assert_impl_all!(PartitionExecuteOptions: Send, Sync, Debug, Default); } #[tokio::test] @@ -459,7 +508,9 @@ pub(crate) mod tests { }, }; - let _result_set = partition.execute(&db_client).await?; + let _result_set = partition + .execute(&db_client, PartitionExecuteOptions::default()) + .await?; Ok(()) } @@ -499,7 +550,9 @@ pub(crate) mod tests { }, }; - let _result_set = partition.execute(&db_client).await?; + let _result_set = partition + .execute(&db_client, PartitionExecuteOptions::default()) + .await?; Ok(()) } @@ -640,4 +693,64 @@ pub(crate) mod tests { } Ok(()) } + + #[tokio::test] + async fn execute_query_with_data_boost() -> anyhow::Result<()> { + let mut mock = create_session_mock(); + + mock.expect_execute_streaming_sql().once().returning(|req| { + let req = req.into_inner(); + assert!(req.data_boost_enabled, "data_boost_enabled should be true"); + let (_, rx) = tokio::sync::mpsc::channel(1); + Ok(Response::from(rx)) + }); + + let (db_client, _server) = setup_db_client(mock).await; + + let partition = Partition { + inner: PartitionedOperation::Query { + partition_token: b"partition_token_123".to_vec().into(), + transaction_selector: crate::model::TransactionSelector::default() + .set_id(b"tx_id_1".to_vec()), + session_name: "projects/p/instances/i/databases/d/sessions/123".into(), + statement: Statement::builder("SELECT * FROM Users").build(), + }, + }; + + let options = PartitionExecuteOptions::default().with_data_boost(true); + let _result_set = partition.execute(&db_client, options).await?; + + Ok(()) + } + + #[tokio::test] + async fn execute_read_with_data_boost() -> anyhow::Result<()> { + let mut mock = create_session_mock(); + + mock.expect_streaming_read().once().returning(|req| { + let req = req.into_inner(); + assert!(req.data_boost_enabled, "data_boost_enabled should be true"); + let (_, rx) = tokio::sync::mpsc::channel(1); + Ok(Response::from(rx)) + }); + + let (db_client, _server) = setup_db_client(mock).await; + + let partition = Partition { + inner: PartitionedOperation::Read { + partition_token: b"partition_token_456".to_vec().into(), + transaction_selector: crate::model::TransactionSelector::default() + .set_id(b"tx_id_2".to_vec()), + session_name: "projects/p/instances/i/databases/d/sessions/123".into(), + read_request: SpannerReadRequest::builder("Users", vec!["Id", "Name"]) + .with_keys(KeySet::all()) + .build(), + }, + }; + + let options = PartitionExecuteOptions::default().with_data_boost(true); + let _result_set = partition.execute(&db_client, options).await?; + + Ok(()) + } } diff --git a/src/spanner/src/lib.rs b/src/spanner/src/lib.rs index 4c29f37749..ee459a96ec 100644 --- a/src/spanner/src/lib.rs +++ b/src/spanner/src/lib.rs @@ -24,7 +24,7 @@ pub use crate::model::PartitionOptions; pub use batch_dml::BatchDml; pub use batch_dml::BatchDmlBuilder; pub use batch_read_only_transaction::{ - BatchReadOnlyTransaction, BatchReadOnlyTransactionBuilder, Partition, + BatchReadOnlyTransaction, BatchReadOnlyTransactionBuilder, Partition, PartitionExecuteOptions, }; pub use error::BatchUpdateError; pub use google_cloud_gax::Result; diff --git a/tests/spanner/src/batch_read_only_transaction.rs b/tests/spanner/src/batch_read_only_transaction.rs index 24bd3ce9bd..4feb5d8c36 100644 --- a/tests/spanner/src/batch_read_only_transaction.rs +++ b/tests/spanner/src/batch_read_only_transaction.rs @@ -14,7 +14,7 @@ use crate::client::create_database_client; use google_cloud_spanner::client::{DatabaseClient, KeySet, Mutation, ReadRequest, Statement}; -use google_cloud_spanner::{PartitionOptions, key}; +use google_cloud_spanner::{PartitionExecuteOptions, PartitionOptions, key}; use google_cloud_test_utils::resource_names::LowercaseAlphanumeric; pub async fn partitioned_query(db_client: &DatabaseClient) -> anyhow::Result<()> { @@ -76,7 +76,9 @@ pub async fn partitioned_query(db_client: &DatabaseClient) -> anyhow::Result<()> let mut rows_received = 0; for partition in partitions { - let mut rs = partition.execute(&execution_client).await?; + let mut rs = partition + .execute(&execution_client, PartitionExecuteOptions::default()) + .await?; while let Some(row) = rs.next().await.transpose()? { rows_received += 1; @@ -168,7 +170,9 @@ pub async fn partitioned_read(db_client: &DatabaseClient) -> anyhow::Result<()> let mut rows_received = 0; for partition in partitions { - let mut rs = partition.execute(&execution_client).await?; + let mut rs = partition + .execute(&execution_client, PartitionExecuteOptions::default()) + .await?; while let Some(row) = rs.next().await.transpose()? { rows_received += 1;