From e7d34b069064aa93113ec4909cbcdb390160f94e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Wed, 1 Apr 2026 16:15:22 +0200 Subject: [PATCH 01/17] perf(spanner): inline BeginTransaction with first query (step 1) Adds support for inlining the BeginTransaction with the first query in a read-only transaction. This saves one round-trip to Spanner for multi-use read-only transactions. This implementation is intentionally simple: 1. It does not support parallel queries at the start of the transaction. 2. It does not include error handling for the first query. 3. It only supports read-only transactions. This is step 1. Follow-up pull requests addresses the above points. --- .../src/batch_read_only_transaction.rs | 11 +- src/spanner/src/read_only_transaction.rs | 269 ++++++++++++++++-- src/spanner/src/read_write_transaction.rs | 12 +- src/spanner/src/result_set.rs | 17 +- tests/spanner/src/query.rs | 29 +- 5 files changed, 299 insertions(+), 39 deletions(-) diff --git a/src/spanner/src/batch_read_only_transaction.rs b/src/spanner/src/batch_read_only_transaction.rs index c31447a5f3..179ac26dc2 100644 --- a/src/spanner/src/batch_read_only_transaction.rs +++ b/src/spanner/src/batch_read_only_transaction.rs @@ -43,7 +43,8 @@ pub struct BatchReadOnlyTransactionBuilder { impl BatchReadOnlyTransactionBuilder { pub(crate) fn new(client: DatabaseClient) -> Self { Self { - inner: MultiUseReadOnlyTransactionBuilder::new(client), + inner: MultiUseReadOnlyTransactionBuilder::new(client) + .with_explicit_begin_transaction(true), } } @@ -147,7 +148,7 @@ impl BatchReadOnlyTransaction { .clone() .into_partition_query_request() .set_session(self.inner.context.client.session.name.clone()) - .set_transaction(self.inner.context.transaction_selector.clone()) + .set_transaction(self.inner.context.transaction_selector.selector()) .set_partition_options(options); let response = self @@ -164,7 +165,7 @@ impl BatchReadOnlyTransaction { .map(|p| Partition { inner: PartitionedOperation::Query { partition_token: p.partition_token, - transaction_selector: self.inner.context.transaction_selector.clone(), + transaction_selector: self.inner.context.transaction_selector.selector(), session_name: self.inner.context.client.session.name.clone(), statement: statement.clone(), }, @@ -202,7 +203,7 @@ impl BatchReadOnlyTransaction { .clone() .into_partition_read_request() .set_session(self.inner.context.client.session.name.clone()) - .set_transaction(self.inner.context.transaction_selector.clone()) + .set_transaction(self.inner.context.transaction_selector.selector()) .set_partition_options(options); let response = self @@ -219,7 +220,7 @@ impl BatchReadOnlyTransaction { .map(|p| Partition { inner: PartitionedOperation::Read { partition_token: p.partition_token, - transaction_selector: self.inner.context.transaction_selector.clone(), + transaction_selector: self.inner.context.transaction_selector.selector(), session_name: self.inner.context.client.session.name.clone(), read_request: read.clone(), }, diff --git a/src/spanner/src/read_only_transaction.rs b/src/spanner/src/read_only_transaction.rs index 3f0df51a5c..44782326f3 100644 --- a/src/spanner/src/read_only_transaction.rs +++ b/src/spanner/src/read_only_transaction.rs @@ -19,6 +19,7 @@ use crate::precommit::PrecommitTokenTracker; use crate::result_set::{ResultSet, StreamOperation}; use crate::statement::Statement; use crate::timestamp_bound::TimestampBound; +use std::sync::{Arc, Mutex}; /// A builder for [SingleUseReadOnlyTransaction]. /// @@ -91,7 +92,10 @@ impl SingleUseReadOnlyTransactionBuilder { SingleUseReadOnlyTransaction { context: ReadContext { client: self.client, - transaction_selector, + transaction_selector: ReadContextTransactionSelector::Fixed( + transaction_selector, + None, + ), precommit_token_tracker: PrecommitTokenTracker::new_noop(), transaction_tag: None, }, @@ -204,6 +208,7 @@ impl SingleUseReadOnlyTransaction { pub struct MultiUseReadOnlyTransactionBuilder { client: DatabaseClient, timestamp_bound: Option, + explicit_begin: bool, } impl MultiUseReadOnlyTransactionBuilder { @@ -211,9 +216,44 @@ impl MultiUseReadOnlyTransactionBuilder { Self { client, timestamp_bound: None, + explicit_begin: false, } } + /// Sets whether the transaction should be explicitly started using a `BeginTransaction` RPC. + /// + /// # Example + /// ``` + /// # use google_cloud_spanner::client::Spanner; + /// # use google_cloud_spanner::client::Statement; + /// # async fn set_explicit_begin(spanner: Spanner) -> Result<(), google_cloud_spanner::Error> { + /// let db_client = spanner.database_client("projects/p/instances/i/databases/d").build().await?; + /// let transaction = db_client.read_only_transaction().with_explicit_begin_transaction(true).build().await?; + /// let statement = Statement::builder("SELECT * FROM users").build(); + /// let result_set = transaction.execute_query(statement).await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// By default, the Spanner client will inline the `BeginTransaction` call with the first query + /// in the transaction. This reduces the number of round-trips to Spanner that are needed for a + /// transaction. Setting this option to `true` can be beneficial for specific transaction shapes: + /// + /// 1. When the transaction executes multiple parallel queries at the start of the transaction. + /// Only one query can include a `BeginTransaction` option, and all other queries must wait for + /// the first query to return the first result before they can proceed to execute. A + /// `BeginTransaction` RPC will quickly return a transaction ID and allow all queries to start + /// execution in parallel once the transaction ID has been returned. + /// 2. When the first query in the transaction could fail. If the query fails, then it will also + /// not start a transaction and return a transaction ID. The transaction will then fall back to + /// executing a `BeginTransaction` RPC and retry the first query. + /// + /// Default is `false` (inline begin). + pub fn with_explicit_begin_transaction(mut self, explicit: bool) -> Self { + self.explicit_begin = explicit; + self + } + /// Sets the timestamp bound for the read-only transaction. /// /// # Example @@ -231,6 +271,29 @@ impl MultiUseReadOnlyTransactionBuilder { self } + async fn begin( + &self, + options: TransactionOptions, + ) -> crate::Result { + let request = crate::model::BeginTransactionRequest::default() + .set_session(self.client.session.name.clone()) + .set_options(options); + + // TODO(#4972): make request options configurable + let response = self + .client + .spanner + .begin_transaction(request, crate::RequestOptions::default()) + .await?; + + let transaction_selector = crate::model::TransactionSelector::default().set_id(response.id); + + Ok(ReadContextTransactionSelector::Fixed( + transaction_selector, + response.read_timestamp, + )) + } + /// Builds the [MultiUseReadOnlyTransaction] and starts the transaction /// by calling the `BeginTransaction` RPC. /// @@ -245,30 +308,27 @@ impl MultiUseReadOnlyTransactionBuilder { /// ``` pub async fn build(self) -> crate::Result { let read_only = ReadOnly::default().set_return_read_timestamp(true); - let read_only = match self.timestamp_bound { - Some(b) => read_only.set_timestamp_bound(b.0), + let read_only = match self.timestamp_bound.as_ref() { + Some(b) => read_only.set_timestamp_bound(b.0.clone()), None => read_only.set_strong(true), }; - let request = crate::model::BeginTransactionRequest::default() - .set_session(self.client.session.name.clone()) - .set_options(TransactionOptions::default().set_read_only(read_only)); + let options = TransactionOptions::default().set_read_only(read_only); - // TODO(#4972): make request options configurable - let response = self - .client - .spanner - .begin_transaction(request, crate::RequestOptions::default()) - .await?; + let selector = if self.explicit_begin { + self.begin(options).await? + } else { + ReadContextTransactionSelector::Lazy(Arc::new(Mutex::new( + TransactionState::NotStarted(options), + ))) + }; - let transaction_selector = crate::model::TransactionSelector::default().set_id(response.id); Ok(MultiUseReadOnlyTransaction { context: ReadContext { client: self.client, - transaction_selector, + transaction_selector: selector, precommit_token_tracker: PrecommitTokenTracker::new_noop(), transaction_tag: None, }, - read_timestamp: response.read_timestamp, }) } } @@ -297,13 +357,12 @@ impl MultiUseReadOnlyTransactionBuilder { #[derive(Debug)] pub struct MultiUseReadOnlyTransaction { pub(crate) context: ReadContext, - pub(crate) read_timestamp: Option, } impl MultiUseReadOnlyTransaction { /// Returns the read timestamp chosen for the transaction. pub fn read_timestamp(&self) -> Option { - self.read_timestamp + self.context.transaction_selector.read_timestamp() } /// Executes a query using this transaction. @@ -370,10 +429,71 @@ impl MultiUseReadOnlyTransaction { } } +#[derive(Clone, Debug)] +pub(crate) enum ReadContextTransactionSelector { + Fixed(crate::model::TransactionSelector, Option), + Lazy(Arc>), +} + +#[derive(Clone, Debug)] +pub(crate) enum TransactionState { + NotStarted(crate::model::TransactionOptions), + Started(crate::model::TransactionSelector, Option), +} + +impl TransactionState { + fn selector(&self) -> crate::model::TransactionSelector { + match self { + Self::Started(selector, _) => selector.clone(), + Self::NotStarted(options) => { + crate::model::TransactionSelector::default().set_begin(options.clone()) + } + } + } +} + +impl ReadContextTransactionSelector { + pub(crate) fn selector(&self) -> crate::model::TransactionSelector { + match self { + Self::Fixed(selector, _) => selector.clone(), + Self::Lazy(lazy) => lazy + .lock() + .expect("transaction state mutex poisoned") + .selector(), + } + } + + pub(crate) fn update(&self, id: bytes::Bytes, timestamp: Option) { + if let Self::Lazy(lazy) = self { + let mut guard = lazy.lock().expect("transaction state mutex poisoned"); + if matches!(&*guard, TransactionState::NotStarted(_)) { + *guard = TransactionState::Started( + crate::model::TransactionSelector::default().set_id(id), + timestamp, + ); + } + } + } + + pub(crate) fn read_timestamp(&self) -> Option { + match self { + Self::Fixed(_, timestamp) => *timestamp, + Self::Lazy(lazy) => { + let guard = lazy.lock().expect("transaction state mutex poisoned"); + if let TransactionState::Started(_, timestamp) = &*guard { + *timestamp + } else { + None + } + } + } + } +} + #[derive(Clone, Debug)] pub(crate) struct ReadContext { pub(crate) client: DatabaseClient, - pub(crate) transaction_selector: crate::model::TransactionSelector, + pub(crate) transaction_selector: ReadContextTransactionSelector, pub(crate) precommit_token_tracker: PrecommitTokenTracker, pub(crate) transaction_tag: Option, } @@ -405,7 +525,7 @@ impl ReadContext { .into() .into_request() .set_session(self.client.session.name.clone()) - .set_transaction(self.transaction_selector.clone()); + .set_transaction(self.transaction_selector.selector()); request.request_options = self.amend_request_options(request.request_options); let stream = self @@ -418,6 +538,7 @@ impl ReadContext { Ok(ResultSet::new( stream, + Some(self.transaction_selector.clone()), self.precommit_token_tracker.clone(), self.client.clone(), StreamOperation::Query(request), @@ -432,7 +553,7 @@ impl ReadContext { .into() .into_request() .set_session(self.client.session.name.clone()) - .set_transaction(self.transaction_selector.clone()); + .set_transaction(self.transaction_selector.selector()); request.request_options = self.amend_request_options(request.request_options); let stream = self @@ -445,6 +566,7 @@ impl ReadContext { Ok(ResultSet::new( stream, + Some(self.transaction_selector.clone()), self.precommit_token_tracker.clone(), self.client.clone(), StreamOperation::Read(request), @@ -525,9 +647,8 @@ pub(crate) mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = db_client.single_use().build(); - let ro = tx - .context - .transaction_selector + let selector = tx.context.transaction_selector.selector(); + let ro = selector .single_use() .expect("Expected SingleUse selector") .read_only() @@ -543,9 +664,8 @@ pub(crate) mod tests { std::time::Duration::from_secs(10), )) .build(); - let ro2 = tx2 - .context - .transaction_selector + let selector = tx2.context.transaction_selector.selector(); + let ro2 = selector .single_use() .expect("Expected SingleUse selector") .read_only() @@ -646,6 +766,7 @@ pub(crate) mod tests { let tx = db_client .read_only_transaction() + .with_explicit_begin_transaction(true) .build() .await .expect("Failed to start tx"); @@ -670,6 +791,102 @@ pub(crate) mod tests { } } + #[tokio::test] + async fn execute_multi_query_inline_begin() -> anyhow::Result<()> { + use super::super::result_set::tests::string_val; + use crate::client::Statement; + use crate::value::Value; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + + let mut mock = create_session_mock(); + + // No explicit begin_transaction should be called. + mock.expect_begin_transaction().never(); + + let mut seq = mockall::Sequence::new(); + + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + + // First call: Should have Selector::Begin + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin"), + } + let mut rs = setup_select1(); + rs.metadata.as_mut().unwrap().transaction = Some(mock_v1::Transaction { + id: vec![4, 5, 6], + read_timestamp: Some(prost_types::Timestamp { + seconds: 987654321, + nanos: 0, + }), + ..Default::default() + }); + Ok(gaxi::grpc::tonic::Response::new(Box::pin( + tokio_stream::iter(vec![Ok(rs)]), + ))) + }); + + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + // Second call: Should have Selector::Id using the ID returned in the first call + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + } + _ => panic!("Expected Selector::Id"), + } + Ok(gaxi::grpc::tonic::Response::new(Box::pin( + tokio_stream::iter(vec![Ok(setup_select1())]), + ))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + // The read timestamp is not available until the first query is executed. + assert!(tx.read_timestamp().is_none()); + + for i in 0..2 { + let mut rs = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + + let row = rs.next().await.expect("Expected a row")?; + assert_eq!(row.raw_values(), [Value(string_val("1"))]); + + let result = rs.next().await; + assert!(result.is_none(), "Expected None, got {result:?}"); + + if i == 0 { + // Read timestamp becomes available. + assert_eq!( + tx.read_timestamp() + .expect("Expected read timestamp") + .seconds(), + 987654321 + ); + } + } + + Ok(()) + } + #[tokio::test] async fn execute_single_read() { use super::super::result_set::tests::string_val; diff --git a/src/spanner/src/read_write_transaction.rs b/src/spanner/src/read_write_transaction.rs index 9a84b1bb87..920d69bfda 100644 --- a/src/spanner/src/read_write_transaction.rs +++ b/src/spanner/src/read_write_transaction.rs @@ -100,7 +100,11 @@ impl ReadWriteTransactionBuilder { .begin_transaction(request, RequestOptions::default()) .await?; - let transaction_selector = TransactionSelector::default().set_id(response.id); + let transaction_selector = + crate::read_only_transaction::ReadContextTransactionSelector::Fixed( + TransactionSelector::default().set_id(response.id), + None, + ); Ok(ReadWriteTransaction { context: ReadContext { client: self.client.clone(), @@ -144,7 +148,7 @@ impl ReadWriteTransaction { .into() .into_request() .set_session(self.context.client.session.name.clone()) - .set_transaction(self.context.transaction_selector.clone()) + .set_transaction(self.context.transaction_selector.selector()) .set_seqno(seqno); request.request_options = self.context.amend_request_options(request.request_options); @@ -245,7 +249,7 @@ impl ReadWriteTransaction { let request = ExecuteBatchDmlRequest::default() .set_session(self.context.client.session.name.clone()) - .set_transaction(self.context.transaction_selector.clone()) + .set_transaction(self.context.transaction_selector.selector()) .set_seqno(seqno) .set_statements(statements) .set_or_clear_request_options( @@ -271,7 +275,7 @@ impl ReadWriteTransaction { } pub(crate) fn transaction_id(&self) -> crate::Result { - match &self.context.transaction_selector.selector { + match &self.context.transaction_selector.selector().selector { Some(Selector::Id(id)) => Ok(id.clone()), _ => Err(internal_error("Transaction ID is missing")), } diff --git a/src/spanner/src/result_set.rs b/src/spanner/src/result_set.rs index cfe0397cae..6bd848b175 100644 --- a/src/spanner/src/result_set.rs +++ b/src/spanner/src/result_set.rs @@ -16,6 +16,7 @@ use crate::database_client::DatabaseClient; use crate::error::internal_error; use crate::google::spanner::v1::PartialResultSet; use crate::precommit::PrecommitTokenTracker; +use crate::read_only_transaction::ReadContextTransactionSelector; use crate::result_set_metadata::ResultSetMetadata; use crate::row::Row; use crate::server_streaming::stream::PartialResultSetStream; @@ -58,6 +59,7 @@ pub struct ResultSet { safe_to_retry: bool, max_buffered_partial_result_sets: usize, retry_count: usize, + transaction_selector: Option, } /// Errors that can occur when interacting with a [`ResultSet`]. @@ -84,6 +86,7 @@ impl ResultSet { /// Creates a new result set. pub(crate) fn new( stream: PartialResultSetStream, + transaction_selector: Option, precommit_token_tracker: PrecommitTokenTracker, client: DatabaseClient, operation: StreamOperation, @@ -102,6 +105,7 @@ impl ResultSet { safe_to_retry: true, max_buffered_partial_result_sets: MAX_BUFFERED_PARTIAL_RESULT_SETS, retry_count: 0, + transaction_selector, } } @@ -274,8 +278,19 @@ impl ResultSet { (Some(_), Some(_)) => { return Err(internal_error("Additional metadata after first result set")); } - (None, Some(m)) => { + (None, Some(mut m)) => { + let transaction = m.transaction.take(); self.metadata = Some(ResultSetMetadata::new(Some(m))); + if let (Some(selector), Some(transaction)) = + (&self.transaction_selector, transaction) + { + selector.update( + transaction.id, + transaction + .read_timestamp + .and_then(|t| wkt::Timestamp::new(t.seconds, t.nanos).ok()), + ); + } } } diff --git a/tests/spanner/src/query.rs b/tests/spanner/src/query.rs index e38b51b989..fe02427a19 100644 --- a/tests/spanner/src/query.rs +++ b/tests/spanner/src/query.rs @@ -194,17 +194,40 @@ pub async fn result_set_metadata(db_client: &DatabaseClient) -> anyhow::Result<( } pub async fn multi_use_read_only_transaction(db_client: &DatabaseClient) -> anyhow::Result<()> { + for explicit_begin in [false, true] { + test_multi_use_read_only_transaction(db_client, explicit_begin).await?; + } + Ok(()) +} + +async fn test_multi_use_read_only_transaction( + db_client: &DatabaseClient, + explicit_begin: bool, +) -> anyhow::Result<()> { // Start a multi-use read-only transaction. - let tx = db_client.read_only_transaction().build().await?; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(explicit_begin) + .build() + .await?; - // Expect a read timestamp to have been chosen. - assert!(tx.read_timestamp().is_some()); + if explicit_begin { + // Expect a read timestamp to have been chosen immediately. + assert!(tx.read_timestamp().is_some()); + } else { + // Expect a read timestamp to NOT have been chosen yet. + assert!(tx.read_timestamp().is_none()); + } // Execute the first query. let mut rs1 = tx .execute_query(Statement::builder("SELECT 1 AS col_int").build()) .await?; let row1 = rs1.next().await.transpose()?.expect("should yield a row"); + + // The read timestamp is now always available. + assert!(tx.read_timestamp().is_some()); + let val1 = row1.raw_values()[0].as_string(); assert_eq!(val1, "1"); let next1 = rs1.next().await.transpose()?; From ef4be7468a1bc4a9cdf496ebf2295c7dac00baff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 2 Apr 2026 11:26:16 +0200 Subject: [PATCH 02/17] perf(spanner): inline begin transaction error handling Adds error handling for inline-begin-transaction. If the first statement in a transaction fails, and that statement included a BeginTransaction option, then the transaction has not been started. In order to keep the semantics of the transaction consistent for an 'outside observer', we need to do the following: 1. Catch the error that was thrown by the initial statement. 2. Start the transaction using an explicit BeginTransaction RPC. 3. Retry the initial statement, but now using the transaction ID from step 2. 4. Return the error or result for the retried initial statement. The above makes sure that: 1. The transaction is actually started when the first statement is executed, also when the statement failed. 2. The statement becomes part of the transaction, and the result of the statement is consistent with the read-timestamp of the transaction. The second part is important in order to comply with Spanner's strong consistency guarantees; If for example a statement returns a 'Table not found' error, then that error is only valid for the read timestamp that was used for executing the statement. This is the reason that we retry the statement after the BeginTransaction RPC to be able to return a result that is guaranteed to be consistent with any other queries/reads that will be executed in the same transaction. --- src/spanner/src/read_only_transaction.rs | 486 +++++++++++++++++++++-- src/spanner/src/result_set.rs | 442 ++++++++++++++++++++- tests/spanner/src/query.rs | 36 ++ tests/spanner/tests/driver.rs | 4 + 4 files changed, 910 insertions(+), 58 deletions(-) diff --git a/src/spanner/src/read_only_transaction.rs b/src/spanner/src/read_only_transaction.rs index 44782326f3..041d74e3b1 100644 --- a/src/spanner/src/read_only_transaction.rs +++ b/src/spanner/src/read_only_transaction.rs @@ -275,16 +275,7 @@ impl MultiUseReadOnlyTransactionBuilder { &self, options: TransactionOptions, ) -> crate::Result { - let request = crate::model::BeginTransactionRequest::default() - .set_session(self.client.session.name.clone()) - .set_options(options); - - // TODO(#4972): make request options configurable - let response = self - .client - .spanner - .begin_transaction(request, crate::RequestOptions::default()) - .await?; + let response = execute_begin_transaction(&self.client, options).await?; let transaction_selector = crate::model::TransactionSelector::default().set_id(response.id); @@ -429,6 +420,22 @@ impl MultiUseReadOnlyTransaction { } } +/// Executes an explicit `BeginTransaction` RPC on Spanner. +async fn execute_begin_transaction( + client: &crate::database_client::DatabaseClient, + options: crate::model::TransactionOptions, +) -> crate::Result { + let request = crate::model::BeginTransactionRequest::default() + .set_session(client.session.name.clone()) + .set_options(options); + + // TODO(#4972): make request options configurable + client + .spanner + .begin_transaction(request, crate::RequestOptions::default()) + .await +} + #[derive(Clone, Debug)] pub(crate) enum ReadContextTransactionSelector { Fixed(crate::model::TransactionSelector, Option), @@ -463,6 +470,32 @@ impl ReadContextTransactionSelector { } } + /// Explicitly begins a transaction if the transaction selector is a `Lazy` + /// selector and the transaction has not yet been started. This is used by + /// the client to force the start of a transaction if the first statement + /// failed. + pub(crate) async fn begin_explicitly( + &self, + client: &crate::database_client::DatabaseClient, + ) -> crate::Result<()> { + let Self::Lazy(lazy) = self else { + return Ok(()); + }; + + let options = { + let guard = lazy.lock().expect("transaction state mutex poisoned"); + let TransactionState::NotStarted(options) = &*guard else { + return Ok(()); + }; + options.clone() + }; + + let response = execute_begin_transaction(client, options).await?; + self.update(response.id, response.read_timestamp); + + Ok(()) + } + pub(crate) fn update(&self, id: bytes::Bytes, timestamp: Option) { if let Self::Lazy(lazy) = self { let mut guard = lazy.lock().expect("transaction state mutex poisoned"); @@ -517,6 +550,64 @@ impl ReadContext { options } + /// Attempts to execute an explicit `begin_transaction` RPC if the current transaction + /// selector is still in the `Lazy(NotStarted)` state. This is used as a + /// fallback mechanism when an initial implicit begin attempt failed. + async fn begin_explicitly_if_not_started(&self) -> crate::Result { + let ReadContextTransactionSelector::Lazy(lazy) = &self.transaction_selector else { + return Ok(false); + }; + let is_started = matches!(&*lazy.lock().unwrap(), TransactionState::Started(_, _)); + if is_started { + return Ok(false); + } + + self.transaction_selector + .begin_explicitly(&self.client) + .await?; + Ok(true) + } +} + +/// Helper macro to execute a streaming SQL or streaming read RPC with retry logic. +macro_rules! execute_stream_with_retry { + ($self:expr, $request:ident, $rpc_method:ident, $operation_variant:path) => {{ + let stream = match $self + .client + .spanner + // TODO(#4972): make request options configurable + .$rpc_method($request.clone(), crate::RequestOptions::default()) + .send() + .await + { + Ok(s) => s, + Err(e) => { + if $self.begin_explicitly_if_not_started().await? { + $request.transaction = Some($self.transaction_selector.selector()); + $self + .client + .spanner + // TODO(#4972): make request options configurable + .$rpc_method($request.clone(), crate::RequestOptions::default()) + .send() + .await? + } else { + return Err(e); + } + } + }; + + Ok(ResultSet::new( + stream, + Some($self.transaction_selector.clone()), + $self.precommit_token_tracker.clone(), + $self.client.clone(), + $operation_variant($request), + )) + }}; +} + +impl ReadContext { pub(crate) async fn execute_query>( &self, statement: T, @@ -528,21 +619,7 @@ impl ReadContext { .set_transaction(self.transaction_selector.selector()); request.request_options = self.amend_request_options(request.request_options); - let stream = self - .client - .spanner - // TODO(#4972): make request options configurable - .execute_streaming_sql(request.clone(), crate::RequestOptions::default()) - .send() - .await?; - - Ok(ResultSet::new( - stream, - Some(self.transaction_selector.clone()), - self.precommit_token_tracker.clone(), - self.client.clone(), - StreamOperation::Query(request), - )) + execute_stream_with_retry!(self, request, execute_streaming_sql, StreamOperation::Query) } pub(crate) async fn execute_read>( @@ -556,27 +633,15 @@ impl ReadContext { .set_transaction(self.transaction_selector.selector()); request.request_options = self.amend_request_options(request.request_options); - let stream = self - .client - .spanner - // TODO(#4972): make request options configurable - .streaming_read(request.clone(), crate::RequestOptions::default()) - .send() - .await?; - - Ok(ResultSet::new( - stream, - Some(self.transaction_selector.clone()), - self.precommit_token_tracker.clone(), - self.client.clone(), - StreamOperation::Read(request), - )) + execute_stream_with_retry!(self, request, streaming_read, StreamOperation::Read) } } #[cfg(test)] pub(crate) mod tests { use super::*; + use crate::result_set::tests::string_val; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; #[test] fn auto_traits() { @@ -922,4 +987,345 @@ pub(crate) mod tests { let result = rs.next().await; assert!(result.is_none(), "expected None, got {result:?}"); } + + #[tokio::test] + async fn inline_begin_failure_retry_success() -> anyhow::Result<()> { + use crate::value::Value; + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial query fails + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error"))); + + // 2. Explicit begin transaction succeeds + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + // Return a transaction with ID + Ok(Response::new(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: Some(prost_types::Timestamp { + seconds: 123456789, + nanos: 0, + }), + ..Default::default() + })) + }); + + // 3. Retry of the query succeeds + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + // Ensure it uses the new transaction ID + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![7, 8, 9]); + } + _ => panic!("Expected Selector::Id"), + } + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( + setup_select1(), + )])))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let mut rs = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + + let row = rs + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Expected a row but stream cleanly exhausted"))??; + assert_eq!( + row.raw_values(), + [Value(string_val("1"))], + "The parsed row value safely matched the underlying stream chunk" + ); + + Ok(()) + } + + #[tokio::test] + async fn inline_begin_failure_retry_failure() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial query fails + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error first"))); + + // 2. Explicit begin transaction succeeds + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: Some(prost_types::Timestamp { + seconds: 123456789, + nanos: 0, + }), + ..Default::default() + })) + }); + + // 3. Retry of the query fails again + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error second"))); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let rs_result = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await; + + assert!( + rs_result.is_err(), + "The failed execution bubbled upwards securely" + ); + let err_str = rs_result.unwrap_err().to_string(); + assert!( + err_str.contains("Internal error second"), + "Secondary error message accurately propagates: {}", + err_str + ); + + Ok(()) + } + + #[tokio::test] + async fn inline_begin_failure_fallback_rpc_fails() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Status; + + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial query fails + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error query"))); + + // 2. Explicit begin transaction fails + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error begin tx"))); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let rs_result = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await; + + assert!( + rs_result.is_err(), + "The explicitly errored fallback boot securely propagated outwards" + ); + let err_str = rs_result.unwrap_err().to_string(); + assert!( + err_str.contains("Internal error begin tx"), + "Natively propagated specific BeginTx bounds: {}", + err_str + ); + + Ok(()) + } + + #[tokio::test] + async fn inline_begin_read_failure_retry_success() -> anyhow::Result<()> { + use crate::client::{KeySet, ReadRequest}; + use crate::value::Value; + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial read fails + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error"))); + + // 2. Explicit begin transaction succeeds + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: None, + ..Default::default() + })) + }); + + // 3. Retry of the read succeeds + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + // Ensure it uses the new transaction ID + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![7, 8, 9]); + } + _ => panic!("Expected Selector::Id"), + } + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( + setup_select1(), + )])))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let read = ReadRequest::builder("Users", vec!["Id", "Name"]) + .with_keys(KeySet::all()) + .build(); + let mut rs = tx.execute_read(read).await?; + + let row = rs + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Expected a row uniquely returned"))??; + assert_eq!( + row.raw_values(), + [Value(string_val("1"))], + "The macro correctly unpacked read arrays seamlessly" + ); + + Ok(()) + } + + #[tokio::test] + async fn single_use_query_send_error_returns_immediately() -> anyhow::Result<()> { + use crate::client::Statement; + use gaxi::grpc::tonic::Status; + + let mut mock = create_session_mock(); + + mock.expect_execute_streaming_sql() + .times(1) + .returning(|_| Err(Status::internal("Internal error single use query"))); + + mock.expect_begin_transaction().never(); + + let (db_client, _server) = setup_db_client(mock).await; + // single_use creates a Fixed selector + let tx = db_client.single_use().build(); + + let rs_result = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await; + + assert!(rs_result.is_err()); + let err_str = rs_result.unwrap_err().to_string(); + assert!(err_str.contains("Internal error single use query")); + + Ok(()) + } + + #[tokio::test] + async fn inline_begin_already_started_query_send_error_returns_immediately() + -> anyhow::Result<()> { + use crate::client::Statement; + use gaxi::grpc::tonic::Status; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + mock.expect_begin_transaction().never(); + + // 1. First query executes successfully and implicitly starts the transaction. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(move |_req| { + let mut rs = setup_select1(); + rs.metadata.as_mut().unwrap().transaction = Some(mock_v1::Transaction { + id: vec![4, 5, 6], + read_timestamp: None, + ..Default::default() + }); + Ok(gaxi::grpc::tonic::Response::new(Box::pin( + tokio_stream::iter(vec![Ok(rs)]), + ))) + }); + + // 2. Second query fails immediately upon send() + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error second query"))); + + let (db_client, _server) = setup_db_client(mock).await; + + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + // Run first query (starts tx) + let mut rs = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + let _ = rs.next().await.expect("has row")?; + + // Run second query (fails) + let rs_result = tx + .execute_query(Statement::builder("SELECT 2").build()) + .await; + + assert!(rs_result.is_err()); + let err_str = rs_result.unwrap_err().to_string(); + assert!(err_str.contains("Internal error second query")); + + Ok(()) + } } diff --git a/src/spanner/src/result_set.rs b/src/spanner/src/result_set.rs index 6bd848b175..435ffe7a1a 100644 --- a/src/spanner/src/result_set.rs +++ b/src/spanner/src/result_set.rs @@ -233,7 +233,29 @@ impl ResultSet { return Ok(()); } - Err(e) + // Check if this stream included an inlined BeginTransaction option + // and has not yet returned a transaction ID. If so, we explicitly + // begin the transaction and restart the stream. + let Some(ReadContextTransactionSelector::Lazy(lazy)) = &self.transaction_selector else { + return Err(e); + }; + let is_started = matches!( + &*lazy.lock().unwrap(), + crate::read_only_transaction::TransactionState::Started(_, _) + ); + if is_started { + return Err(e); + } + + self.transaction_selector + .as_ref() + .unwrap() + .begin_explicitly(&self.client) + .await?; + + self.partial_result_sets_buffer.clear(); + self.restart_stream().await?; + Ok(()) } fn handle_stream_end(&mut self) -> crate::Result> { @@ -281,15 +303,25 @@ impl ResultSet { (None, Some(mut m)) => { let transaction = m.transaction.take(); self.metadata = Some(ResultSetMetadata::new(Some(m))); - if let (Some(selector), Some(transaction)) = - (&self.transaction_selector, transaction) - { - selector.update( - transaction.id, - transaction - .read_timestamp - .and_then(|t| wkt::Timestamp::new(t.seconds, t.nanos).ok()), - ); + if let Some(selector) = &self.transaction_selector { + if let Some(transaction) = transaction { + selector.update( + transaction.id, + transaction + .read_timestamp + .and_then(|t| wkt::Timestamp::new(t.seconds, t.nanos).ok()), + ); + } else if let ReadContextTransactionSelector::Lazy(lazy) = selector { + let is_started = matches!( + &*lazy.lock().expect("transaction state mutex poisoned"), + crate::read_only_transaction::TransactionState::Started(_, _) + ); + if !is_started { + return Err(internal_error( + "Spanner failed to return a transaction ID for a query that included a BeginTransaction option", + )); + } + } } } } @@ -336,9 +368,15 @@ impl ResultSet { } async fn restart_stream(&mut self) -> crate::Result<()> { + // Get the latest transaction selector for this transaction. + let transaction_selector = self.transaction_selector.as_ref().map(|s| s.selector()); + match &mut self.operation { StreamOperation::Query(req) => { req.resume_token = self.last_resume_token.clone(); + req.transaction = transaction_selector + .clone() + .or_else(|| req.transaction.take()); let stream = self .client .spanner @@ -349,6 +387,9 @@ impl ResultSet { } StreamOperation::Read(req) => { req.resume_token = self.last_resume_token.clone(); + req.transaction = transaction_selector + .clone() + .or_else(|| req.transaction.take()); let stream = self .client .spanner @@ -465,6 +506,7 @@ pub(crate) mod tests { use super::*; use crate::client::Spanner; use gaxi::grpc::tonic::Response; + use google_cloud_auth::credentials::anonymous::Builder as Anonymous; use prost_types::Value; use spanner_grpc_mock::MockSpanner; use spanner_grpc_mock::google::spanner::v1::spanner_server::Spanner as SpannerTrait; @@ -528,7 +570,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await .expect("Failed to build client"); @@ -725,7 +767,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -1081,7 +1123,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -1137,7 +1179,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -1207,7 +1249,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -1296,7 +1338,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -1375,7 +1417,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -1427,7 +1469,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -1448,4 +1490,368 @@ pub(crate) mod tests { Ok(()) } + + #[tokio::test] + async fn result_set_inline_begin_stream_error_fallback() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + let mut seq = mockall::Sequence::new(); + + // 1. Stream yields an error on the first chunk before returning transaction metadata. + // E.g., INVALID_ARGUMENT because the query is malformed. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let stream = + tokio_stream::iter(vec![Err(Status::invalid_argument("Invalid query"))]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + // 2. The explicit BeginTransaction fallback gets triggered. + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: Some(prost_types::Timestamp { + seconds: 123456789, + nanos: 0, + }), + ..Default::default() + })) + }); + + // 3. The ResultSet gracefully restarts the stream using the transaction ID returned by BeginTransaction. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + // Ensure the explicitly yielded ID is routed into the new stream transaction selector + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![7, 8, 9]); + } + _ => panic!("Expected Selector::Id"), + } + + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: metadata(1), + values: vec![string_val("1")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let mut rs = tx.execute_query("SELECT 1").await?; + + let row1 = rs.next().await.ok_or_else(|| { + anyhow::anyhow!("Expected row returned successfully despite stream breaking") + })??; + assert_eq!( + row1.raw_values()[0].0, + string_val("1"), + "Verify the returned stream successfully resumed with the correct payload" + ); + + Ok(()) + } + + #[tokio::test] + async fn result_set_retry_inline_begin_transient_error() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial stream throws UNAVAILABLE before metadata. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let stream = + tokio_stream::iter(vec![Err(Status::unavailable("Transient network issue"))]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + // 2. We retry the stream since it was a transient error. + // The retry should use the same transaction selector as the original request. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin on stream retry"), + } + + let mut meta = metadata(1).unwrap(); + meta.transaction = Some(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: None, + ..Default::default() + }); + + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: Some(meta), + values: vec![string_val("1")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let mut rs = tx.execute_query("SELECT 1").await?; + + let row1 = rs + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Expected stream to recover safely"))??; + assert_eq!( + row1.raw_values()[0].0, + string_val("1"), + "Verify resumed stream returns data" + ); + + Ok(()) + } + + #[tokio::test] + async fn result_set_retry_inline_begin_id_recovered() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + let mut seq = mockall::Sequence::new(); + + // 1. Stream successfully returns metadata chunk then throws UNAVAILABLE on chunk 2. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let mut meta = metadata(1).unwrap(); + meta.transaction = Some(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: None, + ..Default::default() + }); + let stream = tokio_stream::iter(vec![ + Ok(PartialResultSet { + metadata: Some(meta), + values: vec![string_val("1")], + resume_token: b"token1".to_vec(), + ..Default::default() + }), + Err(Status::unavailable("Transient mid-stream network issue")), + ]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + // 2. Stream resumes using Selector::Id. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![7, 8, 9]); + } + _ => panic!("Expected Selector::Id on stream retry"), + } + + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + values: vec![string_val("2")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let mut rs = tx.execute_query("SELECT 1").await?; + + let row1 = rs + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Expected stream row1 extracted"))??; + assert_eq!( + row1.raw_values()[0].0, + string_val("1"), + "Verified chunk 1 payload" + ); + let row2 = rs + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Expected stream row2 recovered"))??; + assert_eq!( + row2.raw_values()[0].0, + string_val("2"), + "Verified chunk 2 reboot dynamically intercepted ID bounds correctly" + ); + + Ok(()) + } + + #[tokio::test] + async fn result_set_inline_begin_metadata_missing_transaction_fails() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial stream successfully returns metadata chunk but completely lacks the `Transaction` entity. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: metadata(1), // Missing `.transaction` natively + values: vec![string_val("1")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + + // Use explicitly deferred Lazy begin transaction! + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let mut rs = tx.execute_query("SELECT 1").await?; + + let rs_result = rs + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Expected explicit crash bound properly"))?; + assert!( + rs_result.is_err(), + "Securely aborted when metadata failed to package internal bounds properly" + ); + + let err_str = rs_result.unwrap_err().to_string(); + assert!( + err_str.contains("failed to return a transaction ID"), + "Caught implicit gap boundary: {}", + err_str + ); + + Ok(()) + } } diff --git a/tests/spanner/src/query.rs b/tests/spanner/src/query.rs index fe02427a19..a1483a65cc 100644 --- a/tests/spanner/src/query.rs +++ b/tests/spanner/src/query.rs @@ -246,6 +246,42 @@ async fn test_multi_use_read_only_transaction( Ok(()) } +pub async fn multi_use_read_only_transaction_invalid_query_fallback( + db_client: &DatabaseClient, +) -> anyhow::Result<()> { + // Start a multi-use read-only transaction with implicit begin. + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + // Expect a read timestamp to NOT have been chosen yet. + assert!(tx.read_timestamp().is_none()); + + // Execute the first query with invalid syntax. + let rs_result = tx + .execute_query(Statement::builder("SELECT * FROM NonExistentTable").build()) + .await; + + assert!(rs_result.is_err(), "Expected an error from an invalid query"); + + // The read timestamp should now be available because the transaction + // fell back to an explicit BeginTransaction. + assert!(tx.read_timestamp().is_some()); + + // It should be possible to use the transaction. + let mut rs2 = tx + .execute_query(Statement::builder("SELECT 2 AS col_int").build()) + .await?; + + let row2 = rs2.next().await.transpose()?.expect("should yield a row"); + let val2 = row2.raw_values()[0].as_string(); + assert_eq!(val2, "2"); + + Ok(()) +} + fn verify_null_row(row: &google_cloud_spanner::client::Row) { let raw_values = row.raw_values(); assert_eq!(raw_values.len(), 20, "Row should have exactly 20 columns"); diff --git a/tests/spanner/tests/driver.rs b/tests/spanner/tests/driver.rs index 5d6ff838dc..976867b652 100644 --- a/tests/spanner/tests/driver.rs +++ b/tests/spanner/tests/driver.rs @@ -26,6 +26,10 @@ mod spanner { integration_tests_spanner::query::query_with_parameters(&db_client).await?; integration_tests_spanner::query::result_set_metadata(&db_client).await?; integration_tests_spanner::query::multi_use_read_only_transaction(&db_client).await?; + integration_tests_spanner::query::multi_use_read_only_transaction_invalid_query_fallback( + &db_client, + ) + .await?; Ok(()) } From 255a00797bbbbc0bcd2484828a848a7e67ebd6d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 2 Apr 2026 20:49:19 +0200 Subject: [PATCH 03/17] test(spanner): add integration test for inline-begin error handling Adds an integration test for error handling for inline-begin-transaction. This test uses a gRPC proxy to intercept calls from the client to Spanner to be able to deterministically emulate specific concurrency issues. This test shows how a query that failed during the first attempt, and thereby also failed to start the transaction, could succeed during a retry after the transaction has been started with an explicit BeginTransaction RPC. --- Cargo.lock | 3 + deny.toml | 1 + tests/spanner/Cargo.toml | 3 + tests/spanner/src/client.rs | 21 +-- tests/spanner/src/lib.rs | 1 + tests/spanner/src/query.rs | 178 ++++++++++++++++++++- tests/spanner/src/test_proxy.rs | 269 ++++++++++++++++++++++++++++++++ tests/spanner/tests/driver.rs | 1 + 8 files changed, 466 insertions(+), 11 deletions(-) create mode 100644 tests/spanner/src/test_proxy.rs diff --git a/Cargo.lock b/Cargo.lock index 5bce4a9341..635b1a79ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5863,7 +5863,10 @@ dependencies = [ "prost-types", "reqwest 0.13.2", "serde_json", + "spanner-grpc-mock", "tokio", + "tokio-stream", + "tonic", "tracing", ] diff --git a/deny.toml b/deny.toml index fa61a89904..67b8359d88 100644 --- a/deny.toml +++ b/deny.toml @@ -115,6 +115,7 @@ wrappers = [ # Use in tests is fine. "grpc-server", "integration-tests-o11y", + "integration-tests-spanner", "pubsub-grpc-mock", "spanner-grpc-mock", "storage-grpc-mock", diff --git a/tests/spanner/Cargo.toml b/tests/spanner/Cargo.toml index 461c18d110..30b6625f58 100644 --- a/tests/spanner/Cargo.toml +++ b/tests/spanner/Cargo.toml @@ -36,7 +36,10 @@ google-cloud-test-utils = { workspace = true } prost-types.workspace = true reqwest = { workspace = true, features = ["json"] } serde_json = { workspace = true } +spanner-grpc-mock = { path = "../../src/spanner/grpc-mock" } tokio = { workspace = true, features = ["sync"] } +tokio-stream = { workspace = true } +tonic = { workspace = true } tracing.workspace = true [lints] diff --git a/tests/spanner/src/client.rs b/tests/spanner/src/client.rs index 7b05ecf3c5..b06f310ab8 100644 --- a/tests/spanner/src/client.rs +++ b/tests/spanner/src/client.rs @@ -40,7 +40,7 @@ pub async fn wait_for_emulator(endpoint: &str) { static PROVISION_EMULATOR: tokio::sync::OnceCell<()> = tokio::sync::OnceCell::const_new(); static DATABASE_ID: tokio::sync::OnceCell = tokio::sync::OnceCell::const_new(); -async fn get_database_id() -> &'static str { +pub async fn get_database_id() -> &'static str { DATABASE_ID .get_or_init(|| async { std::env::var("SPANNER_EMULATOR_TEST_DB") @@ -59,16 +59,19 @@ pub async fn provision_emulator(endpoint: &str) { .await; } +pub fn get_emulator_rest_endpoint(grpc_endpoint: &str) -> String { + let rest_endpoint = std::env::var("SPANNER_EMULATOR_REST_HOST") + .unwrap_or_else(|_| grpc_endpoint.replace("9010", "9020")); + if rest_endpoint.starts_with("http://") || rest_endpoint.starts_with("https://") { + rest_endpoint + } else { + format!("http://{}", rest_endpoint) + } +} + async fn do_provision_emulator(endpoint: &str) { // TODO(#4973): Re-write this to use the admin clients once those also support the Emulator. - let rest_endpoint = std::env::var("SPANNER_EMULATOR_REST_HOST") - .unwrap_or_else(|_| endpoint.replace("9010", "9020")); - let rest_endpoint = - if rest_endpoint.starts_with("http://") || rest_endpoint.starts_with("https://") { - rest_endpoint - } else { - format!("http://{}", rest_endpoint) - }; + let rest_endpoint = get_emulator_rest_endpoint(endpoint); let client = reqwest::Client::new(); // Create a test instance and ignore any ALREADY_EXISTS errors. diff --git a/tests/spanner/src/lib.rs b/tests/spanner/src/lib.rs index a59c4ee1a5..c300c74674 100644 --- a/tests/spanner/src/lib.rs +++ b/tests/spanner/src/lib.rs @@ -17,4 +17,5 @@ pub mod partitioned_dml; pub mod query; pub mod read; pub mod read_write_transaction; +pub mod test_proxy; pub mod write; diff --git a/tests/spanner/src/query.rs b/tests/spanner/src/query.rs index a1483a65cc..06cae8cd6a 100644 --- a/tests/spanner/src/query.rs +++ b/tests/spanner/src/query.rs @@ -12,7 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -use google_cloud_spanner::client::{DatabaseClient, Kind, Statement}; +use crate::client::{get_database_id, get_emulator_host, get_emulator_rest_endpoint}; +use crate::test_proxy::{InterceptedSpanner, SpannerInterceptor}; +use google_cloud_spanner::client::{DatabaseClient, Kind, Spanner, Statement}; +use google_cloud_test_utils::resource_names::LowercaseAlphanumeric; +use spanner_grpc_mock::google::spanner::v1 as spanner_v1; +use spanner_grpc_mock::google::spanner::v1::spanner_client::SpannerClient; +use spanner_grpc_mock::google::spanner::v1::spanner_server::SpannerServer; +use std::sync::Arc; +use tokio::net::TcpListener; +use tokio::sync::Notify; +use tokio_stream::wrappers::TcpListenerStream; +use tonic::transport::{Channel, Server}; pub async fn simple_query(db_client: &DatabaseClient) -> anyhow::Result<()> { let rot = db_client.single_use().build(); @@ -264,7 +275,10 @@ pub async fn multi_use_read_only_transaction_invalid_query_fallback( .execute_query(Statement::builder("SELECT * FROM NonExistentTable").build()) .await; - assert!(rs_result.is_err(), "Expected an error from an invalid query"); + assert!( + rs_result.is_err(), + "Expected an error from an invalid query" + ); // The read timestamp should now be available because the transaction // fell back to an explicit BeginTransaction. @@ -405,3 +419,163 @@ fn verify_row_2(row: &google_cloud_spanner::client::Row) { "2026-03-11T16:20:00Z" ); } + +struct DelayedBeginProxy { + emulator_client: SpannerClient, + latch: Arc, + begin_transaction_entered_latch: Arc, +} + +#[tonic::async_trait] +impl SpannerInterceptor for DelayedBeginProxy { + fn emulator_client(&self) -> SpannerClient { + self.emulator_client.clone() + } + + async fn begin_transaction( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.begin_transaction_entered_latch.notify_one(); + self.latch.notified().await; + self.emulator_client().begin_transaction(request).await + } +} + +// This test verifies that the client correctly falls back to `BeginTransaction` when the +// first statement in a transaction fails. It also shows that the statement is retried and +// could (theoretically) succeed during this retry. It achieves this by doing the following: +// 1. It uses a proxy that allows it to intercept the RPCs that are being sent to Spanner. +// 2. It creates a read-only transaction that uses inline-begin-transaction. +// 3. It executes a query that tries to read from a table that does not exist. +// 4. As the first statement in the transaction fails, the client falls back to using +// an explicit BeginTransaction RPC. +// 5. The proxy blocks this BeginTransaction RPC, and in the meantime the test creates +// the missing table. +// 6. The proxy unblocks the BeginTransaction RPC. +// 7. The statement is retried and succeeds. The test never sees the error. +// +// This test might seem like an extreme corner case for a read-only transaction like this. +// However, for read/write transactions, similar types of failures are more likely to occur, +// for example if a transaction tries to insert a row that violates the primary key. Another +// transaction could delete the row in the time between the first attempt failed, and the +// BeginTransaction RPC has been executed. +pub async fn inline_begin_fallback(_db_client: &DatabaseClient) -> anyhow::Result<()> { + let emulator_host = get_emulator_host().expect("SPANNER_EMULATOR_HOST must be set"); + let latch = Arc::new(Notify::new()); + let begin_transaction_entered_latch = Arc::new(Notify::new()); + + // Create a raw gRPC client that connects to the Spanner Emulator. + // This will be used by the proxy server to forward requests to the Emulator. + let endpoint = Channel::from_shared(format!("http://{}", emulator_host))? + .connect() + .await?; + let raw_client = SpannerClient::new(endpoint); + + // Create a local TCP listener to bind our proxy server to. + let listener = TcpListener::bind("127.0.0.1:0").await?; + let local_addr = listener.local_addr()?; + let proxy_address = format!("{}:{}", local_addr.ip(), local_addr.port()); + + let proxy = DelayedBeginProxy { + emulator_client: raw_client, + latch: Arc::clone(&latch), + begin_transaction_entered_latch: Arc::clone(&begin_transaction_entered_latch), + }; + + let _server_handle = tokio::spawn(async move { + let stream = TcpListenerStream::new(listener); + Server::builder() + .add_service(SpannerServer::new(InterceptedSpanner(proxy))) + .serve_with_incoming(stream) + .await + .expect("Proxy server failed"); + }); + + // We build the Spanner DatabaseClient pointing directly to our proxy address over HTTP. + let proxy_db_client = Spanner::builder() + .with_endpoint(format!("http://{}", proxy_address)) + .build() + .await? + .database_client(format!( + "projects/test-project/instances/test-instance/databases/{}", + get_database_id().await + )) + .build() + .await?; + + let tx = proxy_db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let table_name = LowercaseAlphanumeric.random_string(10); + let table_name = format!("LateLoadedTable_{}", table_name); + + // Create a task that tries to query the table before it exists. + // This will initially fail, and the client will fall back to using + // an explicit BeginTransaction RPC. The table will then be created + // BEFORE the BeginTransaction RPC is executed, which will cause the + // query to succeed when it is retried using the transaction ID that + // was returned by BeginTransaction. This task will never see the + // initial error, and instead it will seem like the query simply + // succeeded. + let query_task = tokio::spawn({ + let table_name = table_name.clone(); + async move { + let stmt = Statement::builder(format!("SELECT * FROM {}", table_name)).build(); + let mut rs = tx.execute_query(stmt).await?; + let _ = rs.next().await; + Ok::<_, anyhow::Error>(tx) + } + }); + + // Wait until the query task above has been executed and has triggered an + // explicit BeginTransaction RPC. The BeginTransaction RPC is blocked until + // `latch` is notified. + begin_transaction_entered_latch.notified().await; + + // Create the table on the emulator while the BeginTransaction RPC is blocked. + let rest_endpoint = get_emulator_rest_endpoint(&emulator_host); + + let client = reqwest::Client::new(); + let database_payload = serde_json::json!({ + "statements": [ + format!("CREATE TABLE {} (Id INT64) PRIMARY KEY (Id)", table_name) + ] + }); + + let db_path = format!( + "projects/test-project/instances/test-instance/databases/{}", + get_database_id().await + ); + let update_database_ddl_url = format!("{}/v1/{}/ddl", rest_endpoint, db_path); + + let res: reqwest::Response = client + .patch(&update_database_ddl_url) + .json(&database_payload) + .send() + .await + .expect("Failed to send CREATE TABLE request"); + + assert!( + res.status().is_success(), + "Failed to update DDL: {}", + res.text().await.unwrap() + ); + + // Unblock the BeginTransaction RPC. + latch.notify_one(); + + // Wait for the query task to complete. It should succeed and never see + // the initial error. + let tx = query_task.await??; + + assert!( + tx.read_timestamp().is_some(), + "The transaction should have a read timestamp" + ); + + Ok(()) +} diff --git a/tests/spanner/src/test_proxy.rs b/tests/spanner/src/test_proxy.rs new file mode 100644 index 0000000000..7f0a5837a2 --- /dev/null +++ b/tests/spanner/src/test_proxy.rs @@ -0,0 +1,269 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use spanner_grpc_mock::google::spanner::v1 as spanner_v1; +use spanner_grpc_mock::google::spanner::v1::spanner_client::SpannerClient; + +#[tonic::async_trait] +pub trait SpannerInterceptor: Send + Sync + 'static { + fn emulator_client(&self) -> SpannerClient; + + async fn create_session( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().create_session(request).await + } + + async fn batch_create_sessions( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> + { + self.emulator_client().batch_create_sessions(request).await + } + + async fn get_session( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().get_session(request).await + } + + async fn list_sessions( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().list_sessions(request).await + } + + async fn delete_session( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().delete_session(request).await + } + + async fn execute_sql( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().execute_sql(request).await + } + + async fn execute_streaming_sql( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response>, + tonic::Status, + > { + self.emulator_client().execute_streaming_sql(request).await + } + + async fn execute_batch_dml( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> + { + self.emulator_client().execute_batch_dml(request).await + } + + async fn read( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().read(request).await + } + + async fn streaming_read( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response>, + tonic::Status, + > { + self.emulator_client().streaming_read(request).await + } + + async fn begin_transaction( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().begin_transaction(request).await + } + + async fn commit( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().commit(request).await + } + + async fn rollback( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().rollback(request).await + } + + async fn partition_query( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().partition_query(request).await + } + + async fn partition_read( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().partition_read(request).await + } + + async fn batch_write( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response>, + tonic::Status, + > { + self.emulator_client().batch_write(request).await + } +} + +pub struct InterceptedSpanner(pub T); + +#[tonic::async_trait] +impl spanner_v1::spanner_server::Spanner for InterceptedSpanner { + async fn create_session( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.create_session(request).await + } + + async fn batch_create_sessions( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> + { + self.0.batch_create_sessions(request).await + } + + async fn get_session( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.get_session(request).await + } + + async fn list_sessions( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.list_sessions(request).await + } + + async fn delete_session( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.delete_session(request).await + } + + async fn execute_sql( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.execute_sql(request).await + } + + type ExecuteStreamingSqlStream = tonic::codec::Streaming; + + async fn execute_streaming_sql( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.execute_streaming_sql(request).await + } + + async fn execute_batch_dml( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> + { + self.0.execute_batch_dml(request).await + } + + async fn read( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.read(request).await + } + + type StreamingReadStream = tonic::codec::Streaming; + + async fn streaming_read( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.streaming_read(request).await + } + + async fn begin_transaction( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.begin_transaction(request).await + } + + async fn commit( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.commit(request).await + } + + async fn rollback( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.rollback(request).await + } + + async fn partition_query( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.partition_query(request).await + } + + async fn partition_read( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.partition_read(request).await + } + + type BatchWriteStream = tonic::codec::Streaming; + + async fn batch_write( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.batch_write(request).await + } +} diff --git a/tests/spanner/tests/driver.rs b/tests/spanner/tests/driver.rs index 976867b652..6751a22484 100644 --- a/tests/spanner/tests/driver.rs +++ b/tests/spanner/tests/driver.rs @@ -30,6 +30,7 @@ mod spanner { &db_client, ) .await?; + integration_tests_spanner::query::inline_begin_fallback(&db_client).await?; Ok(()) } From c7512d1498dd1dcaa22f7cc2c7003751c2873366 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 3 Apr 2026 18:10:18 +0200 Subject: [PATCH 04/17] perf(spanner): support concurrent queries with inline-begin-transaction Adds support for running concurrent queries in combination with inline-begin-transaction. Only one of the queries will include the BeginTransaction option. The other queries will wait until the first query has returned a transaction ID. --- Cargo.lock | 3 + .../src/batch_read_only_transaction.rs | 10 +- src/spanner/src/read_only_transaction.rs | 940 ++++++++++++++++-- src/spanner/src/read_write_transaction.rs | 12 +- src/spanner/src/result_set.rs | 12 +- src/spanner/src/transaction_runner.rs | 2 +- tests/spanner/Cargo.toml | 3 + tests/spanner/src/concurrent_inline_begin.rs | 277 ++++++ tests/spanner/src/lib.rs | 1 + tests/spanner/src/test_proxy.rs | 52 +- tests/spanner/tests/driver.rs | 5 + 11 files changed, 1234 insertions(+), 83 deletions(-) create mode 100644 tests/spanner/src/concurrent_inline_begin.rs diff --git a/Cargo.lock b/Cargo.lock index 635b1a79ba..7c17ccc0f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5860,10 +5860,13 @@ dependencies = [ "google-cloud-lro", "google-cloud-spanner", "google-cloud-test-utils", + "google-cloud-wkt", "prost-types", + "rand 0.10.0", "reqwest 0.13.2", "serde_json", "spanner-grpc-mock", + "time", "tokio", "tokio-stream", "tonic", diff --git a/src/spanner/src/batch_read_only_transaction.rs b/src/spanner/src/batch_read_only_transaction.rs index 179ac26dc2..0332f67ffd 100644 --- a/src/spanner/src/batch_read_only_transaction.rs +++ b/src/spanner/src/batch_read_only_transaction.rs @@ -143,12 +143,13 @@ impl BatchReadOnlyTransaction { statement: T, options: PartitionOptions, ) -> crate::Result> { + let selector = self.inner.context.transaction_selector.selector().await?; let statement = statement.into(); let request = statement .clone() .into_partition_query_request() .set_session(self.inner.context.client.session.name.clone()) - .set_transaction(self.inner.context.transaction_selector.selector()) + .set_transaction(selector.clone()) .set_partition_options(options); let response = self @@ -165,7 +166,7 @@ impl BatchReadOnlyTransaction { .map(|p| Partition { inner: PartitionedOperation::Query { partition_token: p.partition_token, - transaction_selector: self.inner.context.transaction_selector.selector(), + transaction_selector: selector.clone(), session_name: self.inner.context.client.session.name.clone(), statement: statement.clone(), }, @@ -198,12 +199,13 @@ impl BatchReadOnlyTransaction { read: T, options: PartitionOptions, ) -> crate::Result> { + let selector = self.inner.context.transaction_selector.selector().await?; let read = read.into(); let request = read .clone() .into_partition_read_request() .set_session(self.inner.context.client.session.name.clone()) - .set_transaction(self.inner.context.transaction_selector.selector()) + .set_transaction(selector.clone()) .set_partition_options(options); let response = self @@ -220,7 +222,7 @@ impl BatchReadOnlyTransaction { .map(|p| Partition { inner: PartitionedOperation::Read { partition_token: p.partition_token, - transaction_selector: self.inner.context.transaction_selector.selector(), + transaction_selector: selector.clone(), session_name: self.inner.context.client.session.name.clone(), read_request: read.clone(), }, diff --git a/src/spanner/src/read_only_transaction.rs b/src/spanner/src/read_only_transaction.rs index 041d74e3b1..45d5bde4b4 100644 --- a/src/spanner/src/read_only_transaction.rs +++ b/src/spanner/src/read_only_transaction.rs @@ -20,6 +20,7 @@ use crate::result_set::{ResultSet, StreamOperation}; use crate::statement::Statement; use crate::timestamp_bound::TimestampBound; use std::sync::{Arc, Mutex}; +use tokio::sync::Notify; /// A builder for [SingleUseReadOnlyTransaction]. /// @@ -445,28 +446,72 @@ pub(crate) enum ReadContextTransactionSelector { #[derive(Clone, Debug)] pub(crate) enum TransactionState { NotStarted(crate::model::TransactionOptions), + Starting(crate::model::TransactionOptions, Arc), Started(crate::model::TransactionSelector, Option), + Failed(Arc), } -impl TransactionState { - fn selector(&self) -> crate::model::TransactionSelector { - match self { - Self::Started(selector, _) => selector.clone(), - Self::NotStarted(options) => { - crate::model::TransactionSelector::default().set_begin(options.clone()) - } - } - } +enum SelectorStatus { + Ready(crate::model::TransactionSelector), + Wait(std::sync::Arc), } impl ReadContextTransactionSelector { - pub(crate) fn selector(&self) -> crate::model::TransactionSelector { + pub(crate) async fn selector(&self) -> crate::Result { match self { - Self::Fixed(selector, _) => selector.clone(), - Self::Lazy(lazy) => lazy - .lock() - .expect("transaction state mutex poisoned") - .selector(), + Self::Fixed(selector, _) => Ok(selector.clone()), + Self::Lazy(_) => loop { + match self.poll_selector_status()? { + SelectorStatus::Ready(selector) => return Ok(selector), + SelectorStatus::Wait(notify) => notify.notified().await, + } + }, + } + } + + /// Inspects the current lazy selector state returning whether it is ready, + /// failed, or needs to wait for the transaction to start. + fn poll_selector_status(&self) -> crate::Result { + let Self::Lazy(lazy) = self else { + unreachable!("poll_selector_status called on non-Lazy selector"); + }; + let mut guard = lazy.lock().expect("transaction state mutex poisoned"); + + // Fast path: Transaction is already started. + if let TransactionState::Started(selector, _) = &*guard { + return Ok(SelectorStatus::Ready(selector.clone())); + } + + // If the transaction has not started, extract options and proceed to transition. + let pending_options = if let TransactionState::NotStarted(options) = &*guard { + Some(options.clone()) + } else { + None + }; + if let Some(options) = pending_options { + let notify = Arc::new(Notify::new()); + *guard = TransactionState::Starting(options.clone(), Arc::clone(¬ify)); + return Ok(SelectorStatus::Ready( + crate::model::TransactionSelector::default().set_begin(options), + )); + } + + // Handle other states: yield error or wait. + match &*guard { + // Note: Failed will only be reached if the following happens: + // 1. The first query fails and the transaction falls back to an explicit BeginTransaction RPC. + // 2. The BeginTransaction RPC fails. This is the error that will be returned to all the waiting queries. + TransactionState::Failed(err) => { + let error = if let Some(status) = err.status() { + crate::Error::service(status.clone()) + } else { + crate::error::internal_error(format!("Transaction failed to start: {}", err)) + }; + Err(error) + } + // Transaction is starting. Wait until a transaction ID is returned. + TransactionState::Starting(_, notify) => Ok(SelectorStatus::Wait(Arc::clone(notify))), + TransactionState::Started(_, _) | TransactionState::NotStarted(_) => unreachable!(), } } @@ -482,29 +527,101 @@ impl ReadContextTransactionSelector { return Ok(()); }; - let options = { + let (options, notify_opt) = { let guard = lazy.lock().expect("transaction state mutex poisoned"); - let TransactionState::NotStarted(options) = &*guard else { - return Ok(()); - }; - options.clone() + match &*guard { + // This should never happen in the current implementation. + TransactionState::NotStarted(_) => { + return Err(crate::error::internal_error( + "explicit begin with NotStarted state is currently unsupported", + )); + } + TransactionState::Starting(options, notify) => { + (options.clone(), Some(Arc::clone(notify))) + } + TransactionState::Started(_, _) | TransactionState::Failed(_) => return Ok(()), + } + }; + + let response = match execute_begin_transaction(client, options).await { + Ok(r) => r, + Err(e) => { + let mut guard = lazy.lock().expect("transaction state mutex poisoned"); + let error = Arc::new(e); + *guard = TransactionState::Failed(Arc::clone(&error)); + // Release the lock and notify all the waiting queries that + // the transaction has failed. + drop(guard); + if let Some(notify) = notify_opt { + notify.notify_waiters(); + } + + let return_error = if let Some(status) = error.status() { + crate::Error::service(status.clone()) + } else { + crate::error::internal_error(format!("Transaction failed to start: {}", error)) + }; + return Err(return_error); + } }; - let response = execute_begin_transaction(client, options).await?; - self.update(response.id, response.read_timestamp); + self.update(response.id, response.read_timestamp)?; Ok(()) } - pub(crate) fn update(&self, id: bytes::Bytes, timestamp: Option) { - if let Self::Lazy(lazy) = self { - let mut guard = lazy.lock().expect("transaction state mutex poisoned"); - if matches!(&*guard, TransactionState::NotStarted(_)) { - *guard = TransactionState::Started( + pub(crate) fn update( + &self, + id: bytes::Bytes, + timestamp: Option, + ) -> crate::Result<()> { + let Self::Lazy(lazy) = self else { + return Ok(()); + }; + let mut guard = lazy.lock().expect("transaction state mutex poisoned"); + + if matches!( + &*guard, + TransactionState::NotStarted(_) | TransactionState::Starting(_, _) + ) { + let previous_state = std::mem::replace( + &mut *guard, + TransactionState::Started( crate::model::TransactionSelector::default().set_id(id), timestamp, - ); + ), + ); + drop(guard); + + // Notify all queries that are waiting for the transaction. + if let TransactionState::Starting(_, notify) = previous_state { + notify.notify_waiters(); } + Ok(()) + } else { + Err(crate::error::internal_error( + "got a transaction id for an already Started or Failed transaction", + )) + } + } + + /// Resets the selector state from `Starting` back to `NotStarted`. + /// + /// This is used during stream resume fallbacks when the first query stream + /// fails before yielding a transaction ID. It unlocks any parked waiters + /// allowing them (or the retry attempt) to include the begin option again. + pub(crate) fn maybe_reset_starting(&self) { + let Self::Lazy(lazy) = self else { + return; + }; + + let mut guard = lazy.lock().expect("transaction state mutex poisoned"); + if let TransactionState::Starting(options, notify) = &*guard { + let options = options.clone(); + let notify = Arc::clone(notify); + *guard = TransactionState::NotStarted(options); + drop(guard); + notify.notify_waiters(); } } @@ -583,7 +700,7 @@ macro_rules! execute_stream_with_retry { Ok(s) => s, Err(e) => { if $self.begin_explicitly_if_not_started().await? { - $request.transaction = Some($self.transaction_selector.selector()); + $request.transaction = Some($self.transaction_selector.selector().await?); $self .client .spanner @@ -616,7 +733,7 @@ impl ReadContext { .into() .into_request() .set_session(self.client.session.name.clone()) - .set_transaction(self.transaction_selector.selector()); + .set_transaction(self.transaction_selector.selector().await?); request.request_options = self.amend_request_options(request.request_options); execute_stream_with_retry!(self, request, execute_streaming_sql, StreamOperation::Query) @@ -630,7 +747,7 @@ impl ReadContext { .into() .into_request() .set_session(self.client.session.name.clone()) - .set_transaction(self.transaction_selector.selector()); + .set_transaction(self.transaction_selector.selector().await?); request.request_options = self.amend_request_options(request.request_options); execute_stream_with_retry!(self, request, streaming_read, StreamOperation::Read) @@ -640,8 +757,16 @@ impl ReadContext { #[cfg(test)] pub(crate) mod tests { use super::*; + use crate::client::Statement; use crate::result_set::tests::string_val; + use crate::value::Value; + use gaxi::grpc::tonic::{self, Code, Response, Status}; + use mock_v1::transaction_selector::Selector; use spanner_grpc_mock::google::spanner::v1 as mock_v1; + use std::pin::Pin; + use std::sync::Arc; + use std::task::{Context, Poll}; + use tokio::sync::{Barrier, Mutex, Notify, mpsc}; #[test] fn auto_traits() { @@ -655,12 +780,10 @@ pub(crate) mod tests { pub(crate) fn create_session_mock() -> spanner_grpc_mock::MockSpanner { let mut mock = spanner_grpc_mock::MockSpanner::new(); mock.expect_create_session().once().returning(|_| { - Ok(gaxi::grpc::tonic::Response::new( - spanner_grpc_mock::google::spanner::v1::Session { - name: "projects/p/instances/i/databases/d/sessions/123".to_string(), - ..Default::default() - }, - )) + Ok(Response::new(mock_v1::Session { + name: "projects/p/instances/i/databases/d/sessions/123".to_string(), + ..Default::default() + })) }); mock } @@ -689,6 +812,7 @@ pub(crate) mod tests { let (address, server) = spanner_grpc_mock::start("0.0.0.0:0", mock) .await .expect("Failed to start mock server"); + let spanner = Spanner::builder() .with_endpoint(address) .with_credentials(Anonymous::new().build()) @@ -712,7 +836,12 @@ pub(crate) mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = db_client.single_use().build(); - let selector = tx.context.transaction_selector.selector(); + let selector = tx + .context + .transaction_selector + .selector() + .await + .expect("Failed to get selector"); let ro = selector .single_use() .expect("Expected SingleUse selector") @@ -729,7 +858,12 @@ pub(crate) mod tests { std::time::Duration::from_secs(10), )) .build(); - let selector = tx2.context.transaction_selector.selector(); + let selector = tx2 + .context + .transaction_selector + .selector() + .await + .expect("Failed to get selector"); let ro2 = selector .single_use() .expect("Expected SingleUse selector") @@ -761,9 +895,9 @@ pub(crate) mod tests { ); assert_eq!(req.sql, "SELECT 1"); - Ok(gaxi::grpc::tonic::Response::new(Box::pin( - tokio_stream::iter(vec![Ok(setup_select1())]), - ))) + Ok(tonic::Response::new(Box::pin(tokio_stream::iter(vec![ + Ok(setup_select1()), + ])))) }); let (db_client, _server) = setup_db_client(mock).await; @@ -795,7 +929,7 @@ pub(crate) mod tests { req.session, "projects/p/instances/i/databases/d/sessions/123" ); - Ok(gaxi::grpc::tonic::Response::new(mock_v1::Transaction { + Ok(tonic::Response::new(mock_v1::Transaction { id: vec![1, 2, 3], // prost_types::Timestamp fields need to be explicitly set because default is 0 for both read_timestamp: Some(prost_types::Timestamp { @@ -822,9 +956,9 @@ pub(crate) mod tests { mock_v1::transaction_selector::Selector::Id(vec![1, 2, 3]) ); - Ok(gaxi::grpc::tonic::Response::new(Box::pin( - tokio_stream::iter(vec![Ok(setup_select1())]), - ))) + Ok(tonic::Response::new(Box::pin(tokio_stream::iter(vec![ + Ok(setup_select1()), + ])))) }); let (db_client, _server) = setup_db_client(mock).await; @@ -894,9 +1028,9 @@ pub(crate) mod tests { }), ..Default::default() }); - Ok(gaxi::grpc::tonic::Response::new(Box::pin( - tokio_stream::iter(vec![Ok(rs)]), - ))) + Ok(tonic::Response::new(Box::pin(tokio_stream::iter(vec![ + Ok(rs), + ])))) }); mock.expect_execute_streaming_sql() @@ -911,9 +1045,9 @@ pub(crate) mod tests { } _ => panic!("Expected Selector::Id"), } - Ok(gaxi::grpc::tonic::Response::new(Box::pin( - tokio_stream::iter(vec![Ok(setup_select1())]), - ))) + Ok(tonic::Response::new(Box::pin(tokio_stream::iter(vec![ + Ok(setup_select1()), + ])))) }); let (db_client, _server) = setup_db_client(mock).await; @@ -969,9 +1103,9 @@ pub(crate) mod tests { assert_eq!(req.table, "Users"); assert_eq!(req.columns, vec!["Id".to_string(), "Name".to_string()]); - Ok(gaxi::grpc::tonic::Response::new(Box::pin( - tokio_stream::iter(vec![Ok(setup_select1())]), - ))) + Ok(tonic::Response::new(Box::pin(tokio_stream::iter(vec![ + Ok(setup_select1()), + ])))) }); let (db_client, _server) = setup_db_client(mock).await; @@ -991,8 +1125,8 @@ pub(crate) mod tests { #[tokio::test] async fn inline_begin_failure_retry_success() -> anyhow::Result<()> { use crate::value::Value; - use gaxi::grpc::tonic::Response; use gaxi::grpc::tonic::Status; + use tonic::Response; let mut mock = create_session_mock(); let mut seq = mockall::Sequence::new(); @@ -1068,8 +1202,8 @@ pub(crate) mod tests { #[tokio::test] async fn inline_begin_failure_retry_failure() -> anyhow::Result<()> { - use gaxi::grpc::tonic::Response; use gaxi::grpc::tonic::Status; + use tonic::Response; let mut mock = create_session_mock(); let mut seq = mockall::Sequence::new(); @@ -1174,8 +1308,8 @@ pub(crate) mod tests { async fn inline_begin_read_failure_retry_success() -> anyhow::Result<()> { use crate::client::{KeySet, ReadRequest}; use crate::value::Value; - use gaxi::grpc::tonic::Response; use gaxi::grpc::tonic::Status; + use tonic::Response; let mut mock = create_session_mock(); let mut seq = mockall::Sequence::new(); @@ -1292,9 +1426,9 @@ pub(crate) mod tests { read_timestamp: None, ..Default::default() }); - Ok(gaxi::grpc::tonic::Response::new(Box::pin( - tokio_stream::iter(vec![Ok(rs)]), - ))) + Ok(tonic::Response::new(Box::pin(tokio_stream::iter(vec![ + Ok(rs), + ])))) }); // 2. Second query fails immediately upon send() @@ -1328,4 +1462,688 @@ pub(crate) mod tests { Ok(()) } + + /// A wrapper that implements `tokio_stream::Stream` for a `mpsc::Receiver`. + /// Useful in mock setups to yield controlled streaming test responses. + struct ReceiverStream(mpsc::Receiver); + impl tokio_stream::Stream for ReceiverStream { + type Item = T; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.0.poll_recv(cx) + } + } + + #[tokio::test] + async fn execute_concurrent_queries_inline_begin() -> anyhow::Result<()> { + let mut mock = create_session_mock(); + mock.expect_begin_transaction().never(); + + let mut seq = mockall::Sequence::new(); + let (tx_sender, rx_receiver) = mpsc::channel(1); + let rx_receiver = Arc::new(Mutex::new(Some(rx_receiver))); + + let task1_ready = Arc::new(Notify::new()); + let task1_ready_clone = Arc::clone(&task1_ready); + let tasks_started = Arc::new(Barrier::new(3)); + + // 1. First query: should include Selector::Begin + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(move |req| { + task1_ready_clone.notify_one(); + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin for first query"), + } + let rx = rx_receiver + .try_lock() + .expect("mutex poisoned") + .take() + .unwrap(); + Ok(Response::new(Box::pin(ReceiverStream(rx)))) + }); + + // 2. The other queries: should include populated Selector::Id + mock.expect_execute_streaming_sql() + .times(2) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + } + _ => panic!("Expected Selector::Id for other queries"), + } + + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( + setup_select1(), + )])))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let tx = Arc::new(tx); + + // Spawn 3 concurrent queries. + // Task 1 launches first and executes the first query. + let tx1 = Arc::clone(&tx); + let handle1 = tokio::spawn(async move { + let mut rs = tx1 + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + // Read the first result to get the transaction ID. + let _ = rs.next().await; + Ok::<_, crate::Error>(rs) + }); + + // Wait for Task 1 to reach the mock server. + task1_ready.notified().await; + + let tx2 = Arc::clone(&tx); + let tasks_started2 = Arc::clone(&tasks_started); + let handle2 = tokio::spawn(async move { + tasks_started2.wait().await; + tx2.execute_query(Statement::builder("SELECT 1").build()) + .await + }); + + let tx3 = Arc::clone(&tx); + let tasks_started3 = Arc::clone(&tasks_started); + let handle3 = tokio::spawn(async move { + tasks_started3.wait().await; + tx3.execute_query(Statement::builder("SELECT 1").build()) + .await + }); + + // Ensure both Tasks 2 and 3 have reached the barrier before proceeding. + tasks_started.wait().await; + + // Flush the scheduler on this single-threaded executor. + // This guarantees that Tasks 2 & 3 run until they both hit the internal + // selector Notify latch and become suspended. + tokio::task::yield_now().await; + + // Provide the first result (including the transaction ID) to Task 1. + // This transitions the selector to 'Started' and unblocks Tasks 2 and 3. + let mut rs = setup_select1(); + rs.metadata.as_mut().unwrap().transaction = Some(mock_v1::Transaction { + id: vec![4, 5, 6], + read_timestamp: Some(prost_types::Timestamp { + seconds: 987654321, + nanos: 0, + }), + ..Default::default() + }); + tx_sender.send(Ok(rs)).await.expect("channel broken"); + drop(tx_sender); + + // Collect all results + let mut rs1 = handle1.await??; + let mut rs2 = handle2.await??; + let mut rs3 = handle3.await??; + + // Verify the query results + assert!(rs1.next().await.is_none()); + + let row2 = rs2.next().await.expect("Expected a row")?; + assert_eq!(row2.raw_values(), [Value(string_val("1"))]); + assert!(rs2.next().await.is_none()); + + let row3 = rs3.next().await.expect("Expected a row")?; + assert_eq!(row3.raw_values(), [Value(string_val("1"))]); + assert!(rs3.next().await.is_none()); + + // Verify that the read timestamp was populated + assert_eq!(tx.read_timestamp().unwrap().seconds(), 987654321); + + Ok(()) + } + + #[tokio::test] + async fn execute_concurrent_queries_inline_begin_failed_cascade() -> anyhow::Result<()> { + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + let (tx_sender, rx_receiver) = mpsc::channel(1); + let rx_receiver = Arc::new(Mutex::new(Some(rx_receiver))); + + let task1_ready = Arc::new(Notify::new()); + let task1_ready_clone = Arc::clone(&task1_ready); + let tasks_started = Arc::new(Barrier::new(3)); + + // 1. Return a stream connected to tx_sender. + // We will use tx_sender later in the test to inject a failed first chunk. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(move |_req| { + task1_ready_clone.notify_one(); + let rx = rx_receiver + .try_lock() + .expect("mutex poisoned") + .take() + .unwrap(); + Ok(tonic::Response::new(Box::pin(ReceiverStream(rx)))) + }); + + // 2. Fallback BeginTransaction RPC fails + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| { + Err(gaxi::grpc::tonic::Status::internal( + "Fallback BeginTransaction failed", + )) + }); + + // The other queries will never be executed. + mock.expect_execute_streaming_sql().times(0).returning(|_| { + panic!("Other queries should not launch after failure to start the transaction") + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let tx = Arc::new(tx); + + // Spawn 3 concurrent queries. + let tx1 = Arc::clone(&tx); + let handle1 = tokio::spawn(async move { + let mut rs = tx1 + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + rs.next().await.ok_or_else(|| { + crate::error::internal_error("stream exhausted (this should never happen)") + })??; + Ok::<_, crate::Error>(rs) + }); + + // Wait for Task 1 to reach the mock and transition the selector to Starting. + task1_ready.notified().await; + + let tx2 = Arc::clone(&tx); + let tasks_started2 = Arc::clone(&tasks_started); + let handle2 = tokio::spawn(async move { + tasks_started2.wait().await; + tx2.execute_query(Statement::builder("SELECT 1").build()) + .await + }); + + let tx3 = Arc::clone(&tx); + let tasks_started3 = Arc::clone(&tasks_started); + let handle3 = tokio::spawn(async move { + tasks_started3.wait().await; + tx3.execute_query(Statement::builder("SELECT 1").build()) + .await + }); + + // Ensure both Tasks 2 and 3 have reached the barrier before proceeding. + tasks_started.wait().await; + + // Flush the scheduler on this single-threaded executor. + // This guarantees that Tasks 2 & 3 run until they both hit the internal + // selector Notify latch and become suspended. + tokio::task::yield_now().await; + + // Push error to channel failing first query stream! + tx_sender + .send(Err(gaxi::grpc::tonic::Status::internal( + "Mocked boot failed", + ))) + .await + .expect("channel broken"); + drop(tx_sender); + + // Collect all results - all should fail with identical cached error! + let err1 = handle1.await?.unwrap_err().to_string(); + let err2 = handle2.await?.unwrap_err().to_string(); + let err3 = handle3.await?.unwrap_err().to_string(); + + assert!( + err1.contains("Fallback BeginTransaction failed"), + "err1: {}", + err1 + ); + assert!( + err2.contains("Fallback BeginTransaction failed"), + "err2: {}", + err2 + ); + assert!( + err3.contains("Fallback BeginTransaction failed"), + "err3: {}", + err3 + ); + + Ok(()) + } + + #[tokio::test] + async fn execute_concurrent_queries_inline_begin_stream_restart_deadlock_prevention() + -> crate::Result<()> { + let mut mock = create_session_mock(); + mock.expect_begin_transaction().never(); + + let mut seq = mockall::Sequence::new(); + + let (tx_sender, rx_receiver) = mpsc::channel(1); + let rx_receiver = Arc::new(Mutex::new(Some(rx_receiver))); + + let task1_ready = Arc::new(Notify::new()); + let task1_ready_clone = Arc::clone(&task1_ready); + let tasks_started = Arc::new(Barrier::new(3)); + + // 1. Task 1 initial query: Return a stream connected to tx_sender for error injection. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + // Return a stream connected to tx_sender. + // We will use tx_sender later in the test to inject a transient error. + task1_ready_clone.notify_one(); + match req.transaction.unwrap().selector.unwrap() { + Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin for first query"), + } + let rx = rx_receiver + .try_lock() + .expect("mutex poisoned") + .take() + .unwrap(); + Ok(Response::new(Box::pin(ReceiverStream(rx)))) + }); + + // 2. Task 1 restart query: should include Selector::Begin, since + // it failed with a transient error. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + Selector::Begin(_) => { + let mut rs = setup_select1(); + rs.metadata.as_mut().unwrap().transaction = Some(mock_v1::Transaction { + id: vec![4, 5, 6], + ..Default::default() + }); + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok(rs)])))) + } + _ => panic!("Expected Selector::Begin for stream restart query"), + } + }); + + // 3. Tasks 2 & 3: should include populated Selector::Id + mock.expect_execute_streaming_sql() + .times(2) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( + setup_select1(), + )])))) + } + _ => panic!("Expected Selector::Id for concurrent queries"), + } + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let tx = Arc::new(tx); + + let handle1_tx = Arc::clone(&tx); + let handle1 = tokio::spawn(async move { + let mut rs = handle1_tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + let _ = rs.next().await.ok_or_else(|| { + crate::error::internal_error("stream exhausted (this should never happen)") + })??; + Ok::<_, crate::Error>(rs) + }); + + // Wait for Task 1 to reach the mock and transition the selector to Starting. + task1_ready.notified().await; + + let handle2_tx = Arc::clone(&tx); + let tasks_started2 = Arc::clone(&tasks_started); + let handle2 = tokio::spawn(async move { + tasks_started2.wait().await; + let mut rs = handle2_tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + let _ = rs.next().await.ok_or_else(|| { + crate::error::internal_error("stream exhausted (this should never happen)") + })??; + Ok::<_, crate::Error>(rs) + }); + + let handle3_tx = Arc::clone(&tx); + let tasks_started3 = Arc::clone(&tasks_started); + let handle3 = tokio::spawn(async move { + tasks_started3.wait().await; + let mut rs = handle3_tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + let _ = rs.next().await.ok_or_else(|| { + crate::error::internal_error("stream exhausted (this should never happen)") + })??; + Ok::<_, crate::Error>(rs) + }); + + // Ensure both Tasks 2 and 3 have reached the barrier before proceeding. + tasks_started.wait().await; + + // Flush the scheduler on this single-threaded executor. + // This guarantees that Tasks 2 & 3 run until they both hit the internal + // selector Notify latch and become suspended. + tokio::task::yield_now().await; + + let grpc_status = Status::new(gaxi::grpc::tonic::Code::Unavailable, "transient error"); + tx_sender.send(Err(grpc_status)).await.expect("send failed"); + drop(tx_sender); + + // Collect and verify all results. + // handle.await returns Result, JoinError>. + // The first ? handles the potential JoinError (panic in the task), + // and the second ? handles the Spanner error. + let mut rs1 = handle1.await.expect("Task 1 panicked")?; + let mut rs2 = handle2.await.expect("Task 2 panicked")?; + let mut rs3 = handle3.await.expect("Task 3 panicked")?; + + // Verify that all results have been exhausted. + // (The tasks themselves already successfully read the first row). + assert!(rs1.next().await.is_none(), "Stream 1 should be exhausted"); + assert!(rs2.next().await.is_none(), "Stream 2 should be exhausted"); + assert!(rs3.next().await.is_none(), "Stream 3 should be exhausted"); + + Ok(()) + } + + #[tokio::test] + async fn execute_concurrent_queries_late_arrival_failure() -> anyhow::Result<()> { + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial query fails. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin for first query"), + } + Err(Status::internal("Initial inline-begin failed")) + }); + + // 2. Fallback BeginTransaction RPC also fails. + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Fallback BeginTransaction failed"))); + + // Any further attempts would panic because we haven't mocked them. + mock.expect_execute_streaming_sql().never(); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + // First query: triggers the failure and transitions the state to Failed. + let err1 = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await + .expect_err("First query should fail"); + assert!( + err1.to_string() + .contains("Fallback BeginTransaction failed") + ); + + // Second query: starts AFTER the failure is already cached. + // It should immediately return the same error without invoking the mock server. + let err2 = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await + .expect_err("Late query should fail immediately"); + assert!( + err2.to_string() + .contains("Fallback BeginTransaction failed") + ); + + Ok(()) + } + + #[tokio::test] + async fn execute_concurrent_reads_inline_begin() -> anyhow::Result<()> { + use crate::client::{KeySet, ReadRequest}; + let mut mock = create_session_mock(); + mock.expect_begin_transaction().never(); + + let mut seq = mockall::Sequence::new(); + let (tx_sender, rx_receiver) = mpsc::channel(1); + let rx_receiver = Arc::new(Mutex::new(Some(rx_receiver))); + + let task1_ready = Arc::new(Notify::new()); + let task1_ready_clone = Arc::clone(&task1_ready); + let tasks_started = Arc::new(Barrier::new(3)); + + // 1. First read: should include Selector::Begin + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(move |req| { + task1_ready_clone.notify_one(); + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin for first read"), + } + + let rx = rx_receiver + .try_lock() + .expect("mutex poisoned") + .take() + .unwrap(); + Ok(Response::new(Box::pin(ReceiverStream(rx)))) + }); + + // 2. The other reads: should include populated Selector::Id + mock.expect_streaming_read() + .times(2) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + } + _ => panic!("Expected Selector::Id for other reads"), + } + + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( + setup_select1(), + )])))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let tx = Arc::new(tx); + + let read_req = ReadRequest::builder("Table", vec!["Col"]) + .with_keys(KeySet::all()) + .build(); + + // Spawn 3 concurrent reads. + let tx1 = Arc::clone(&tx); + let read1 = read_req.clone(); + let handle1 = tokio::spawn(async move { + let mut rs = tx1.execute_read(read1).await?; + let _ = rs.next().await; + Ok::<_, crate::Error>(rs) + }); + + task1_ready.notified().await; + + let tx2 = Arc::clone(&tx); + let read2 = read_req.clone(); + let tasks_started2 = Arc::clone(&tasks_started); + let handle2 = tokio::spawn(async move { + tasks_started2.wait().await; + let mut rs = tx2.execute_read(read2).await?; + let _ = rs.next().await; + Ok::<_, crate::Error>(rs) + }); + + let tx3 = Arc::clone(&tx); + let read3 = read_req.clone(); + let tasks_started3 = Arc::clone(&tasks_started); + let handle3 = tokio::spawn(async move { + tasks_started3.wait().await; + let mut rs = tx3.execute_read(read3).await?; + let _ = rs.next().await; + Ok::<_, crate::Error>(rs) + }); + + tasks_started.wait().await; + tokio::task::yield_now().await; + + // Provide the transaction ID. + let mut rs = setup_select1(); + rs.metadata.as_mut().unwrap().transaction = Some(mock_v1::Transaction { + id: vec![4, 5, 6], + ..Default::default() + }); + tx_sender.send(Ok(rs)).await.expect("send failed"); + drop(tx_sender); + + let mut rs1 = handle1.await.expect("Task 1 panicked")?; + let mut rs2 = handle2.await.expect("Task 2 panicked")?; + let mut rs3 = handle3.await.expect("Task 3 panicked")?; + + assert!(rs1.next().await.is_none()); + assert!(rs2.next().await.is_none()); + assert!(rs3.next().await.is_none()); + + Ok(()) + } + + #[tokio::test] + async fn execute_inline_begin_idempotent_update() -> anyhow::Result<()> { + let (db_client, _server) = setup_db_client(create_session_mock()).await; + // Access internal state for unit testing. + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let id1 = bytes::Bytes::from_static(b"tx1"); + let id2 = bytes::Bytes::from_static(b"tx2"); + + // 1. Initial update. + tx.context.transaction_selector.update(id1.clone(), None)?; + assert_eq!( + tx.context + .transaction_selector + .selector() + .await? + .id() + .unwrap(), + &id1 + ); + + // 2. Redundant update with same ID should result in an error. + // The implementation explicitly prevents redundant updates to ensure state consistency. + let err1 = tx + .context + .transaction_selector + .update(id1.clone(), None) + .expect_err("Redundant update should fail"); + assert!(err1.to_string().contains("already Started or Failed")); + + // 3. Update with DIFFERENT ID after already Started should also fail. + let err2 = tx + .context + .transaction_selector + .update(id2, None) + .expect_err("Update after Started should fail"); + assert!(err2.to_string().contains("already Started or Failed")); + + Ok(()) + } + + #[tokio::test] + async fn execute_inline_begin_with_transient_failure() -> anyhow::Result<()> { + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. First attempt fails transiently. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::new(Code::Unavailable, "Transient 1"))); + + // 2. Fallback BeginTransaction succeeds. + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(mock_v1::Transaction { + id: vec![7, 8, 9], + ..Default::default() + })) + }); + + // 3. The manual retry of the query (which happens after explicit begin fallback). + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( + setup_select1(), + )])))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let mut rs = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + assert!(rs.next().await.is_some()); + assert!(rs.next().await.is_none()); + + Ok(()) + } } diff --git a/src/spanner/src/read_write_transaction.rs b/src/spanner/src/read_write_transaction.rs index 920d69bfda..dd7f9e0553 100644 --- a/src/spanner/src/read_write_transaction.rs +++ b/src/spanner/src/read_write_transaction.rs @@ -148,7 +148,7 @@ impl ReadWriteTransaction { .into() .into_request() .set_session(self.context.client.session.name.clone()) - .set_transaction(self.context.transaction_selector.selector()) + .set_transaction(self.context.transaction_selector.selector().await?) .set_seqno(seqno); request.request_options = self.context.amend_request_options(request.request_options); @@ -249,7 +249,7 @@ impl ReadWriteTransaction { let request = ExecuteBatchDmlRequest::default() .set_session(self.context.client.session.name.clone()) - .set_transaction(self.context.transaction_selector.selector()) + .set_transaction(self.context.transaction_selector.selector().await?) .set_seqno(seqno) .set_statements(statements) .set_or_clear_request_options( @@ -274,8 +274,8 @@ impl ReadWriteTransaction { } } - pub(crate) fn transaction_id(&self) -> crate::Result { - match &self.context.transaction_selector.selector().selector { + pub(crate) async fn transaction_id(&self) -> crate::Result { + match &self.context.transaction_selector.selector().await?.selector { Some(Selector::Id(id)) => Ok(id.clone()), _ => Err(internal_error("Transaction ID is missing")), } @@ -283,7 +283,7 @@ impl ReadWriteTransaction { /// Commits the transaction. pub(crate) async fn commit(self) -> crate::Result { - let transaction_id = self.transaction_id()?; + let transaction_id = self.transaction_id().await?; let precommit_token = self.context.precommit_token_tracker.get(); let request = CommitRequest::default() .set_session(self.context.client.session.name.clone()) @@ -323,7 +323,7 @@ impl ReadWriteTransaction { /// Rolls back the transaction. pub(crate) async fn rollback(self) -> crate::Result<()> { - let transaction_id = self.transaction_id()?; + let transaction_id = self.transaction_id().await?; let request = RollbackRequest::default() .set_session(self.context.client.session.name.clone()) diff --git a/src/spanner/src/result_set.rs b/src/spanner/src/result_set.rs index 435ffe7a1a..07d8d8d89b 100644 --- a/src/spanner/src/result_set.rs +++ b/src/spanner/src/result_set.rs @@ -310,7 +310,7 @@ impl ResultSet { transaction .read_timestamp .and_then(|t| wkt::Timestamp::new(t.seconds, t.nanos).ok()), - ); + )?; } else if let ReadContextTransactionSelector::Lazy(lazy) = selector { let is_started = matches!( &*lazy.lock().expect("transaction state mutex poisoned"), @@ -368,8 +368,16 @@ impl ResultSet { } async fn restart_stream(&mut self) -> crate::Result<()> { + if let Some(s) = &self.transaction_selector { + s.maybe_reset_starting(); + } + // Get the latest transaction selector for this transaction. - let transaction_selector = self.transaction_selector.as_ref().map(|s| s.selector()); + let transaction_selector = if let Some(s) = &self.transaction_selector { + Some(s.selector().await?) + } else { + None + }; match &mut self.operation { StreamOperation::Query(req) => { diff --git a/src/spanner/src/transaction_runner.rs b/src/spanner/src/transaction_runner.rs index e6f5b4da6d..089e6f7d65 100644 --- a/src/spanner/src/transaction_runner.rs +++ b/src/spanner/src/transaction_runner.rs @@ -223,7 +223,7 @@ impl TransactionRunner { let mut current_tx_id = None; let attempt_result = async { let transaction = self.builder.begin_transaction().await?; - current_tx_id = transaction.transaction_id().ok(); + current_tx_id = transaction.transaction_id().await.ok(); let result = match work(transaction.clone()).await { Ok(res) => res, diff --git a/tests/spanner/Cargo.toml b/tests/spanner/Cargo.toml index 30b6625f58..2869708160 100644 --- a/tests/spanner/Cargo.toml +++ b/tests/spanner/Cargo.toml @@ -33,10 +33,13 @@ google-cloud-gax = { workspace = true } google-cloud-lro = { workspace = true } google-cloud-spanner = { workspace = true, features = ["unstable-stream"] } google-cloud-test-utils = { workspace = true } +google-cloud-wkt = { workspace = true } prost-types.workspace = true +rand = { workspace = true } reqwest = { workspace = true, features = ["json"] } serde_json = { workspace = true } spanner-grpc-mock = { path = "../../src/spanner/grpc-mock" } +time = { workspace = true } tokio = { workspace = true, features = ["sync"] } tokio-stream = { workspace = true } tonic = { workspace = true } diff --git a/tests/spanner/src/concurrent_inline_begin.rs b/tests/spanner/src/concurrent_inline_begin.rs new file mode 100644 index 0000000000..325064b0ea --- /dev/null +++ b/tests/spanner/src/concurrent_inline_begin.rs @@ -0,0 +1,277 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::client::{ + get_database_id, get_emulator_host, get_emulator_rest_endpoint, provision_emulator, +}; +use crate::test_proxy::{InterceptedSpanner, SpannerInterceptor}; +use futures::stream::{self, StreamExt}; +use google_cloud_spanner::client::{ResultSet, Row, Spanner, TimestampBound}; +use google_cloud_test_utils::resource_names::LowercaseAlphanumeric; +use spanner_grpc_mock::google::spanner::v1 as spanner_v1; +use spanner_grpc_mock::google::spanner::v1::spanner_client::SpannerClient; +use std::collections::HashMap; +use std::sync::Arc; +use time::OffsetDateTime; +use tokio::net::TcpListener; +use tokio::sync::{Barrier, Mutex}; +use tokio_stream::wrappers::TcpListenerStream; +use tonic::transport::{Channel, Server}; + +/// An interceptor that injects transient (Unavailable) and permanent (Internal) failures +/// into streaming SQL responses for specific query patterns. +pub struct ConcurrentFaultInterceptor { + emulator_client: SpannerClient, + /// Tracks failure counts to allow transient recovery. + failure_counts: Arc>>, +} + +impl ConcurrentFaultInterceptor { + pub fn new(emulator_client: SpannerClient) -> Self { + Self { + emulator_client, + failure_counts: Arc::new(Mutex::new(HashMap::new())), + } + } +} + +#[tonic::async_trait] +impl SpannerInterceptor for ConcurrentFaultInterceptor { + fn emulator_client(&self) -> SpannerClient { + self.emulator_client.clone() + } + + async fn execute_streaming_sql( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + let sql = request.get_ref().sql.clone(); + + // Emulates a transient stream failure. + if sql.starts_with("SELECT 'Transient-") { + let mut counts = self.failure_counts.lock().await; + let count = counts.entry(sql.clone()).or_insert(0); + if *count == 0 { + *count += 1; + // Return a stream that fails immediately with Unavailable. + let stream = stream::once(async { + Err(tonic::Status::unavailable("Transient stream failure")) + }); + return Ok(tonic::Response::new(stream.boxed())); + } + // Second attempt succeeds (fall through to emulator). + } + + // Emulates a permanent stream failure. + if sql == "SELECT 'Permanent'" { + // Returns a stream that always fails with an Internal error. + let stream = + stream::once(async { Err(tonic::Status::internal("Permanent stream failure")) }); + return Ok(tonic::Response::new(stream.boxed())); + } + + // Forward other queries to the emulator. + let res = self + .emulator_client() + .execute_streaming_sql(request) + .await?; + let (metadata, stream, extensions) = res.into_parts(); + Ok(tonic::Response::from_parts( + metadata, + stream.boxed(), + extensions, + )) + } +} + +pub async fn test_concurrent_inline_begin_with_snapshot_consistency() -> anyhow::Result<()> { + let emulator_host = match get_emulator_host() { + Some(host) => host, + None => return Ok(()), + }; + provision_emulator(&emulator_host).await; + let rest_endpoint = get_emulator_rest_endpoint(&emulator_host); + let db_id = get_database_id().await; + let db_path = format!( + "projects/test-project/instances/test-instance/databases/{}", + db_id + ); + + // 1. Setup Table 1 (Exists at snapshot time) + let suffix = LowercaseAlphanumeric.random_string(6); + let table_success = format!("TableSuccess_{}", suffix); + let table_not_found = format!("TableNotFound_{}", suffix); + + let client = reqwest::Client::new(); + client + .patch(format!("{}/v1/{}/ddl", rest_endpoint, db_path)) + .json(&serde_json::json!({ + "statements": [format!("CREATE TABLE {} (Id INT64) PRIMARY KEY (Id)", table_success)] + })) + .send() + .await? + .error_for_status()?; + + // 2. Capture snapshot time. + let spanner = Spanner::builder() + .with_endpoint(format!("http://{}", emulator_host)) + .build() + .await?; + let db_client = spanner.database_client(&db_path).build().await?; + + let mut rs: ResultSet = db_client + .single_use() + .build() + .execute_query("SELECT CURRENT_TIMESTAMP") + .await?; + let row: Row = rs.next().await.unwrap().unwrap(); + let snapshot_time: OffsetDateTime = row.try_get(0)?; + + // 3. Setup Table 2 (Does NOT exist at snapshot time) + client + .patch(format!("{}/v1/{}/ddl", rest_endpoint, db_path)) + .json(&serde_json::json!({ + "statements": [format!("CREATE TABLE {} (Id INT64) PRIMARY KEY (Id)", table_not_found)] + })) + .send() + .await? + .error_for_status()?; + + // 4. Start the Intercepted Server + let listener = TcpListener::bind("127.0.0.1:0").await?; + let local_addr = listener.local_addr()?; + let emulator_channel = Channel::from_shared(format!("http://{}", emulator_host))? + .connect() + .await?; + let interceptor = ConcurrentFaultInterceptor::new(SpannerClient::new(emulator_channel)); + let service = InterceptedSpanner(interceptor); + + tokio::spawn(async move { + Server::builder() + .add_service(spanner_v1::spanner_server::SpannerServer::new(service)) + .serve_with_incoming(TcpListenerStream::new(listener)) + .await + .expect("Server failed"); + }); + + // 5. Build Client pointing to Interceptor + let intercepted_spanner = Spanner::builder() + .with_endpoint(format!("http://{}", local_addr)) + .build() + .await?; + let intercepted_db = intercepted_spanner + .database_client(&db_path) + .build() + .await?; + + // 6. Spawn 20 tasks with random workloads + let tx = intercepted_db + .read_only_transaction() + .with_timestamp_bound(TimestampBound::read_timestamp(snapshot_time)) + .with_explicit_begin_transaction(false) + .build() + .await?; + let tx = Arc::new(tx); + let barrier = Arc::new(Barrier::new(20)); + let mut handles = Vec::new(); + + for i in 0..20 { + let role = rand::random_range(0..4); + let tx = Arc::clone(&tx); + let barrier = Arc::clone(&barrier); + let table_success = table_success.clone(); + let table_not_found = table_not_found.clone(); + + handles.push(tokio::spawn(async move { + barrier.wait().await; + match role { + 0 => { + // Success + let mut result_set: ResultSet = tx + .execute_query(format!("SELECT * FROM {}", table_success)) + .await?; + while let Some(row) = result_set.next().await { + row?; + } + Ok::<_, anyhow::Error>(format!("Task {} Success: OK", i)) + } + 1 => { + // Table not found + let res: Result = tx + .execute_query(format!("SELECT * FROM {}", table_not_found)) + .await; + match res { + Err(e) + if e.to_string().contains("not found") + || e.to_string().contains("NotFound") => + { + Ok(format!("Task {} NotFound: OK", i)) + } + Ok(_) => anyhow::bail!("Task {} expected NotFound but got Success", i), + Err(e) => anyhow::bail!("Task {} expected NotFound but got: {:?}", i, e), + } + } + 2 => { + // Transient stream error. This will trigger a retry of the stream. + let sql = format!("SELECT 'Transient-{}'", i); + let mut result_set: ResultSet = tx.execute_query(sql).await?; + while let Some(row) = result_set.next().await { + row?; + } + Ok(format!("Task {} Transient: OK", i)) + } + 3 => { + // Permanent stream error. + let result_set_res: Result = + tx.execute_query("SELECT 'Permanent'").await; + let mut result_set = match result_set_res { + Ok(rs) => rs, + Err(e) => anyhow::bail!( + "Task {} expected successful RPC initiation but got: {:?}", + i, + e + ), + }; + + let next = result_set.next().await; + match next { + Some(Err(e)) + if e.to_string().contains("Permanent") + || e.to_string().contains("Internal") => + { + Ok(format!("Task {} Permanent: OK", i)) + } + Some(Ok(_)) => { + anyhow::bail!("Task {} expected Permanent error but got a valid row", i) + } + _ => anyhow::bail!( + "Task {} expected Permanent error but succeeded or got empty results", + i + ), + } + } + _ => unreachable!(), + } + })); + } + + for handle in handles { + handle.await??; + } + + Ok(()) +} diff --git a/tests/spanner/src/lib.rs b/tests/spanner/src/lib.rs index c300c74674..3449720728 100644 --- a/tests/spanner/src/lib.rs +++ b/tests/spanner/src/lib.rs @@ -13,6 +13,7 @@ // limitations under the License. pub mod client; +pub mod concurrent_inline_begin; pub mod partitioned_dml; pub mod query; pub mod read; diff --git a/tests/spanner/src/test_proxy.rs b/tests/spanner/src/test_proxy.rs index 7f0a5837a2..f7a07d881f 100644 --- a/tests/spanner/src/test_proxy.rs +++ b/tests/spanner/src/test_proxy.rs @@ -12,9 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +use futures::stream::{BoxStream, StreamExt}; use spanner_grpc_mock::google::spanner::v1 as spanner_v1; use spanner_grpc_mock::google::spanner::v1::spanner_client::SpannerClient; +pub type ExecuteStreamingSqlStream = + BoxStream<'static, std::result::Result>; + #[tonic::async_trait] pub trait SpannerInterceptor: Send + Sync + 'static { fn emulator_client(&self) -> SpannerClient; @@ -66,10 +70,21 @@ pub trait SpannerInterceptor: Send + Sync + 'static { &self, request: tonic::Request, ) -> std::result::Result< - tonic::Response>, + tonic::Response< + BoxStream<'static, std::result::Result>, + >, tonic::Status, > { - self.emulator_client().execute_streaming_sql(request).await + let res = self + .emulator_client() + .execute_streaming_sql(request) + .await?; + let (metadata, stream, extensions) = res.into_parts(); + Ok(tonic::Response::from_parts( + metadata, + stream.boxed(), + extensions, + )) } async fn execute_batch_dml( @@ -91,10 +106,18 @@ pub trait SpannerInterceptor: Send + Sync + 'static { &self, request: tonic::Request, ) -> std::result::Result< - tonic::Response>, + tonic::Response< + BoxStream<'static, std::result::Result>, + >, tonic::Status, > { - self.emulator_client().streaming_read(request).await + let res = self.emulator_client().streaming_read(request).await?; + let (metadata, stream, extensions) = res.into_parts(); + Ok(tonic::Response::from_parts( + metadata, + stream.boxed(), + extensions, + )) } async fn begin_transaction( @@ -136,10 +159,18 @@ pub trait SpannerInterceptor: Send + Sync + 'static { &self, request: tonic::Request, ) -> std::result::Result< - tonic::Response>, + tonic::Response< + BoxStream<'static, std::result::Result>, + >, tonic::Status, > { - self.emulator_client().batch_write(request).await + let res = self.emulator_client().batch_write(request).await?; + let (metadata, stream, extensions) = res.into_parts(); + Ok(tonic::Response::from_parts( + metadata, + stream.boxed(), + extensions, + )) } } @@ -190,7 +221,8 @@ impl spanner_v1::spanner_server::Spanner for InterceptedS self.0.execute_sql(request).await } - type ExecuteStreamingSqlStream = tonic::codec::Streaming; + type ExecuteStreamingSqlStream = + BoxStream<'static, std::result::Result>; async fn execute_streaming_sql( &self, @@ -214,7 +246,8 @@ impl spanner_v1::spanner_server::Spanner for InterceptedS self.0.read(request).await } - type StreamingReadStream = tonic::codec::Streaming; + type StreamingReadStream = + BoxStream<'static, std::result::Result>; async fn streaming_read( &self, @@ -258,7 +291,8 @@ impl spanner_v1::spanner_server::Spanner for InterceptedS self.0.partition_read(request).await } - type BatchWriteStream = tonic::codec::Streaming; + type BatchWriteStream = + BoxStream<'static, std::result::Result>; async fn batch_write( &self, diff --git a/tests/spanner/tests/driver.rs b/tests/spanner/tests/driver.rs index 6751a22484..2bab0f0725 100644 --- a/tests/spanner/tests/driver.rs +++ b/tests/spanner/tests/driver.rs @@ -100,4 +100,9 @@ mod spanner { Ok(()) } + + #[tokio::test] + async fn run_concurrent_inline_begin_tests() -> anyhow::Result<()> { + integration_tests_spanner::concurrent_inline_begin::test_concurrent_inline_begin_with_snapshot_consistency().await + } } From 1250553e5eaf68e235c28eaa1b677cc29afd2073 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Sat, 4 Apr 2026 11:43:30 +0200 Subject: [PATCH 05/17] fix(spanner): prevent deadlock if ResultSet#next() is never called If a query included a BeginTransaction option and the application never called ResultSet#next(), then the transaction ID would never be returned. This would block any other query from using the transaction. This change refactors ResultSet to use a background worker to read from the stream. This prevents that a deadlock can happen if the application does not call ResultSet#next(). It also allows the application to call ResultSet#metadata() without first calling ResultSet#next(). Finally, it also allows the ResultSet to decode data from the server asynchronously while the application processes rows that it has already read from the ResultSet. --- src/spanner/src/result_set.rs | 538 ++++++++++++++++++++----- src/spanner/src/result_set_metadata.rs | 4 +- tests/spanner/src/query.rs | 18 +- tests/spanner/src/write.rs | 1 + 4 files changed, 445 insertions(+), 116 deletions(-) diff --git a/src/spanner/src/result_set.rs b/src/spanner/src/result_set.rs index 07d8d8d89b..2dd5ba23aa 100644 --- a/src/spanner/src/result_set.rs +++ b/src/spanner/src/result_set.rs @@ -25,6 +25,10 @@ use gaxi::prost::FromProto; use google_cloud_gax::error::rpc::Code; use std::collections::VecDeque; use std::mem::take; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use tokio::sync::mpsc; +use tokio::sync::watch; #[cfg(feature = "unstable-stream")] use futures::Stream; @@ -44,11 +48,20 @@ use futures::Stream; /// ``` #[derive(Debug)] pub struct ResultSet { + receiver: mpsc::Receiver>, + metadata: watch::Receiver>, + // This field is only modified in tests to set a small buffer size. + #[allow(dead_code)] + max_buffered_partial_result_sets: Arc, +} + +#[derive(Debug)] +struct ResultSetWorker { stream: PartialResultSetStream, buffered_values: Vec, chunked: bool, ready_rows: VecDeque, - metadata: Option, + metadata: watch::Sender>, precommit_token_tracker: PrecommitTokenTracker, // Fields for retries and buffering of a stream of PartialResultSets. @@ -57,7 +70,7 @@ pub struct ResultSet { last_resume_token: Bytes, partial_result_sets_buffer: VecDeque, safe_to_retry: bool, - max_buffered_partial_result_sets: usize, + max_buffered_partial_result_sets: Arc, retry_count: usize, transaction_selector: Option, } @@ -91,21 +104,33 @@ impl ResultSet { client: DatabaseClient, operation: StreamOperation, ) -> Self { - Self { + let (sender, receiver) = mpsc::channel(4); + let (metadata_sender, metadata_receiver) = watch::channel(None); + let max_buffered_partial_result_sets = + Arc::new(AtomicUsize::new(MAX_BUFFERED_PARTIAL_RESULT_SETS)); + + let mut worker = ResultSetWorker::new( stream, - buffered_values: Vec::new(), - chunked: false, - ready_rows: VecDeque::new(), - metadata: None, + transaction_selector, precommit_token_tracker, client, operation, - last_resume_token: Bytes::new(), - partial_result_sets_buffer: VecDeque::new(), - safe_to_retry: true, - max_buffered_partial_result_sets: MAX_BUFFERED_PARTIAL_RESULT_SETS, - retry_count: 0, - transaction_selector, + metadata_sender, + Arc::clone(&max_buffered_partial_result_sets), + ); + + tokio::spawn(async move { + while let Some(row) = worker.next().await { + if sender.send(row).await.is_err() { + break; // Receiver dropped + } + } + }); + + Self { + receiver, + metadata: metadata_receiver, + max_buffered_partial_result_sets, } } @@ -114,21 +139,29 @@ impl ResultSet { /// # Example /// ``` /// # use google_cloud_spanner::client::{ResultSet, Row}; - /// # async fn fetch_metadata(mut rs: ResultSet) -> Result<(), Box> { - /// if let Some(row) = rs.next().await.transpose()? { - /// let metadata = rs.metadata()?; - /// for column in metadata.column_names() { - /// println!("Column name: {}", column); - /// } + /// # async fn fetch_metadata(mut result_set: ResultSet) -> Result<(), Box> { + /// let metadata = result_set.metadata().await?; + /// for column in metadata.column_names() { + /// println!("Column name: {}", column); /// } /// # Ok(()) /// # } /// ``` /// - /// The metadata is only available after the first call to [`next`](Self::next). - /// If called before the first `next()` call, it returns a [`ResultSetError::MetadataNotAvailable`] error. - pub fn metadata(&self) -> Result { - self.metadata + /// This method blocks until the metadata is available, which is after the + /// first chunk is received from the server. If the stream ends or fails + /// before metadata is available, it returns [`ResultSetError::MetadataNotAvailable`]. + pub async fn metadata(&self) -> Result { + let mut receiver = self.metadata.clone(); + if let Some(metadata) = &*receiver.borrow() { + return Ok(metadata.clone()); + } + receiver + .changed() + .await + .map_err(|_| ResultSetError::MetadataNotAvailable)?; + receiver + .borrow() .clone() .ok_or(ResultSetError::MetadataNotAvailable) } @@ -148,6 +181,73 @@ impl ResultSet { /// /// Returns `None` when all rows have been retrieved. pub async fn next(&mut self) -> Option> { + self.receiver.recv().await + } + + /// Converts the [`ResultSet`] into a [`Stream`]. + /// + /// # Example + /// + /// ``` + /// # use google_cloud_spanner::client::ResultSet; + /// # use futures::TryStreamExt; + /// # use std::future::ready; + /// # async fn example(result_set: ResultSet) -> Result<(), google_cloud_spanner::Error> { + /// let rows: Vec<_> = result_set + /// .into_stream() + /// .try_filter(|row| { + /// let id = row.get::("Id"); + /// ready(id == "id1") + /// }) + /// .try_collect() + /// .await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// This consumes the [`ResultSet`] and returns a stream of rows. + #[cfg(feature = "unstable-stream")] + pub fn into_stream(self) -> impl Stream> + Unpin { + use futures::stream::unfold; + Box::pin(unfold(self, |mut result_set| async move { + result_set.next().await.map(|row| (row, result_set)) + })) + } +} + +impl ResultSetWorker { + /// Creates a new result set worker. + pub(crate) fn new( + stream: PartialResultSetStream, + transaction_selector: Option, + precommit_token_tracker: PrecommitTokenTracker, + client: DatabaseClient, + operation: StreamOperation, + metadata: watch::Sender>, + max_buffered_partial_result_sets: Arc, + ) -> Self { + Self { + stream, + buffered_values: Vec::new(), + chunked: false, + ready_rows: VecDeque::new(), + metadata, + precommit_token_tracker, + client, + operation, + last_resume_token: Bytes::new(), + partial_result_sets_buffer: VecDeque::new(), + safe_to_retry: true, + max_buffered_partial_result_sets, + retry_count: 0, + transaction_selector, + } + } + + /// Fetches the next row from the result set. + /// + /// Returns `None` when all rows have been retrieved. + pub(crate) async fn next(&mut self) -> Option> { if let Some(row) = self.ready_rows.pop_front() { return Some(Ok(row)); } @@ -210,7 +310,11 @@ impl ResultSet { // The PartialResultSet did not have a resume_token. Buffer the result // and continue with the next PartialResultSet, unless the buffer is full. - if self.partial_result_sets_buffer.len() >= self.max_buffered_partial_result_sets { + if self.partial_result_sets_buffer.len() + >= self + .max_buffered_partial_result_sets + .load(Ordering::Relaxed) + { // Mark this stream as 'unsafe to retry', meaning that any transient error // that we see will not be retried. We will instead propagate the error. self.safe_to_retry = false; @@ -290,37 +394,47 @@ impl ResultSet { &mut self, partial_result_set: PartialResultSet, ) -> crate::Result<()> { - match (self.metadata.as_ref(), partial_result_set.metadata) { - (Some(_), None) => {} - (None, None) => { - return Err(internal_error( - "First PartialResultSet did not contain metadata", - )); - } - (Some(_), Some(_)) => { - return Err(internal_error("Additional metadata after first result set")); + let update_selector = { + let metadata_ref = self.metadata.borrow(); + match (&*metadata_ref, partial_result_set.metadata) { + (Some(_), None) => None, + (None, None) => { + return Err(internal_error( + "First PartialResultSet did not contain metadata", + )); + } + (Some(_), Some(_)) => { + return Err(internal_error("Additional metadata after first result set")); + } + (None, Some(mut m)) => { + let transaction = m.transaction.take(); + Some((ResultSetMetadata::new(Some(m)), transaction)) + } } - (None, Some(mut m)) => { - let transaction = m.transaction.take(); - self.metadata = Some(ResultSetMetadata::new(Some(m))); - if let Some(selector) = &self.transaction_selector { - if let Some(transaction) = transaction { - selector.update( - transaction.id, - transaction - .read_timestamp - .and_then(|t| wkt::Timestamp::new(t.seconds, t.nanos).ok()), - )?; - } else if let ReadContextTransactionSelector::Lazy(lazy) = selector { - let is_started = matches!( - &*lazy.lock().expect("transaction state mutex poisoned"), - crate::read_only_transaction::TransactionState::Started(_, _) - ); - if !is_started { - return Err(internal_error( - "Spanner failed to return a transaction ID for a query that included a BeginTransaction option", - )); - } + }; + + if let Some((metadata, transaction)) = update_selector { + self.metadata + .send(Some(metadata)) + .map_err(|_| internal_error("Failed to send metadata"))?; + + if let Some(selector) = &self.transaction_selector { + if let Some(transaction) = transaction { + selector.update( + transaction.id, + transaction + .read_timestamp + .and_then(|t| wkt::Timestamp::new(t.seconds, t.nanos).ok()), + )?; + } else if let ReadContextTransactionSelector::Lazy(lazy) = selector { + let is_started = matches!( + &*lazy.lock().expect("transaction state mutex poisoned"), + crate::read_only_transaction::TransactionState::Started(_, _) + ); + if !is_started { + return Err(internal_error( + "Spanner failed to return a transaction ID for a query that included a BeginTransaction option", + )); } } } @@ -329,7 +443,8 @@ impl ResultSet { if partial_result_set.values.is_empty() { return Ok(()); } - let metadata = self.metadata.as_ref().unwrap(); + + let metadata = self.metadata.borrow().as_ref().unwrap().clone(); if metadata.column_types.is_empty() { return Err(internal_error( "PartialResultSet contained values but no column metadata was provided", @@ -418,36 +533,6 @@ impl ResultSet { e.status() .is_some_and(|status| status.code == Code::Unavailable) } - - /// Converts the [`ResultSet`] into a [`Stream`]. - /// - /// # Example - /// - /// ``` - /// # use google_cloud_spanner::client::ResultSet; - /// # use futures::TryStreamExt; - /// # use std::future::ready; - /// # async fn example(result_set: ResultSet) -> Result<(), google_cloud_spanner::Error> { - /// let rows: Vec<_> = result_set - /// .into_stream() - /// .try_filter(|row| { - /// let id = row.get::("Id"); - /// ready(id == "id1") - /// }) - /// .try_collect() - /// .await?; - /// # Ok(()) - /// # } - /// ``` - /// - /// This consumes the [`ResultSet`] and returns a stream of rows. - #[cfg(feature = "unstable-stream")] - pub fn into_stream(self) -> impl Stream> + Unpin { - use futures::stream::unfold; - Box::pin(unfold(self, |mut result_set| async move { - result_set.next().await.map(|row| (row, result_set)) - })) - } } /// Merges two values from successive `PartialResultSet`s into a single value. @@ -505,7 +590,8 @@ fn merge_values(target: &mut prost_types::Value, source: prost_types::Value) -> #[cfg(test)] impl ResultSet { pub(crate) fn set_max_buffered_partial_result_sets(&mut self, limit: usize) { - self.max_buffered_partial_result_sets = limit; + self.max_buffered_partial_result_sets + .store(limit, Ordering::Relaxed); } } @@ -1046,21 +1132,62 @@ pub(crate) mod tests { } #[tokio::test] - async fn test_result_set_precommit_token_tracked() { - let mut rs = run_mock_query(vec![PartialResultSet { - metadata: metadata(1), - precommit_token: Some( - spanner_grpc_mock::google::spanner::v1::MultiplexedSessionPrecommitToken { - precommit_token: b"test_token".to_vec(), - seq_num: 99, - }, - ), - ..Default::default() - }]) - .await; + async fn test_result_set_precommit_token_tracked() -> anyhow::Result<()> { + let mut mock = MockSpanner::new(); + mock.expect_execute_streaming_sql() + .returning(move |_request| { + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: metadata(1), + precommit_token: Some( + spanner_grpc_mock::google::spanner::v1::MultiplexedSessionPrecommitToken { + precommit_token: b"test_token".to_vec(), + seq_num: 99, + }, + ), + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + + let req = crate::model::ExecuteSqlRequest::default() + .set_session(db_client.session.name.clone()) + .set_sql("SELECT 1".to_string()); + + let stream = db_client + .spanner + .execute_streaming_sql(req.clone(), crate::RequestOptions::default()) + .send() + .await?; - // Force tracking mode since run_mock_query uses a ReadOnly transaction (NoOp). - rs.precommit_token_tracker = PrecommitTokenTracker::new(); + let tracker = PrecommitTokenTracker::new(); // Track mode! + + let mut rs = ResultSet::new( + stream, + None, + tracker.clone(), + db_client.clone(), + StreamOperation::Query(req), + ); // Read a row to trigger precommit token extraction assert!( @@ -1069,12 +1196,11 @@ pub(crate) mod tests { ); // Validate the tracker correctly intercepted and preserved the token - let token = rs - .precommit_token_tracker - .get() - .expect("token should be tracked"); + let token = tracker.get().expect("token should be tracked"); assert_eq!(token.seq_num, 99); assert_eq!(token.precommit_token, bytes::Bytes::from("test_token")); + + Ok(()) } #[tokio::test] @@ -1862,4 +1988,208 @@ pub(crate) mod tests { Ok(()) } + + #[tokio::test] + async fn test_lazy_begin_deadlock_fixed() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + let mut seq = mockall::Sequence::new(); + + // Setup mock to return metadata with transaction ID on first query. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let mut meta = metadata(1).expect("failed to create metadata"); + meta.transaction = Some(mock_v1::Transaction { + id: b"lazy_tx_id".to_vec(), + ..Default::default() + }); + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: Some(meta), + values: vec![string_val("1")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + // Mock call for second query which must carry the returned transaction ID + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + let selector = req + .transaction + .expect("missing transaction component") + .selector + .expect("missing selector component"); + + match selector { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, b"lazy_tx_id".to_vec()); + } + _ => panic!("Expected Selector::Id"), + } + + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: metadata(1), + values: vec![string_val("2")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + + // Use inline begin transaction + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + // Execute query but DO NOT call rs.next() + let _rs = tx.execute_query("SELECT 1").await?; + + // Execute second query against same transaction + let mut rs2 = tx.execute_query("SELECT 2").await?; + + // Assert it does not hang and yielded elements properly + let row2 = rs2.next().await; + assert!( + row2.is_some(), + "Implicit deadlock encountered; query 2 stalled!" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_result_set_metadata_not_available() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + + // Setup mock to return a stream that fails immediately. + mock.expect_execute_streaming_sql().returning(|_request| { + let stream = tokio_stream::iter(vec![Err(Status::internal("Internal error"))]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + use spanner_grpc_mock::google::spanner::v1::Session; + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + let tx = db_client.single_use().build(); + + let rs = tx.execute_query("SELECT 1").await?; + + // Call metadata() immediately. It should fail because the stream ends without metadata. + let result = rs.metadata().await; + assert!(result.is_err(), "Expected error but got Ok"); + assert!( + matches!(result.unwrap_err(), ResultSetError::MetadataNotAvailable), + "Expected MetadataNotAvailable error" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_result_set_metadata_available_before_next() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + + // Setup mock to return metadata in first chunk. + mock.expect_execute_streaming_sql().returning(|_request| { + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: metadata(1), + values: vec![string_val("1")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + use spanner_grpc_mock::google::spanner::v1::Session; + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + let tx = db_client.single_use().build(); + + let mut rs = tx.execute_query("SELECT 1").await?; + + // Call metadata() BEFORE next(). It should succeed. + let metadata = rs.metadata().await?; + assert_eq!(metadata.column_names().len(), 1); + assert_eq!(metadata.column_names()[0], "col0"); + + // Now consume the row + let row = rs.next().await; + assert!(row.is_some()); + + Ok(()) + } } diff --git a/src/spanner/src/result_set_metadata.rs b/src/spanner/src/result_set_metadata.rs index 2a6cd2bf1e..3e48cf2b09 100644 --- a/src/spanner/src/result_set_metadata.rs +++ b/src/spanner/src/result_set_metadata.rs @@ -26,9 +26,7 @@ use std::sync::Arc; /// let tx = db.single_use().build(); /// let mut rs = tx.execute_query(Statement::builder("SELECT 1 AS Number").build()).await?; /// -/// // Metadata is available after the first `next` call -/// let _ = rs.next().await.transpose()?; -/// let metadata = rs.metadata()?; +/// let metadata = rs.metadata().await?; /// /// for (name, type_) in metadata.column_names().iter().zip(metadata.column_types().iter()) { /// println!("Column: {} has type: {:?}", name, type_.code()); diff --git a/tests/spanner/src/query.rs b/tests/spanner/src/query.rs index 06cae8cd6a..e09874c1a6 100644 --- a/tests/spanner/src/query.rs +++ b/tests/spanner/src/query.rs @@ -158,14 +158,14 @@ pub async fn result_set_metadata(db_client: &DatabaseClient) -> anyhow::Result<( // 1. Simple normal query let sql = "SELECT 1 as num, 'Alice' as name"; - let mut rs = rot.execute_query(Statement::builder(sql).build()).await?; + let mut result_set = rot.execute_query(Statement::builder(sql).build()).await?; - assert!(rs.next().await.transpose()?.is_some()); - let metadata = rs.metadata()?; + let metadata = result_set.metadata().await?; assert_eq!( metadata.column_names(), &["num".to_string(), "name".to_string()] ); + assert!(result_set.next().await.transpose()?.is_some()); // 2. Query that returns zero rows let sql_zero_rows = r#" @@ -174,25 +174,25 @@ pub async fn result_set_metadata(db_client: &DatabaseClient) -> anyhow::Result<( ) SELECT num, name FROM Data WHERE 1=0 "#; - let mut rs_zero_rows = rot + let mut result_set_zero_rows = rot .execute_query(Statement::builder(sql_zero_rows).build()) .await?; - assert!(rs_zero_rows.next().await.transpose()?.is_none()); - let metadata_zero_rows = rs_zero_rows.metadata()?; + let metadata_zero_rows = result_set_zero_rows.metadata().await?; assert_eq!( metadata_zero_rows.column_names(), &["num".to_string(), "name".to_string()] ); + assert!(result_set_zero_rows.next().await.transpose()?.is_none()); // 3. Query with duplicate aliases let sql_dup = "SELECT 1 as dup, 2 as dup"; - let mut rs_dup = rot + let mut result_set_dup = rot .execute_query(Statement::builder(sql_dup).build()) .await?; - let row_dup = rs_dup.next().await.transpose()?.unwrap(); - let metadata_dup = rs_dup.metadata()?; + let row_dup = result_set_dup.next().await.transpose()?.unwrap(); + let metadata_dup = result_set_dup.metadata().await?; assert_eq!( metadata_dup.column_names(), &["dup".to_string(), "dup".to_string()] diff --git a/tests/spanner/src/write.rs b/tests/spanner/src/write.rs index cd2f9ae485..f0007df60b 100644 --- a/tests/spanner/src/write.rs +++ b/tests/spanner/src/write.rs @@ -526,6 +526,7 @@ async fn write_internal( let metadata = rs .metadata() + .await .expect("result set metadata is unexpectedly missing"); let column_count = metadata.column_names().len(); assert_eq!(row2.raw_values().len(), column_count); From 43d906b2c7c2e17b38071f5225171b125dd9ad7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Sun, 5 Apr 2026 12:35:18 +0200 Subject: [PATCH 06/17] perf(spanner): inline-begin for read/write transactions Adds support for inline-begin for read/write transactions. This reduces the number of round-trips to Spanner by one for read/write transactions. --- src/spanner/src/read_only_transaction.rs | 30 + src/spanner/src/read_write_transaction.rs | 1604 ++++++++++++++++----- src/spanner/src/transaction_runner.rs | 2 +- 3 files changed, 1298 insertions(+), 338 deletions(-) diff --git a/src/spanner/src/read_only_transaction.rs b/src/spanner/src/read_only_transaction.rs index 45d5bde4b4..28c693acd0 100644 --- a/src/spanner/src/read_only_transaction.rs +++ b/src/spanner/src/read_only_transaction.rs @@ -19,6 +19,7 @@ use crate::precommit::PrecommitTokenTracker; use crate::result_set::{ResultSet, StreamOperation}; use crate::statement::Statement; use crate::timestamp_bound::TimestampBound; +use crate::transaction_retry_policy::is_aborted; use std::sync::{Arc, Mutex}; use tokio::sync::Notify; @@ -605,6 +606,32 @@ impl ReadContextTransactionSelector { } } + /// Returns the transaction ID if it is already available, without waiting. + /// + /// This method inspects the selector and returns the transaction ID if the + /// transaction has already started. It returns `None` if the transaction + /// has not yet started or is in a state without an ID. + #[allow(dead_code)] + pub(crate) fn get_id_no_wait(&self) -> Option { + use crate::generated::gapic_dataplane::model::transaction_selector::Selector; + match self { + Self::Fixed(selector, _) => { + if let Some(Selector::Id(id)) = &selector.selector { + return Some(id.clone()); + } + } + Self::Lazy(lazy) => { + let guard = lazy.lock().expect("transaction state mutex poisoned"); + if let TransactionState::Started(selector, _) = &*guard { + if let Some(Selector::Id(id)) = &selector.selector { + return Some(id.clone()); + } + } + } + } + None + } + /// Resets the selector state from `Starting` back to `NotStarted`. /// /// This is used during stream resume fallbacks when the first query stream @@ -699,6 +726,9 @@ macro_rules! execute_stream_with_retry { { Ok(s) => s, Err(e) => { + if is_aborted(&e) { + return Err(e); + } if $self.begin_explicitly_if_not_started().await? { $request.transaction = Some($self.transaction_selector.selector().await?); $self diff --git a/src/spanner/src/read_write_transaction.rs b/src/spanner/src/read_write_transaction.rs index dd7f9e0553..78ef3ab5d4 100644 --- a/src/spanner/src/read_write_transaction.rs +++ b/src/spanner/src/read_write_transaction.rs @@ -33,7 +33,9 @@ use crate::precommit::PrecommitTokenTracker; use crate::read_only_transaction::ReadContext; use crate::result_set::ResultSet; use crate::statement::Statement; +use crate::transaction_retry_policy::is_aborted; use std::sync::Arc; +use std::sync::Mutex; use std::sync::atomic::{AtomicI64, Ordering}; /// A builder for [ReadWriteTransaction]. @@ -42,6 +44,7 @@ pub(crate) struct ReadWriteTransactionBuilder { client: DatabaseClient, options: TransactionOptions, transaction_tag: Option, + explicit_begin: bool, } impl ReadWriteTransactionBuilder { @@ -50,6 +53,7 @@ impl ReadWriteTransactionBuilder { client, options: TransactionOptions::default().set_read_write(ReadWrite::default()), transaction_tag: None, + explicit_begin: false, } } @@ -83,28 +87,58 @@ impl ReadWriteTransactionBuilder { self } - pub(crate) async fn begin_transaction(&self) -> crate::Result { - let mut request = BeginTransactionRequest::default() - .set_session(self.client.session.name.clone()) - .set_options(self.options.clone()); - if let Some(tag) = &self.transaction_tag { - request = request.set_request_options( - crate::model::RequestOptions::default().set_transaction_tag(tag.clone()), - ); - } + /// Sets whether the transaction should be explicitly started using a `BeginTransaction` RPC. + /// + /// By default, the Spanner client will inline the `BeginTransaction` call with the first query + /// or DML statement in the transaction. This reduces the number of round-trips to Spanner that + /// are needed for a transaction. Setting this option to `true` can be beneficial for specific + /// transaction shapes: + /// + /// 1. When the transaction executes multiple parallel queries at the start of the transaction. + /// Only one query can include a `BeginTransaction` option, and all other queries must wait for + /// the first query to return the first result before they can proceed to execute. A + /// `BeginTransaction` RPC will quickly return a transaction ID and allow all queries to start + /// execution in parallel once the transaction ID has been returned. + /// 2. When the first statement in the transaction could fail. If the statement fails, then it + /// will also not start a transaction and return a transaction ID. The transaction will then + /// fall back to executing a `BeginTransaction` RPC and retry the first statement. + /// + /// Default is `false` (inline begin). + pub fn with_explicit_begin_transaction(mut self, explicit: bool) -> Self { + self.explicit_begin = explicit; + self + } - // TODO(#4972): make request options configurable - let response = self - .client - .spanner - .begin_transaction(request, RequestOptions::default()) - .await?; + pub(crate) async fn build(&self) -> crate::Result { + let transaction_selector = if self.explicit_begin { + let mut request = BeginTransactionRequest::default() + .set_session(self.client.session.name.clone()) + .set_options(self.options.clone()); + if let Some(tag) = &self.transaction_tag { + request = request.set_request_options( + crate::model::RequestOptions::default().set_transaction_tag(tag.clone()), + ); + } + + // TODO(#4972): make request options configurable + let response = self + .client + .spanner + .begin_transaction(request, RequestOptions::default()) + .await?; - let transaction_selector = crate::read_only_transaction::ReadContextTransactionSelector::Fixed( TransactionSelector::default().set_id(response.id), None, - ); + ) + } else { + crate::read_only_transaction::ReadContextTransactionSelector::Lazy(Arc::new( + Mutex::new(crate::read_only_transaction::TransactionState::NotStarted( + self.options.clone(), + )), + )) + }; + Ok(ReadWriteTransaction { context: ReadContext { client: self.client.clone(), @@ -124,6 +158,64 @@ pub struct ReadWriteTransaction { seqno: Arc, } +/// Helper macro to execute a DML or BatchDML RPC with retry logic if the +/// request included a BeginTransaction option. +macro_rules! execute_with_retry { + ($self:expr, $request:ident, $rpc_method:ident, $extract_id:expr) => {{ + let is_starting = matches!( + $request + .transaction + .as_ref() + .and_then(|t| t.selector.as_ref()), + Some(Selector::Begin(_)) + ); + + let response_result = $self + .context + .client + .spanner + .$rpc_method($request.clone(), RequestOptions::default()) + .await; + + let response = match response_result { + Ok(response) => { + if is_starting { + let id = $extract_id(&response).ok_or_else(|| { + crate::error::internal_error("Transaction ID was not returned by Spanner") + })?; + $self.context.transaction_selector.update(id, None)?; + } + response + } + Err(error) => { + if !is_starting { + return Err(error); + } + if is_aborted(&error) { + return Err(error); + } + + $self + .context + .transaction_selector + .begin_explicitly(&$self.context.client) + .await?; + + $request.transaction = Some($self.context.transaction_selector.selector().await?); + + $self + .context + .client + .spanner + .$rpc_method($request.clone(), RequestOptions::default()) + .await? + } + }; + + response + }}; +} + impl ReadWriteTransaction { /// Executes a query using this transaction. pub async fn execute_query>( @@ -152,12 +244,19 @@ impl ReadWriteTransaction { .set_seqno(seqno); request.request_options = self.context.amend_request_options(request.request_options); - let response = self - .context - .client - .spanner - .execute_sql(request, RequestOptions::default()) - .await?; + let response = execute_with_retry!( + self, + request, + execute_sql, + |response: &crate::model::ResultSet| { + response + .metadata + .as_ref() + .and_then(|md| md.transaction.as_ref()) + .map(|t| t.id.clone()) + } + ); + self.context .precommit_token_tracker .update(response.precommit_token); @@ -241,37 +340,39 @@ impl ReadWriteTransaction { pub async fn execute_batch_update(&self, batch: BatchDml) -> crate::Result> { let seqno = self.seqno.fetch_add(1, Ordering::SeqCst); - let statements: Vec = batch - .statements + let BatchDml { + statements, + request_options, + } = batch; + let statements: Vec = statements .into_iter() .map(|stmt: crate::statement::Statement| stmt.into_batch_statement()) .collect(); - let request = ExecuteBatchDmlRequest::default() + let mut request = ExecuteBatchDmlRequest::default() .set_session(self.context.client.session.name.clone()) .set_transaction(self.context.transaction_selector.selector().await?) .set_seqno(seqno) .set_statements(statements) - .set_or_clear_request_options( - self.context.amend_request_options(batch.request_options), - ); - - let response_result = self - .context - .client - .spanner - .execute_batch_dml(request, RequestOptions::default()) - .await; + .set_or_clear_request_options(self.context.amend_request_options(request_options)); - match response_result { - Ok(response) => { - self.context - .precommit_token_tracker - .update(response.precommit_token.clone()); - crate::batch_dml::process_response(response) + let response = execute_with_retry!( + self, + request, + execute_batch_dml, + |response: &crate::model::ExecuteBatchDmlResponse| { + response + .result_sets + .first() + .and_then(|rs| rs.metadata.as_ref()) + .and_then(|md| md.transaction.as_ref()) + .map(|t| t.id.clone()) } - Err(e) => Err(e), - } + ); + self.context + .precommit_token_tracker + .update(response.precommit_token.clone()); + crate::batch_dml::process_response(response) } pub(crate) async fn transaction_id(&self) -> crate::Result { @@ -347,6 +448,9 @@ mod tests { use gaxi::grpc::tonic; use spanner_grpc_mock::google::spanner::v1; use std::fmt::Debug; + use v1::result_set_stats::RowCount; + use v1::transaction_options::Mode; + use v1::transaction_selector::Selector; #[test] fn auto_traits() { @@ -355,28 +459,61 @@ mod tests { } #[tokio::test] - async fn read_write_transaction_commit_retry() { + async fn read_write_transaction_commit_retry_explicit() -> anyhow::Result<()> { + run_read_write_transaction_commit_retry(true).await + } + + #[tokio::test] + async fn read_write_transaction_commit_retry_inline() -> anyhow::Result<()> { + run_read_write_transaction_commit_retry(false).await + } + + async fn run_read_write_transaction_commit_retry(explicit_begin: bool) -> anyhow::Result<()> { let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|req| { - let req = req.into_inner(); - assert_eq!( - req.session, - "projects/p/instances/i/databases/d/sessions/123" - ); - Ok(tonic::Response::new(v1::Transaction { - id: vec![0, 0, 7], - ..Default::default() - })) - }); + if explicit_begin { + mock.expect_begin_transaction().once().returning(|req| { + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + Ok(tonic::Response::new(v1::Transaction { + id: vec![0, 0, 7], + ..Default::default() + })) + }); + } // execute_update returns a precommit token. - mock.expect_execute_sql().once().returning(|req| { + mock.expect_execute_sql().once().returning(move |req| { let req = req.into_inner(); assert_eq!(req.sql, "UPDATE Users SET Name = 'Bob' WHERE Id = 1"); + + if !explicit_begin { + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!(selector, Selector::Begin(_))); + } + + let mut metadata = v1::ResultSetMetadata { + row_type: Some(v1::StructType { fields: vec![] }), + ..Default::default() + }; + if !explicit_begin { + metadata.transaction = Some(v1::Transaction { + id: vec![0, 0, 7], + ..Default::default() + }); + } + Ok(tonic::Response::new(v1::ResultSet { + metadata: Some(metadata), stats: Some(v1::ResultSetStats { - row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)), + row_count: Some(RowCount::RowCountExact(1)), ..Default::default() }), precommit_token: Some(v1::MultiplexedSessionPrecommitToken { @@ -387,94 +524,135 @@ mod tests { })) }); + let mut seq = mockall::Sequence::new(); + // Simulate that commit returns a precommit token in the response. // This would normally not happen, but we test it here to verify // that the commit is retried. - mock.expect_commit().once().returning(|req| { - let req = req.into_inner(); - assert_eq!( - req.precommit_token, - Some(v1::MultiplexedSessionPrecommitToken { - precommit_token: vec![101], - seq_num: 1, - }) - ); - Ok(tonic::Response::new(v1::CommitResponse { - commit_timestamp: Some(prost_types::Timestamp { - seconds: 1000, - nanos: 0, - }), - multiplexed_session_retry: Some( - v1::commit_response::MultiplexedSessionRetry::PrecommitToken( - v1::MultiplexedSessionPrecommitToken { - precommit_token: vec![202], - seq_num: 2, - }, + mock.expect_commit() + .once() + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + assert_eq!( + req.precommit_token, + Some(v1::MultiplexedSessionPrecommitToken { + precommit_token: vec![101], + seq_num: 1, + }) + ); + Ok(tonic::Response::new(v1::CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 1000, + nanos: 0, + }), + multiplexed_session_retry: Some( + v1::commit_response::MultiplexedSessionRetry::PrecommitToken( + v1::MultiplexedSessionPrecommitToken { + precommit_token: vec![202], + seq_num: 2, + }, + ), ), - ), - ..Default::default() - })) - }); + ..Default::default() + })) + }); // Second commit retry is automatically issued with the new token - mock.expect_commit().once().returning(|req| { - let req = req.into_inner(); - assert_eq!( - req.precommit_token, - Some(v1::MultiplexedSessionPrecommitToken { - precommit_token: vec![202], - seq_num: 2, - }) - ); - Ok(tonic::Response::new(v1::CommitResponse { - commit_timestamp: Some(prost_types::Timestamp { - seconds: 1001, - nanos: 0, - }), - ..Default::default() - })) - }); + mock.expect_commit() + .once() + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + assert_eq!( + req.precommit_token, + Some(v1::MultiplexedSessionPrecommitToken { + precommit_token: vec![202], + seq_num: 2, + }) + ); + Ok(tonic::Response::new(v1::CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 1001, + nanos: 0, + }), + ..Default::default() + })) + }); let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() - .await - .expect("Failed to build transaction"); + .with_explicit_begin_transaction(explicit_begin) + .build() + .await?; let count = tx .execute_update("UPDATE Users SET Name = 'Bob' WHERE Id = 1") - .await - .unwrap(); + .await?; assert_eq!(count, 1); - let timestamp = tx.commit().await.unwrap(); + let timestamp = tx.commit().await?; assert_eq!(timestamp.seconds(), 1001); + Ok(()) + } + + #[tokio::test] + async fn read_write_transaction_execute_update_explicit() { + run_read_write_transaction_execute_update(true).await; } #[tokio::test] - async fn read_write_transaction_execute_update() { + async fn read_write_transaction_execute_update_inline() { + run_read_write_transaction_execute_update(false).await; + } + + async fn run_read_write_transaction_execute_update(explicit_begin: bool) { let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|req| { - let req = req.into_inner(); - assert_eq!( - req.session, - "projects/p/instances/i/databases/d/sessions/123" - ); - Ok(tonic::Response::new(v1::Transaction { - id: vec![1, 2, 3], - ..Default::default() - })) - }); + if explicit_begin { + mock.expect_begin_transaction().once().returning(|req| { + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + Ok(tonic::Response::new(v1::Transaction { + id: vec![1, 2, 3], + ..Default::default() + })) + }); + } - mock.expect_execute_sql().once().returning(|req| { + mock.expect_execute_sql().once().returning(move |req| { let req = req.into_inner(); assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1"); assert_eq!(req.seqno, 1); + + if !explicit_begin { + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!(selector, Selector::Begin(_))); + } + + let mut metadata = v1::ResultSetMetadata { + row_type: Some(v1::StructType { fields: vec![] }), + ..Default::default() + }; + if !explicit_begin { + metadata.transaction = Some(v1::Transaction { + id: vec![1, 2, 3], + ..Default::default() + }); + } + Ok(tonic::Response::new(v1::ResultSet { + metadata: Some(metadata), stats: Some(v1::ResultSetStats { - row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)), + row_count: Some(RowCount::RowCountExact(1)), ..Default::default() }), ..Default::default() @@ -505,7 +683,8 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .with_explicit_begin_transaction(explicit_begin) + .build() .await .expect("Failed to build transaction"); let count = tx @@ -519,20 +698,55 @@ mod tests { } #[tokio::test] - async fn read_write_transaction_execute_update_invalid_stats() { + async fn read_write_transaction_execute_update_invalid_stats_explicit() -> anyhow::Result<()> { + run_read_write_transaction_execute_update_invalid_stats(true).await + } + + #[tokio::test] + async fn read_write_transaction_execute_update_invalid_stats_inline() -> anyhow::Result<()> { + run_read_write_transaction_execute_update_invalid_stats(false).await + } + + async fn run_read_write_transaction_execute_update_invalid_stats( + explicit_begin: bool, + ) -> anyhow::Result<()> { let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|_| { - Ok(tonic::Response::new(v1::Transaction { - id: vec![1, 2, 3], + if explicit_begin { + mock.expect_begin_transaction().once().returning(|_| { + Ok(tonic::Response::new(v1::Transaction { + id: vec![1, 2, 3], + ..Default::default() + })) + }); + } + + mock.expect_execute_sql().once().returning(move |req| { + let req = req.into_inner(); + if !explicit_begin { + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!(selector, Selector::Begin(_))); + } + + let mut metadata = v1::ResultSetMetadata { + row_type: Some(v1::StructType { fields: vec![] }), ..Default::default() - })) - }); + }; + if !explicit_begin { + metadata.transaction = Some(v1::Transaction { + id: vec![1, 2, 3], + ..Default::default() + }); + } - mock.expect_execute_sql().once().returning(|_| { Ok(tonic::Response::new(v1::ResultSet { + metadata: Some(metadata), stats: Some(v1::ResultSetStats { - row_count: Some(v1::result_set_stats::RowCount::RowCountLowerBound(1)), + row_count: Some(RowCount::RowCountLowerBound(1)), ..Default::default() }), ..Default::default() @@ -542,9 +756,9 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() - .await - .expect("Failed to build transaction"); + .with_explicit_begin_transaction(explicit_begin) + .build() + .await?; let result = tx .execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1") @@ -556,97 +770,177 @@ mod tests { "Error did not contain expected message: {:?}", err ); + Ok(()) } #[tokio::test] - async fn read_write_transaction_rollback() { + async fn read_write_transaction_rollback_explicit() -> anyhow::Result<()> { + run_read_write_transaction_rollback(true).await + } + + #[tokio::test] + async fn read_write_transaction_rollback_inline() -> anyhow::Result<()> { + run_read_write_transaction_rollback(false).await + } + + async fn run_read_write_transaction_rollback(explicit_begin: bool) -> anyhow::Result<()> { let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|req| { - let req = req.into_inner(); - assert_eq!( - req.session, - "projects/p/instances/i/databases/d/sessions/123" - ); - Ok(tonic::Response::new(v1::Transaction { - id: vec![9, 9, 9], - ..Default::default() - })) - }); + let transaction_id = vec![9, 9, 9]; + + if explicit_begin { + let id = transaction_id.clone(); + mock.expect_begin_transaction().once().returning(move |_| { + Ok(tonic::Response::new(v1::Transaction { + id: id.clone(), + ..Default::default() + })) + }); + } else { + let id = transaction_id.clone(); + mock.expect_execute_sql().once().returning(move |req| { + let req = req.into_inner(); + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!(selector, Selector::Begin(_))); + + Ok(tonic::Response::new(v1::ResultSet { + metadata: Some(v1::ResultSetMetadata { + transaction: Some(v1::Transaction { + id: id.clone(), + ..Default::default() + }), + ..Default::default() + }), + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + })) + }); + } - mock.expect_rollback().once().returning(|req| { + let id = transaction_id.clone(); + mock.expect_rollback().once().returning(move |req| { let req = req.into_inner(); assert_eq!( req.session, "projects/p/instances/i/databases/d/sessions/123" ); - assert_eq!(req.transaction_id, vec![9, 9, 9]); + assert_eq!(req.transaction_id, id); Ok(tonic::Response::new(())) }); let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() - .await - .expect("Failed to build transaction"); + .with_explicit_begin_transaction(explicit_begin) + .build() + .await?; + + if !explicit_begin { + tx.execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1") + .await + .expect("Failed to execute update"); + } + + tx.rollback().await?; + Ok(()) + } - tx.rollback().await.expect("Failed to rollback"); + #[tokio::test] + async fn read_write_transaction_execute_batch_update_explicit() -> anyhow::Result<()> { + run_read_write_transaction_execute_batch_update(true).await } #[tokio::test] - async fn read_write_transaction_execute_batch_update() -> anyhow::Result<()> { + async fn read_write_transaction_execute_batch_update_inline() -> anyhow::Result<()> { + run_read_write_transaction_execute_batch_update(false).await + } + + async fn run_read_write_transaction_execute_batch_update( + explicit_begin: bool, + ) -> anyhow::Result<()> { let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|_| { - Ok(tonic::Response::new(v1::Transaction { - id: vec![4, 5, 6], - ..Default::default() - })) - }); + if explicit_begin { + mock.expect_begin_transaction().once().returning(|_| { + Ok(tonic::Response::new(v1::Transaction { + id: vec![4, 5, 6], + ..Default::default() + })) + }); + } - mock.expect_execute_batch_dml().once().returning(|req| { - let req = req.into_inner(); - assert_eq!(req.statements.len(), 2); - assert_eq!( - req.statements[0].sql, - "UPDATE Users SET Name = 'Alice' WHERE Id = 1" - ); - assert_eq!( - req.statements[1].sql, - "UPDATE Users SET Name = 'Bob' WHERE Id = 2" - ); + mock.expect_execute_batch_dml() + .once() + .returning(move |req| { + let req = req.into_inner(); + assert_eq!(req.statements.len(), 2); + assert_eq!( + req.statements[0].sql, + "UPDATE Users SET Name = 'Alice' WHERE Id = 1" + ); + assert_eq!( + req.statements[1].sql, + "UPDATE Users SET Name = 'Bob' WHERE Id = 2" + ); + + if !explicit_begin { + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + assert!(matches!(selector, Selector::Begin(_))); + } - Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse { - result_sets: vec![ - v1::ResultSet { - stats: Some(v1::ResultSetStats { - row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)), - ..Default::default() - }), - ..Default::default() - }, - v1::ResultSet { - stats: Some(v1::ResultSetStats { - row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)), - ..Default::default() - }), + let mut metadata = v1::ResultSetMetadata { + ..Default::default() + }; + if !explicit_begin { + metadata.transaction = Some(v1::Transaction { + id: vec![4, 5, 6], ..Default::default() - }, - ], - status: Some(spanner_grpc_mock::google::rpc::Status { - code: 0, - message: "OK".into(), - details: vec![], - }), - ..Default::default() - })) - }); - - let (db_client, _server) = setup_db_client(mock).await; + }); + } + + Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse { + result_sets: vec![ + v1::ResultSet { + metadata: Some(metadata), + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + }, + v1::ResultSet { + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + }, + ], + status: Some(spanner_grpc_mock::google::rpc::Status { + code: 0, + message: "OK".into(), + details: vec![], + }), + ..Default::default() + })) + }); + + let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client) - .begin_transaction() + .with_explicit_begin_transaction(explicit_begin) + .build() .await?; let batch = BatchDml::builder() @@ -660,38 +954,77 @@ mod tests { } #[tokio::test] - async fn read_write_transaction_execute_batch_update_partial_failure() -> anyhow::Result<()> { + async fn read_write_transaction_execute_batch_update_partial_failure_explicit() + -> anyhow::Result<()> { + run_read_write_transaction_execute_batch_update_partial_failure(true).await + } + + #[tokio::test] + async fn read_write_transaction_execute_batch_update_partial_failure_inline() + -> anyhow::Result<()> { + run_read_write_transaction_execute_batch_update_partial_failure(false).await + } + + async fn run_read_write_transaction_execute_batch_update_partial_failure( + explicit_begin: bool, + ) -> anyhow::Result<()> { let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|_| { - Ok(tonic::Response::new(v1::Transaction { - id: vec![7, 8, 9], - ..Default::default() - })) - }); + if explicit_begin { + mock.expect_begin_transaction().once().returning(|_| { + Ok(tonic::Response::new(v1::Transaction { + id: vec![7, 8, 9], + ..Default::default() + })) + }); + } - mock.expect_execute_batch_dml().once().returning(|_| { - Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse { - result_sets: vec![v1::ResultSet { - stats: Some(v1::ResultSetStats { - row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)), + mock.expect_execute_batch_dml() + .once() + .returning(move |req| { + let req = req.into_inner(); + if !explicit_begin { + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + assert!(matches!(selector, Selector::Begin(_))); + } + + let mut metadata = v1::ResultSetMetadata { + ..Default::default() + }; + if !explicit_begin { + metadata.transaction = Some(v1::Transaction { + id: vec![7, 8, 9], ..Default::default() + }); + } + + Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse { + result_sets: vec![v1::ResultSet { + metadata: Some(metadata), + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + }], + status: Some(spanner_grpc_mock::google::rpc::Status { + code: gaxi::grpc::tonic::Code::AlreadyExists as i32, + message: "row already exists".into(), + details: vec![], }), ..Default::default() - }], - status: Some(spanner_grpc_mock::google::rpc::Status { - code: gaxi::grpc::tonic::Code::AlreadyExists as i32, - message: "row already exists".into(), - details: vec![], - }), - ..Default::default() - })) - }); + })) + }); let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client) - .begin_transaction() + .with_explicit_begin_transaction(explicit_begin) + .build() .await?; let batch = BatchDml::builder() @@ -715,94 +1048,236 @@ mod tests { } #[tokio::test] - async fn read_write_transaction_execute_multiple_updates() { + async fn read_write_transaction_execute_multiple_updates_explicit() -> anyhow::Result<()> { + run_read_write_transaction_execute_multiple_updates(true).await + } + + #[tokio::test] + async fn read_write_transaction_execute_multiple_updates_inline() -> anyhow::Result<()> { + run_read_write_transaction_execute_multiple_updates(false).await + } + + async fn run_read_write_transaction_execute_multiple_updates( + explicit_begin: bool, + ) -> anyhow::Result<()> { let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|req| { - let req = req.into_inner(); - assert_eq!( - req.session, - "projects/p/instances/i/databases/d/sessions/123" - ); - Ok(tonic::Response::new(v1::Transaction { - id: vec![4, 5, 6], - ..Default::default() - })) - }); + if explicit_begin { + mock.expect_begin_transaction().once().returning(|req| { + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + Ok(tonic::Response::new(v1::Transaction { + id: vec![4, 5, 6], + ..Default::default() + })) + }); + } - let counter = Arc::new(AtomicI64::new(1)); - mock.expect_execute_sql().times(3).returning(move |req| { - let req = req.into_inner(); - assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1"); - let c = counter.fetch_add(1, Ordering::SeqCst); - assert_eq!(req.seqno, c); + let mut seq = mockall::Sequence::new(); - Ok(tonic::Response::new(v1::ResultSet { - stats: Some(v1::ResultSetStats { - row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)), + // First update + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1"); + assert_eq!(req.seqno, 1); + + let mut metadata = v1::ResultSetMetadata { ..Default::default() - }), - ..Default::default() - })) - }); + }; + + if !explicit_begin { + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + assert!(matches!(selector, Selector::Begin(_))); + metadata.transaction = Some(v1::Transaction { + id: vec![4, 5, 6], + ..Default::default() + }); + } else { + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + } + _ => panic!("Expected Selector::Id"), + } + } + + Ok(tonic::Response::new(v1::ResultSet { + metadata: Some(metadata), + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + })) + }); + + // Second update + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1"); + assert_eq!(req.seqno, 2); + + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + } + _ => panic!("Expected Selector::Id"), + } + + Ok(tonic::Response::new(v1::ResultSet { + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + })) + }); + + // Third update + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1"); + assert_eq!(req.seqno, 3); + + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + } + _ => panic!("Expected Selector::Id"), + } + + Ok(tonic::Response::new(v1::ResultSet { + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + })) + }); let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() - .await - .expect("Failed to build transaction"); + .with_explicit_begin_transaction(explicit_begin) + .build() + .await?; for i in 1..=3 { let count = tx .execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1") .await - .unwrap_or_else(|_| panic!("Failed to execute update {}", i)); + .map_err(|e| anyhow::anyhow!("Failed to execute update {}: {:?}", i, e))?; assert_eq!(count, 1); } + Ok(()) + } + + #[tokio::test] + async fn read_write_transaction_execute_query_explicit() -> anyhow::Result<()> { + run_read_write_transaction_execute_query(true).await } #[tokio::test] - async fn read_write_transaction_execute_query() { + async fn read_write_transaction_execute_query_inline() -> anyhow::Result<()> { + run_read_write_transaction_execute_query(false).await + } + + async fn run_read_write_transaction_execute_query(explicit_begin: bool) -> anyhow::Result<()> { use crate::client::Statement; let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|req| { - let req = req.into_inner(); - assert_eq!( - req.session, - "projects/p/instances/i/databases/d/sessions/123" - ); - Ok(tonic::Response::new(v1::Transaction { - id: vec![7, 8, 9], - ..Default::default() - })) - }); + if explicit_begin { + mock.expect_begin_transaction().once().returning(|req| { + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + Ok(tonic::Response::new(v1::Transaction { + id: vec![7, 8, 9], + ..Default::default() + })) + }); + } - mock.expect_execute_streaming_sql().once().returning(|req| { + mock.expect_execute_streaming_sql().once().returning(move |req| { let req = req.into_inner(); assert_eq!(req.sql, "SELECT 1"); // Queries do not need to include a sequence number. assert_eq!(req.seqno, 0); - assert_eq!( - req.transaction, - Some(v1::TransactionSelector { - selector: Some(v1::transaction_selector::Selector::Id(vec![7, 8, 9])) - }) - ); + if !explicit_begin { + let transaction = req.transaction.as_ref().expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!(selector, Selector::Begin(_))); + } else { + assert_eq!( + req.transaction, + Some(v1::TransactionSelector { + selector: Some(Selector::Id(vec![7, 8, 9])) + }) + ); + } type StreamType = ::ExecuteStreamingSqlStream; - let stream: tokio_stream::Empty> = tokio_stream::empty(); + + let mut metadata = v1::ResultSetMetadata { + row_type: Some(v1::StructType { fields: vec![] }), + ..Default::default() + }; + if !explicit_begin { + metadata.transaction = Some(v1::Transaction { + id: vec![7, 8, 9], + ..Default::default() + }); + } + + let first_response = v1::PartialResultSet { + metadata: Some(metadata), + ..Default::default() + }; + + let stream = tokio_stream::iter(vec![Ok(first_response)]); Ok(tonic::Response::new(Box::pin(stream) as StreamType)) }); let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() - .await - .expect("Failed to build transaction"); + .with_explicit_begin_transaction(explicit_begin) + .build() + .await?; let mut rs = tx .execute_query(Statement::builder("SELECT 1").build()) @@ -811,6 +1286,7 @@ mod tests { let result = rs.next().await; assert!(result.is_none(), "expected None, got empty stream"); + Ok(()) } #[tokio::test] @@ -827,7 +1303,7 @@ mod tests { let options = req.options.expect("missing transaction options"); let mode = options.mode.expect("missing mode"); match mode { - v1::transaction_options::Mode::ReadWrite(rw) => { + Mode::ReadWrite(rw) => { assert_eq!( rw.read_lock_mode, v1::transaction_options::read_write::ReadLockMode::Pessimistic as i32 @@ -852,44 +1328,148 @@ mod tests { let _tx = ReadWriteTransactionBuilder::new(db_client.clone()) .with_isolation_level(IsolationLevel::Serializable) .with_read_lock_mode(ReadLockMode::Pessimistic) - .begin_transaction() + .build() .await .expect("Failed to build transaction"); } #[tokio::test] - async fn read_write_transaction_tracks_highest_precommit_token() { + async fn read_write_transaction_tracks_highest_precommit_token_explicit() -> anyhow::Result<()> + { + run_read_write_transaction_tracks_highest_precommit_token(true).await + } + + #[tokio::test] + async fn read_write_transaction_tracks_highest_precommit_token_inline() -> anyhow::Result<()> { + run_read_write_transaction_tracks_highest_precommit_token(false).await + } + + async fn run_read_write_transaction_tracks_highest_precommit_token( + explicit_begin: bool, + ) -> anyhow::Result<()> { let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|_| { - Ok(tonic::Response::new(v1::Transaction { - id: vec![4, 2], - ..Default::default() - })) - }); + if explicit_begin { + mock.expect_begin_transaction().once().returning(|_| { + Ok(tonic::Response::new(v1::Transaction { + id: vec![4, 2], + ..Default::default() + })) + }); + } - // 3 sequential updates returning tokens [seq 2, seq 5, seq 3] - let tokens_iter = vec![2, 5, 3].into_iter(); - let counter_mutex = std::sync::Mutex::new(tokens_iter); + let mut seq = mockall::Sequence::new(); - mock.expect_execute_sql().times(3).returning(move |_req| { - let seq = counter_mutex - .lock() - .expect("Failed to lock mutex") - .next() - .expect("Failed to get next token"); - Ok(tonic::Response::new(v1::ResultSet { - stats: Some(v1::ResultSetStats { - row_count: Some(v1::result_set_stats::RowCount::RowCountExact(1)), + // First update + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + let mut metadata = v1::ResultSetMetadata { ..Default::default() - }), - precommit_token: Some(v1::MultiplexedSessionPrecommitToken { - precommit_token: vec![seq as u8], - seq_num: seq, - }), - ..Default::default() - })) - }); + }; + + if !explicit_begin { + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + assert!(matches!(selector, Selector::Begin(_))); + metadata.transaction = Some(v1::Transaction { + id: vec![4, 2], + ..Default::default() + }); + } else { + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Id(id) => { + assert_eq!(id, vec![4, 2]); + } + _ => panic!("Expected Selector::Id"), + } + } + + Ok(tonic::Response::new(v1::ResultSet { + metadata: Some(metadata), + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + precommit_token: Some(v1::MultiplexedSessionPrecommitToken { + precommit_token: vec![2], + seq_num: 2, + }), + ..Default::default() + })) + }); + + // Second update + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Id(id) => { + assert_eq!(id, vec![4, 2]); + } + _ => panic!("Expected Selector::Id"), + } + + Ok(tonic::Response::new(v1::ResultSet { + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + precommit_token: Some(v1::MultiplexedSessionPrecommitToken { + precommit_token: vec![5], + seq_num: 5, + }), + ..Default::default() + })) + }); + + // Third update + mock.expect_execute_sql() + .once() + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Id(id) => { + assert_eq!(id, vec![4, 2]); + } + _ => panic!("Expected Selector::Id"), + } + + Ok(tonic::Response::new(v1::ResultSet { + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + precommit_token: Some(v1::MultiplexedSessionPrecommitToken { + precommit_token: vec![3], + seq_num: 3, + }), + ..Default::default() + })) + }); // Commit should only use the highest token (seq 5) mock.expect_commit().once().returning(|req| { @@ -912,9 +1492,9 @@ mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() - .await - .expect("Failed to build transaction"); + .with_explicit_begin_transaction(explicit_begin) + .build() + .await?; for _ in 0..3 { tx.execute_update("UPDATE Y") @@ -923,74 +1503,424 @@ mod tests { } let ts = tx.commit().await.expect("Failed to commit transaction"); assert_eq!(ts.seconds(), 12345); + Ok(()) } #[tokio::test] - async fn read_write_transaction_commit_retry_exactly_once() { + async fn read_write_transaction_commit_retry_exactly_once_explicit() -> anyhow::Result<()> { + run_read_write_transaction_commit_retry_exactly_once(true).await + } + + #[tokio::test] + async fn read_write_transaction_commit_retry_exactly_once_inline() -> anyhow::Result<()> { + run_read_write_transaction_commit_retry_exactly_once(false).await + } + + async fn run_read_write_transaction_commit_retry_exactly_once( + explicit_begin: bool, + ) -> anyhow::Result<()> { let mut mock = create_session_mock(); - mock.expect_begin_transaction().once().returning(|_| { - Ok(tonic::Response::new(v1::Transaction { - id: vec![7, 7], + let transaction_id = vec![7, 7]; + + if explicit_begin { + let id = transaction_id.clone(); + mock.expect_begin_transaction().once().returning(move |_| { + Ok(tonic::Response::new(v1::Transaction { + id: id.clone(), + ..Default::default() + })) + }); + } else { + let id = transaction_id.clone(); + mock.expect_execute_sql().once().returning(move |req| { + let req = req.into_inner(); + let transaction = req + .transaction + .as_ref() + .expect("transaction options required for inline begin"); + let selector = transaction.selector.as_ref().expect("selector required"); + assert!(matches!(selector, Selector::Begin(_))); + + Ok(tonic::Response::new(v1::ResultSet { + metadata: Some(v1::ResultSetMetadata { + transaction: Some(v1::Transaction { + id: id.clone(), + ..Default::default() + }), + ..Default::default() + }), + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + })) + }); + } + + let mut seq = mockall::Sequence::new(); + + // Initial commit returns a retry token (seq 2) + mock.expect_commit() + .once() + .in_sequence(&mut seq) + .returning(|_| { + Ok(tonic::Response::new(v1::CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 1000, + nanos: 0, + }), + multiplexed_session_retry: Some( + v1::commit_response::MultiplexedSessionRetry::PrecommitToken( + v1::MultiplexedSessionPrecommitToken { + precommit_token: vec![2], + seq_num: 2, + }, + ), + ), + ..Default::default() + })) + }); + + // Retry commit returns another retry token (seq 3). + // The library should not retry multiple times. + mock.expect_commit() + .once() + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + assert_eq!( + req.precommit_token + .as_ref() + .expect("Missing precommit token in retry req") + .seq_num, + 2 + ); + + Ok(tonic::Response::new(v1::CommitResponse { + commit_timestamp: Some(prost_types::Timestamp { + seconds: 9999, + nanos: 0, + }), + multiplexed_session_retry: Some( + v1::commit_response::MultiplexedSessionRetry::PrecommitToken( + v1::MultiplexedSessionPrecommitToken { + precommit_token: vec![3], + seq_num: 3, + }, + ), + ), + ..Default::default() + })) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = ReadWriteTransactionBuilder::new(db_client.clone()) + .with_explicit_begin_transaction(explicit_begin) + .build() + .await?; + + if !explicit_begin { + tx.execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1") + .await?; + } + + let ts = tx.commit().await.expect("Failed to commit transaction"); + assert_eq!(ts.seconds(), 9999); + Ok(()) + } + + #[tokio::test] + async fn read_write_transaction_execute_update_inline_begin() { + let mut mock = create_session_mock(); + + mock.expect_execute_sql().once().returning(|req| { + let req = req.into_inner(); + assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1"); + assert_eq!(req.seqno, 1); + + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Begin(options) => { + assert!(options.mode.is_some()); + } + _ => panic!("Expected Selector::Begin"), + } + + Ok(tonic::Response::new(v1::ResultSet { + metadata: Some(v1::ResultSetMetadata { + transaction: Some(v1::Transaction { + id: vec![7, 8, 9], + ..Default::default() + }), + ..Default::default() + }), + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), ..Default::default() })) }); - // Initial commit returns a retry token (seq 2) - mock.expect_commit().once().returning(|_| { + mock.expect_commit().once().returning(|req| { + let req = req.into_inner(); + match req.transaction.expect("missing transaction") { + v1::commit_request::Transaction::TransactionId(id) => { + assert_eq!(id, vec![7, 8, 9]); + } + _ => panic!("Expected TransactionId"), + } Ok(tonic::Response::new(v1::CommitResponse { commit_timestamp: Some(prost_types::Timestamp { - seconds: 1000, + seconds: 123456789, nanos: 0, }), - multiplexed_session_retry: Some( - v1::commit_response::MultiplexedSessionRetry::PrecommitToken( - v1::MultiplexedSessionPrecommitToken { - precommit_token: vec![2], - seq_num: 2, - }, - ), - ), ..Default::default() })) }); - // Retry commit returns another retry token (seq 3). - // The library should not retry multiple times. - mock.expect_commit().once().returning(|req| { + let (db_client, _server) = setup_db_client(mock).await; + + let tx = ReadWriteTransactionBuilder::new(db_client.clone()) + .build() + .await + .expect("Failed to build transaction"); + + let count = tx + .execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1") + .await + .expect("Failed to execute update"); + assert_eq!(count, 1); + + let ts = tx.commit().await.expect("Failed to commit"); + assert_eq!(ts.seconds(), 123456789); + } + + #[tokio::test] + async fn read_write_transaction_execute_batch_update_inline_begin() -> anyhow::Result<()> { + let mut mock = create_session_mock(); + + mock.expect_execute_batch_dml().once().returning(|req| { let req = req.into_inner(); - assert_eq!( - req.precommit_token - .as_ref() - .expect("Missing precommit token in retry req") - .seq_num, - 2 - ); + assert_eq!(req.statements.len(), 1); + + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Begin(options) => { + assert!(options.mode.is_some()); + } + _ => panic!("Expected Selector::Begin"), + } + + Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse { + result_sets: vec![v1::ResultSet { + metadata: Some(v1::ResultSetMetadata { + transaction: Some(v1::Transaction { + id: vec![4, 5, 6], + ..Default::default() + }), + ..Default::default() + }), + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + }], + status: Some(spanner_grpc_mock::google::rpc::Status { + code: 0, + message: "OK".into(), + details: vec![], + }), + ..Default::default() + })) + }); + mock.expect_commit().once().returning(|req| { + let req = req.into_inner(); + match req.transaction.expect("missing transaction") { + v1::commit_request::Transaction::TransactionId(id) => { + assert_eq!(id, vec![4, 5, 6]); + } + _ => panic!("Expected TransactionId"), + } Ok(tonic::Response::new(v1::CommitResponse { commit_timestamp: Some(prost_types::Timestamp { - seconds: 9999, + seconds: 123456789, nanos: 0, }), - multiplexed_session_retry: Some( - v1::commit_response::MultiplexedSessionRetry::PrecommitToken( - v1::MultiplexedSessionPrecommitToken { - precommit_token: vec![3], - seq_num: 3, - }, - ), - ), ..Default::default() })) }); let (db_client, _server) = setup_db_client(mock).await; + + let tx = ReadWriteTransactionBuilder::new(db_client).build().await?; + + let batch = + BatchDml::builder().add_statement("UPDATE Users SET Name = 'Alice' WHERE Id = 1"); + + let counts = tx.execute_batch_update(batch.build()).await?; + + assert_eq!(counts, vec![1]); + + let ts = tx.commit().await?; + assert_eq!(ts.seconds(), 123456789); + + Ok(()) + } + + #[tokio::test] + async fn read_write_transaction_execute_update_fallback() { + let mut mock = create_session_mock(); + + // 1. First DML attempt fails! + mock.expect_execute_sql().once().returning(|req| { + let req = req.into_inner(); + assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1"); + + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin"), + } + + Err(tonic::Status::new(tonic::Code::Internal, "internal error")) + }); + + // 2. Client falls back to explicit BeginTransaction! + mock.expect_begin_transaction().once().returning(|_| { + Ok(tonic::Response::new(v1::Transaction { + id: vec![7, 8, 9], + ..Default::default() + })) + }); + + // 3. Client retries DML with new ID! + mock.expect_execute_sql().once().returning(|req| { + let req = req.into_inner(); + assert_eq!(req.sql, "UPDATE Users SET Name = 'Alice' WHERE Id = 1"); + + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Id(id) => { + assert_eq!(id, vec![7, 8, 9]); + } + _ => panic!("Expected Selector::Id"), + } + + Ok(tonic::Response::new(v1::ResultSet { + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + })) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = ReadWriteTransactionBuilder::new(db_client.clone()) - .begin_transaction() + .build() .await .expect("Failed to build transaction"); - let ts = tx.commit().await.expect("Failed to commit transaction"); - assert_eq!(ts.seconds(), 9999); + let count = tx + .execute_update("UPDATE Users SET Name = 'Alice' WHERE Id = 1") + .await + .expect("Failed to execute update after fallback"); + assert_eq!(count, 1); + } + + #[tokio::test] + async fn read_write_transaction_execute_batch_update_fallback() -> anyhow::Result<()> { + let mut mock = create_session_mock(); + + // 1. First Batch DML attempt fails! + mock.expect_execute_batch_dml().once().returning(|req| { + let req = req.into_inner(); + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin"), + } + + Err(tonic::Status::new(tonic::Code::Internal, "internal error")) + }); + + // 2. Client falls back to explicit BeginTransaction! + mock.expect_begin_transaction().once().returning(|_| { + Ok(tonic::Response::new(v1::Transaction { + id: vec![4, 5, 6], + ..Default::default() + })) + }); + + // 3. Client retries Batch DML with new ID! + mock.expect_execute_batch_dml().once().returning(|req| { + let req = req.into_inner(); + let selector = req + .transaction + .expect("missing transaction selector") + .selector + .expect("missing selector"); + match selector { + Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + } + _ => panic!("Expected Selector::Id"), + } + + Ok(tonic::Response::new(v1::ExecuteBatchDmlResponse { + result_sets: vec![v1::ResultSet { + stats: Some(v1::ResultSetStats { + row_count: Some(RowCount::RowCountExact(1)), + ..Default::default() + }), + ..Default::default() + }], + status: Some(spanner_grpc_mock::google::rpc::Status { + code: 0, + message: "OK".into(), + details: vec![], + }), + ..Default::default() + })) + }); + + let (db_client, _server) = setup_db_client(mock).await; + + let tx = ReadWriteTransactionBuilder::new(db_client).build().await?; + + let batch = + BatchDml::builder().add_statement("UPDATE Users SET Name = 'Alice' WHERE Id = 1"); + + let counts = tx.execute_batch_update(batch.build()).await?; + + assert_eq!(counts, vec![1]); + + Ok(()) } } diff --git a/src/spanner/src/transaction_runner.rs b/src/spanner/src/transaction_runner.rs index 089e6f7d65..3570e51e3e 100644 --- a/src/spanner/src/transaction_runner.rs +++ b/src/spanner/src/transaction_runner.rs @@ -222,7 +222,7 @@ impl TransactionRunner { let mut current_tx_id = None; let attempt_result = async { - let transaction = self.builder.begin_transaction().await?; + let transaction = self.builder.clone().with_explicit_begin_transaction(true).build().await?; current_tx_id = transaction.transaction_id().await.ok(); let result = match work(transaction.clone()).await { From 57eb4035dd1fc002e80ca3033aed9a8a7e5133c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Wed, 8 Apr 2026 08:56:33 +0200 Subject: [PATCH 07/17] fix(spanner): modify constructor call in BatchReadOnlyTransaction --- src/spanner/src/batch_read_only_transaction.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/spanner/src/batch_read_only_transaction.rs b/src/spanner/src/batch_read_only_transaction.rs index 44ddb8cd07..ff59050803 100644 --- a/src/spanner/src/batch_read_only_transaction.rs +++ b/src/spanner/src/batch_read_only_transaction.rs @@ -16,7 +16,7 @@ use crate::database_client::DatabaseClient; use crate::model::PartitionOptions; use crate::precommit::PrecommitTokenTracker; use crate::read_only_transaction::{ - MultiUseReadOnlyTransaction, MultiUseReadOnlyTransactionBuilder, + MultiUseReadOnlyTransaction, MultiUseReadOnlyTransactionBuilder, ReadContextTransactionSelector, }; use crate::result_set::{ResultSet, StreamOperation}; use crate::statement::Statement; @@ -345,6 +345,10 @@ impl Partition { Ok(ResultSet::new( stream, + Some(ReadContextTransactionSelector::Fixed( + transaction_selector.clone(), + None, + )), PrecommitTokenTracker::new_noop(), client.clone(), StreamOperation::Query(request), @@ -374,6 +378,10 @@ impl Partition { Ok(ResultSet::new( stream, + Some(ReadContextTransactionSelector::Fixed( + transaction_selector.clone(), + None, + )), PrecommitTokenTracker::new_noop(), client.clone(), StreamOperation::Read(request), From d8af76e44d5049df1326f19625ed2c28c9f8daaa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Wed, 8 Apr 2026 09:35:55 +0200 Subject: [PATCH 08/17] test(spanner): add missing test for Read --- src/spanner/src/read_only_transaction.rs | 97 ++++++++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/src/spanner/src/read_only_transaction.rs b/src/spanner/src/read_only_transaction.rs index 44782326f3..ea57a75853 100644 --- a/src/spanner/src/read_only_transaction.rs +++ b/src/spanner/src/read_only_transaction.rs @@ -922,4 +922,101 @@ pub(crate) mod tests { let result = rs.next().await; assert!(result.is_none(), "expected None, got {result:?}"); } + + #[tokio::test] + async fn execute_multi_read() -> anyhow::Result<()> { + use super::super::result_set::tests::string_val; + use crate::client::{KeySet, ReadRequest}; + use crate::value::Value; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + + let mut mock = create_session_mock(); + + // No explicit begin_transaction should be called. + mock.expect_begin_transaction().never(); + + let mut seq = mockall::Sequence::new(); + + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + + // First call: Should have Selector::Begin + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin"), + } + let mut rs = setup_select1(); + rs.metadata.as_mut().unwrap().transaction = Some(mock_v1::Transaction { + id: vec![4, 5, 6], + read_timestamp: Some(prost_types::Timestamp { + seconds: 987654321, + nanos: 0, + }), + ..Default::default() + }); + Ok(gaxi::grpc::tonic::Response::new(Box::pin( + tokio_stream::iter(vec![Ok(rs)]), + ))) + }); + + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + // Second call: Should have Selector::Id using the ID returned in the first call + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + } + _ => panic!("Expected Selector::Id"), + } + Ok(gaxi::grpc::tonic::Response::new(Box::pin( + tokio_stream::iter(vec![Ok(setup_select1())]), + ))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + // The read timestamp is not available until the first query is executed. + assert!(tx.read_timestamp().is_none()); + + for i in 0..2 { + let read = ReadRequest::builder("Users", vec!["Id", "Name"]) + .with_keys(KeySet::all()) + .build(); + let mut rs = tx.execute_read(read).await?; + + let row = rs.next().await.expect("Expected a row")?; + assert_eq!(row.raw_values(), [Value(string_val("1"))]); + + let result = rs.next().await; + assert!(result.is_none(), "Expected None, got {result:?}"); + + if i == 0 { + // Read timestamp becomes available. + assert_eq!( + tx.read_timestamp() + .expect("Expected read timestamp") + .seconds(), + 987654321 + ); + } + } + + Ok(()) + } } From c3e8b332ca0629f0f439e7761a7d401d45fd934c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 2 Apr 2026 11:26:16 +0200 Subject: [PATCH 09/17] perf(spanner): inline begin transaction error handling Adds error handling for inline-begin-transaction. If the first statement in a transaction fails, and that statement included a BeginTransaction option, then the transaction has not been started. In order to keep the semantics of the transaction consistent for an 'outside observer', we need to do the following: 1. Catch the error that was thrown by the initial statement. 2. Start the transaction using an explicit BeginTransaction RPC. 3. Retry the initial statement, but now using the transaction ID from step 2. 4. Return the error or result for the retried initial statement. The above makes sure that: 1. The transaction is actually started when the first statement is executed, also when the statement failed. 2. The statement becomes part of the transaction, and the result of the statement is consistent with the read-timestamp of the transaction. The second part is important in order to comply with Spanner's strong consistency guarantees; If for example a statement returns a 'Table not found' error, then that error is only valid for the read timestamp that was used for executing the statement. This is the reason that we retry the statement after the BeginTransaction RPC to be able to return a result that is guaranteed to be consistent with any other queries/reads that will be executed in the same transaction. --- src/spanner/src/read_only_transaction.rs | 486 ++++++++++++++++++++-- src/spanner/src/result_set.rs | 503 ++++++++++++++++++++++- tests/spanner/src/query.rs | 39 ++ tests/spanner/tests/driver.rs | 4 + 4 files changed, 974 insertions(+), 58 deletions(-) diff --git a/src/spanner/src/read_only_transaction.rs b/src/spanner/src/read_only_transaction.rs index ea57a75853..67ca13feb4 100644 --- a/src/spanner/src/read_only_transaction.rs +++ b/src/spanner/src/read_only_transaction.rs @@ -275,16 +275,7 @@ impl MultiUseReadOnlyTransactionBuilder { &self, options: TransactionOptions, ) -> crate::Result { - let request = crate::model::BeginTransactionRequest::default() - .set_session(self.client.session.name.clone()) - .set_options(options); - - // TODO(#4972): make request options configurable - let response = self - .client - .spanner - .begin_transaction(request, crate::RequestOptions::default()) - .await?; + let response = execute_begin_transaction(&self.client, options).await?; let transaction_selector = crate::model::TransactionSelector::default().set_id(response.id); @@ -429,6 +420,22 @@ impl MultiUseReadOnlyTransaction { } } +/// Executes an explicit `BeginTransaction` RPC on Spanner. +async fn execute_begin_transaction( + client: &crate::database_client::DatabaseClient, + options: crate::model::TransactionOptions, +) -> crate::Result { + let request = crate::model::BeginTransactionRequest::default() + .set_session(client.session.name.clone()) + .set_options(options); + + // TODO(#4972): make request options configurable + client + .spanner + .begin_transaction(request, crate::RequestOptions::default()) + .await +} + #[derive(Clone, Debug)] pub(crate) enum ReadContextTransactionSelector { Fixed(crate::model::TransactionSelector, Option), @@ -463,6 +470,32 @@ impl ReadContextTransactionSelector { } } + /// Explicitly begins a transaction if the transaction selector is a `Lazy` + /// selector and the transaction has not yet been started. This is used by + /// the client to force the start of a transaction if the first statement + /// failed. + pub(crate) async fn begin_explicitly( + &self, + client: &crate::database_client::DatabaseClient, + ) -> crate::Result<()> { + let Self::Lazy(lazy) = self else { + return Ok(()); + }; + + let options = { + let guard = lazy.lock().expect("transaction state mutex poisoned"); + let TransactionState::NotStarted(options) = &*guard else { + return Ok(()); + }; + options.clone() + }; + + let response = execute_begin_transaction(client, options).await?; + self.update(response.id, response.read_timestamp); + + Ok(()) + } + pub(crate) fn update(&self, id: bytes::Bytes, timestamp: Option) { if let Self::Lazy(lazy) = self { let mut guard = lazy.lock().expect("transaction state mutex poisoned"); @@ -517,6 +550,64 @@ impl ReadContext { options } + /// Attempts to execute an explicit `begin_transaction` RPC if the current transaction + /// selector is still in the `Lazy(NotStarted)` state. This is used as a + /// fallback mechanism when an initial implicit begin attempt failed. + async fn begin_explicitly_if_not_started(&self) -> crate::Result { + let ReadContextTransactionSelector::Lazy(lazy) = &self.transaction_selector else { + return Ok(false); + }; + let is_started = matches!(&*lazy.lock().unwrap(), TransactionState::Started(_, _)); + if is_started { + return Ok(false); + } + + self.transaction_selector + .begin_explicitly(&self.client) + .await?; + Ok(true) + } +} + +/// Helper macro to execute a streaming SQL or streaming read RPC with retry logic. +macro_rules! execute_stream_with_retry { + ($self:expr, $request:ident, $rpc_method:ident, $operation_variant:path) => {{ + let stream = match $self + .client + .spanner + // TODO(#4972): make request options configurable + .$rpc_method($request.clone(), crate::RequestOptions::default()) + .send() + .await + { + Ok(s) => s, + Err(e) => { + if $self.begin_explicitly_if_not_started().await? { + $request.transaction = Some($self.transaction_selector.selector()); + $self + .client + .spanner + // TODO(#4972): make request options configurable + .$rpc_method($request.clone(), crate::RequestOptions::default()) + .send() + .await? + } else { + return Err(e); + } + } + }; + + Ok(ResultSet::new( + stream, + Some($self.transaction_selector.clone()), + $self.precommit_token_tracker.clone(), + $self.client.clone(), + $operation_variant($request), + )) + }}; +} + +impl ReadContext { pub(crate) async fn execute_query>( &self, statement: T, @@ -528,21 +619,7 @@ impl ReadContext { .set_transaction(self.transaction_selector.selector()); request.request_options = self.amend_request_options(request.request_options); - let stream = self - .client - .spanner - // TODO(#4972): make request options configurable - .execute_streaming_sql(request.clone(), crate::RequestOptions::default()) - .send() - .await?; - - Ok(ResultSet::new( - stream, - Some(self.transaction_selector.clone()), - self.precommit_token_tracker.clone(), - self.client.clone(), - StreamOperation::Query(request), - )) + execute_stream_with_retry!(self, request, execute_streaming_sql, StreamOperation::Query) } pub(crate) async fn execute_read>( @@ -556,27 +633,15 @@ impl ReadContext { .set_transaction(self.transaction_selector.selector()); request.request_options = self.amend_request_options(request.request_options); - let stream = self - .client - .spanner - // TODO(#4972): make request options configurable - .streaming_read(request.clone(), crate::RequestOptions::default()) - .send() - .await?; - - Ok(ResultSet::new( - stream, - Some(self.transaction_selector.clone()), - self.precommit_token_tracker.clone(), - self.client.clone(), - StreamOperation::Read(request), - )) + execute_stream_with_retry!(self, request, streaming_read, StreamOperation::Read) } } #[cfg(test)] pub(crate) mod tests { use super::*; + use crate::result_set::tests::string_val; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; #[test] fn auto_traits() { @@ -1019,4 +1084,345 @@ pub(crate) mod tests { Ok(()) } + + #[tokio::test] + async fn inline_begin_failure_retry_success() -> anyhow::Result<()> { + use crate::value::Value; + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial query fails + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error"))); + + // 2. Explicit begin transaction succeeds + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + // Return a transaction with ID + Ok(Response::new(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: Some(prost_types::Timestamp { + seconds: 123456789, + nanos: 0, + }), + ..Default::default() + })) + }); + + // 3. Retry of the query succeeds + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + // Ensure it uses the new transaction ID + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![7, 8, 9]); + } + _ => panic!("Expected Selector::Id"), + } + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( + setup_select1(), + )])))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let mut rs = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + + let row = rs + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Expected a row but stream cleanly exhausted"))??; + assert_eq!( + row.raw_values(), + [Value(string_val("1"))], + "The parsed row value safely matched the underlying stream chunk" + ); + + Ok(()) + } + + #[tokio::test] + async fn inline_begin_failure_retry_failure() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial query fails + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error first"))); + + // 2. Explicit begin transaction succeeds + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: Some(prost_types::Timestamp { + seconds: 123456789, + nanos: 0, + }), + ..Default::default() + })) + }); + + // 3. Retry of the query fails again + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error second"))); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let rs_result = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await; + + assert!( + rs_result.is_err(), + "The failed execution bubbled upwards securely" + ); + let err_str = rs_result.unwrap_err().to_string(); + assert!( + err_str.contains("Internal error second"), + "Secondary error message accurately propagates: {}", + err_str + ); + + Ok(()) + } + + #[tokio::test] + async fn inline_begin_failure_fallback_rpc_fails() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Status; + + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial query fails + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error query"))); + + // 2. Explicit begin transaction fails + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error begin tx"))); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let rs_result = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await; + + assert!( + rs_result.is_err(), + "The explicitly errored fallback boot securely propagated outwards" + ); + let err_str = rs_result.unwrap_err().to_string(); + assert!( + err_str.contains("Internal error begin tx"), + "Natively propagated specific BeginTx bounds: {}", + err_str + ); + + Ok(()) + } + + #[tokio::test] + async fn inline_begin_read_failure_retry_success() -> anyhow::Result<()> { + use crate::client::{KeySet, ReadRequest}; + use crate::value::Value; + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial read fails + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error"))); + + // 2. Explicit begin transaction succeeds + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: None, + ..Default::default() + })) + }); + + // 3. Retry of the read succeeds + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + // Ensure it uses the new transaction ID + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![7, 8, 9]); + } + _ => panic!("Expected Selector::Id"), + } + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( + setup_select1(), + )])))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let read = ReadRequest::builder("Users", vec!["Id", "Name"]) + .with_keys(KeySet::all()) + .build(); + let mut rs = tx.execute_read(read).await?; + + let row = rs + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Expected a row uniquely returned"))??; + assert_eq!( + row.raw_values(), + [Value(string_val("1"))], + "The macro correctly unpacked read arrays seamlessly" + ); + + Ok(()) + } + + #[tokio::test] + async fn single_use_query_send_error_returns_immediately() -> anyhow::Result<()> { + use crate::client::Statement; + use gaxi::grpc::tonic::Status; + + let mut mock = create_session_mock(); + + mock.expect_execute_streaming_sql() + .times(1) + .returning(|_| Err(Status::internal("Internal error single use query"))); + + mock.expect_begin_transaction().never(); + + let (db_client, _server) = setup_db_client(mock).await; + // single_use creates a Fixed selector + let tx = db_client.single_use().build(); + + let rs_result = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await; + + assert!(rs_result.is_err()); + let err_str = rs_result.unwrap_err().to_string(); + assert!(err_str.contains("Internal error single use query")); + + Ok(()) + } + + #[tokio::test] + async fn inline_begin_already_started_query_send_error_returns_immediately() + -> anyhow::Result<()> { + use crate::client::Statement; + use gaxi::grpc::tonic::Status; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + mock.expect_begin_transaction().never(); + + // 1. First query executes successfully and implicitly starts the transaction. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(move |_req| { + let mut rs = setup_select1(); + rs.metadata.as_mut().unwrap().transaction = Some(mock_v1::Transaction { + id: vec![4, 5, 6], + read_timestamp: None, + ..Default::default() + }); + Ok(gaxi::grpc::tonic::Response::new(Box::pin( + tokio_stream::iter(vec![Ok(rs)]), + ))) + }); + + // 2. Second query fails immediately upon send() + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error second query"))); + + let (db_client, _server) = setup_db_client(mock).await; + + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + // Run first query (starts tx) + let mut rs = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + let _ = rs.next().await.expect("has row")?; + + // Run second query (fails) + let rs_result = tx + .execute_query(Statement::builder("SELECT 2").build()) + .await; + + assert!(rs_result.is_err()); + let err_str = rs_result.unwrap_err().to_string(); + assert!(err_str.contains("Internal error second query")); + + Ok(()) + } } diff --git a/src/spanner/src/result_set.rs b/src/spanner/src/result_set.rs index 6bd848b175..2d775b7412 100644 --- a/src/spanner/src/result_set.rs +++ b/src/spanner/src/result_set.rs @@ -233,7 +233,29 @@ impl ResultSet { return Ok(()); } - Err(e) + // Check if this stream included an inlined BeginTransaction option + // and has not yet returned a transaction ID. If so, we explicitly + // begin the transaction and restart the stream. + let Some(ReadContextTransactionSelector::Lazy(lazy)) = &self.transaction_selector else { + return Err(e); + }; + let is_started = matches!( + &*lazy.lock().unwrap(), + crate::read_only_transaction::TransactionState::Started(_, _) + ); + if is_started { + return Err(e); + } + + self.transaction_selector + .as_ref() + .unwrap() + .begin_explicitly(&self.client) + .await?; + + self.partial_result_sets_buffer.clear(); + self.restart_stream().await?; + Ok(()) } fn handle_stream_end(&mut self) -> crate::Result> { @@ -281,15 +303,25 @@ impl ResultSet { (None, Some(mut m)) => { let transaction = m.transaction.take(); self.metadata = Some(ResultSetMetadata::new(Some(m))); - if let (Some(selector), Some(transaction)) = - (&self.transaction_selector, transaction) - { - selector.update( - transaction.id, - transaction - .read_timestamp - .and_then(|t| wkt::Timestamp::new(t.seconds, t.nanos).ok()), - ); + if let Some(selector) = &self.transaction_selector { + if let Some(transaction) = transaction { + selector.update( + transaction.id, + transaction + .read_timestamp + .and_then(|t| wkt::Timestamp::new(t.seconds, t.nanos).ok()), + ); + } else if let ReadContextTransactionSelector::Lazy(lazy) = selector { + let is_started = matches!( + &*lazy.lock().expect("transaction state mutex poisoned"), + crate::read_only_transaction::TransactionState::Started(_, _) + ); + if !is_started { + return Err(internal_error( + "Spanner failed to return a transaction ID for a query that included a BeginTransaction option", + )); + } + } } } } @@ -336,9 +368,15 @@ impl ResultSet { } async fn restart_stream(&mut self) -> crate::Result<()> { + // Get the latest transaction selector for this transaction. + let transaction_selector = self.transaction_selector.as_ref().map(|s| s.selector()); + match &mut self.operation { StreamOperation::Query(req) => { req.resume_token = self.last_resume_token.clone(); + req.transaction = transaction_selector + .clone() + .or_else(|| req.transaction.take()); let stream = self .client .spanner @@ -349,6 +387,9 @@ impl ResultSet { } StreamOperation::Read(req) => { req.resume_token = self.last_resume_token.clone(); + req.transaction = transaction_selector + .clone() + .or_else(|| req.transaction.take()); let stream = self .client .spanner @@ -465,6 +506,7 @@ pub(crate) mod tests { use super::*; use crate::client::Spanner; use gaxi::grpc::tonic::Response; + use google_cloud_auth::credentials::anonymous::Builder as Anonymous; use prost_types::Value; use spanner_grpc_mock::MockSpanner; use spanner_grpc_mock::google::spanner::v1::spanner_server::Spanner as SpannerTrait; @@ -528,7 +570,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await .expect("Failed to build client"); @@ -564,6 +606,39 @@ pub(crate) mod tests { assert!(next.is_none()); } + #[tokio::test] + async fn test_result_set_metadata() -> anyhow::Result<()> { + let mut rs = run_mock_query(vec![PartialResultSet { + metadata: metadata(2), + values: vec![string_val("a"), string_val("b")], + last: true, + ..Default::default() + }]) + .await; + + // Called before next() -> returns MetadataNotAvailable + let meta_err = rs.metadata(); + assert!(meta_err.is_err()); + assert!(matches!( + meta_err.unwrap_err(), + ResultSetError::MetadataNotAvailable + )); + + // Advance to fetch metadata + let _next = rs.next().await.expect("Expected a row")?; + + // Called after next() -> returns metadata + let meta = rs.metadata(); + assert!(meta.is_ok()); + let meta = meta.unwrap(); + assert_eq!( + meta.column_names(), + &["col0".to_string(), "col1".to_string()] + ); + + Ok(()) + } + #[tokio::test] async fn test_result_set_handle_partial_result_set_error() -> anyhow::Result<()> { let mut rs = run_mock_query(vec![PartialResultSet { @@ -586,6 +661,34 @@ pub(crate) mod tests { Ok(()) } + #[tokio::test] + async fn test_result_set_handle_partial_result_set_error_immediate() -> anyhow::Result<()> { + let mut rs = run_mock_query(vec![ + PartialResultSet { + values: vec![string_val("row1")], + ..Default::default() + }, + PartialResultSet { + resume_token: b"token".to_vec(), + ..Default::default() + }, + ]) + .await; + + let res = rs.next().await; + assert!(res.is_some(), "Expected an error but got None"); + let res = res.expect("Expected some response but got None"); + assert!(res.is_err(), "Expected an error but got Ok"); + let err_str = res.expect_err("Expected should be an error").to_string(); + assert!( + err_str.contains("First PartialResultSet did not contain metadata"), + "Expected error to contain 'First PartialResultSet did not contain metadata', but got '{}'", + err_str + ); + + Ok(()) + } + #[tokio::test] async fn test_result_set_stream_ended_with_chunked_value() -> anyhow::Result<()> { let mut rs = run_mock_query(vec![PartialResultSet { @@ -725,7 +828,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -1081,7 +1184,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -1137,7 +1240,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -1207,7 +1310,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -1296,7 +1399,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -1375,7 +1478,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -1427,7 +1530,7 @@ pub(crate) mod tests { let client: Spanner = Spanner::builder() .with_endpoint(address) - .with_credentials(google_cloud_auth::credentials::anonymous::Builder::new().build()) + .with_credentials(Anonymous::new().build()) .build() .await?; @@ -1448,4 +1551,368 @@ pub(crate) mod tests { Ok(()) } + + #[tokio::test] + async fn result_set_inline_begin_stream_error_fallback() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + let mut seq = mockall::Sequence::new(); + + // 1. Stream yields an error on the first chunk before returning transaction metadata. + // E.g., INVALID_ARGUMENT because the query is malformed. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let stream = + tokio_stream::iter(vec![Err(Status::invalid_argument("Invalid query"))]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + // 2. The explicit BeginTransaction fallback gets triggered. + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: Some(prost_types::Timestamp { + seconds: 123456789, + nanos: 0, + }), + ..Default::default() + })) + }); + + // 3. The ResultSet gracefully restarts the stream using the transaction ID returned by BeginTransaction. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + // Ensure the explicitly yielded ID is routed into the new stream transaction selector + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![7, 8, 9]); + } + _ => panic!("Expected Selector::Id"), + } + + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: metadata(1), + values: vec![string_val("1")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let mut rs = tx.execute_query("SELECT 1").await?; + + let row1 = rs.next().await.ok_or_else(|| { + anyhow::anyhow!("Expected row returned successfully despite stream breaking") + })??; + assert_eq!( + row1.raw_values()[0].0, + string_val("1"), + "Verify the returned stream successfully resumed with the correct payload" + ); + + Ok(()) + } + + #[tokio::test] + async fn result_set_retry_inline_begin_transient_error() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial stream throws UNAVAILABLE before metadata. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let stream = + tokio_stream::iter(vec![Err(Status::unavailable("Transient network issue"))]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + // 2. We retry the stream since it was a transient error. + // The retry should use the same transaction selector as the original request. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin on stream retry"), + } + + let mut meta = metadata(1).unwrap(); + meta.transaction = Some(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: None, + ..Default::default() + }); + + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: Some(meta), + values: vec![string_val("1")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let mut rs = tx.execute_query("SELECT 1").await?; + + let row1 = rs + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Expected stream to recover safely"))??; + assert_eq!( + row1.raw_values()[0].0, + string_val("1"), + "Verify resumed stream returns data" + ); + + Ok(()) + } + + #[tokio::test] + async fn result_set_retry_inline_begin_id_recovered() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + let mut seq = mockall::Sequence::new(); + + // 1. Stream successfully returns metadata chunk then throws UNAVAILABLE on chunk 2. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let mut meta = metadata(1).unwrap(); + meta.transaction = Some(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: None, + ..Default::default() + }); + let stream = tokio_stream::iter(vec![ + Ok(PartialResultSet { + metadata: Some(meta), + values: vec![string_val("1")], + resume_token: b"token1".to_vec(), + ..Default::default() + }), + Err(Status::unavailable("Transient mid-stream network issue")), + ]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + // 2. Stream resumes using Selector::Id. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![7, 8, 9]); + } + _ => panic!("Expected Selector::Id on stream retry"), + } + + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + values: vec![string_val("2")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let mut rs = tx.execute_query("SELECT 1").await?; + + let row1 = rs + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Expected stream row1 extracted"))??; + assert_eq!( + row1.raw_values()[0].0, + string_val("1"), + "Verified chunk 1 payload" + ); + let row2 = rs + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Expected stream row2 recovered"))??; + assert_eq!( + row2.raw_values()[0].0, + string_val("2"), + "Verified chunk 2 reboot dynamically intercepted ID bounds correctly" + ); + + Ok(()) + } + + #[tokio::test] + async fn result_set_inline_begin_metadata_missing_transaction_fails() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial stream successfully returns metadata chunk but completely lacks the `Transaction` entity. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: metadata(1), // Missing `.transaction` natively + values: vec![string_val("1")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + + // Use explicitly deferred Lazy begin transaction! + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let mut rs = tx.execute_query("SELECT 1").await?; + + let rs_result = rs + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Expected explicit crash bound properly"))?; + assert!( + rs_result.is_err(), + "Securely aborted when metadata failed to package internal bounds properly" + ); + + let err_str = rs_result.unwrap_err().to_string(); + assert!( + err_str.contains("failed to return a transaction ID"), + "Caught implicit gap boundary: {}", + err_str + ); + + Ok(()) + } } diff --git a/tests/spanner/src/query.rs b/tests/spanner/src/query.rs index fe02427a19..0f1a54553e 100644 --- a/tests/spanner/src/query.rs +++ b/tests/spanner/src/query.rs @@ -246,6 +246,45 @@ async fn test_multi_use_read_only_transaction( Ok(()) } +pub async fn multi_use_read_only_transaction_invalid_query_fallback( + db_client: &DatabaseClient, +) -> anyhow::Result<()> { + // Start a multi-use read-only transaction with implicit begin. + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + // Expect a read timestamp to NOT have been chosen yet. + assert!(tx.read_timestamp().is_none()); + + // Execute the first query with invalid syntax. + let rs_result = tx + .execute_query(Statement::builder("SELECT * FROM NonExistentTable").build()) + .await; + + assert!( + rs_result.is_err(), + "Expected an error from an invalid query" + ); + + // The read timestamp should now be available because the transaction + // fell back to an explicit BeginTransaction. + assert!(tx.read_timestamp().is_some()); + + // It should be possible to use the transaction. + let mut rs2 = tx + .execute_query(Statement::builder("SELECT 2 AS col_int").build()) + .await?; + + let row2 = rs2.next().await.transpose()?.expect("should yield a row"); + let val2 = row2.raw_values()[0].as_string(); + assert_eq!(val2, "2"); + + Ok(()) +} + fn verify_null_row(row: &google_cloud_spanner::client::Row) { let raw_values = row.raw_values(); assert_eq!(raw_values.len(), 20, "Row should have exactly 20 columns"); diff --git a/tests/spanner/tests/driver.rs b/tests/spanner/tests/driver.rs index d27cc10dc5..2492088f67 100644 --- a/tests/spanner/tests/driver.rs +++ b/tests/spanner/tests/driver.rs @@ -26,6 +26,10 @@ mod spanner { integration_tests_spanner::query::query_with_parameters(&db_client).await?; integration_tests_spanner::query::result_set_metadata(&db_client).await?; integration_tests_spanner::query::multi_use_read_only_transaction(&db_client).await?; + integration_tests_spanner::query::multi_use_read_only_transaction_invalid_query_fallback( + &db_client, + ) + .await?; Ok(()) } From a0cfa919c19fbb23142f4c503ec46a755ccbb6ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 2 Apr 2026 11:26:16 +0200 Subject: [PATCH 10/17] perf(spanner): inline begin transaction error handling Adds error handling for inline-begin-transaction. If the first statement in a transaction fails, and that statement included a BeginTransaction option, then the transaction has not been started. In order to keep the semantics of the transaction consistent for an 'outside observer', we need to do the following: 1. Catch the error that was thrown by the initial statement. 2. Start the transaction using an explicit BeginTransaction RPC. 3. Retry the initial statement, but now using the transaction ID from step 2. 4. Return the error or result for the retried initial statement. The above makes sure that: 1. The transaction is actually started when the first statement is executed, also when the statement failed. 2. The statement becomes part of the transaction, and the result of the statement is consistent with the read-timestamp of the transaction. The second part is important in order to comply with Spanner's strong consistency guarantees; If for example a statement returns a 'Table not found' error, then that error is only valid for the read timestamp that was used for executing the statement. This is the reason that we retry the statement after the BeginTransaction RPC to be able to return a result that is guaranteed to be consistent with any other queries/reads that will be executed in the same transaction. --- src/spanner/src/read_only_transaction.rs | 341 +++++++++++++++++++++++ tests/spanner/src/query.rs | 36 +++ 2 files changed, 377 insertions(+) diff --git a/src/spanner/src/read_only_transaction.rs b/src/spanner/src/read_only_transaction.rs index 67ca13feb4..4856a0495f 100644 --- a/src/spanner/src/read_only_transaction.rs +++ b/src/spanner/src/read_only_transaction.rs @@ -1425,4 +1425,345 @@ pub(crate) mod tests { Ok(()) } + + #[tokio::test] + async fn inline_begin_failure_retry_success() -> anyhow::Result<()> { + use crate::value::Value; + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial query fails + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error"))); + + // 2. Explicit begin transaction succeeds + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + assert_eq!( + req.session, + "projects/p/instances/i/databases/d/sessions/123" + ); + // Return a transaction with ID + Ok(Response::new(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: Some(prost_types::Timestamp { + seconds: 123456789, + nanos: 0, + }), + ..Default::default() + })) + }); + + // 3. Retry of the query succeeds + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + // Ensure it uses the new transaction ID + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![7, 8, 9]); + } + _ => panic!("Expected Selector::Id"), + } + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( + setup_select1(), + )])))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let mut rs = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + + let row = rs + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Expected a row but stream cleanly exhausted"))??; + assert_eq!( + row.raw_values(), + [Value(string_val("1"))], + "The parsed row value safely matched the underlying stream chunk" + ); + + Ok(()) + } + + #[tokio::test] + async fn inline_begin_failure_retry_failure() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial query fails + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error first"))); + + // 2. Explicit begin transaction succeeds + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: Some(prost_types::Timestamp { + seconds: 123456789, + nanos: 0, + }), + ..Default::default() + })) + }); + + // 3. Retry of the query fails again + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error second"))); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let rs_result = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await; + + assert!( + rs_result.is_err(), + "The failed execution bubbled upwards securely" + ); + let err_str = rs_result.unwrap_err().to_string(); + assert!( + err_str.contains("Internal error second"), + "Secondary error message accurately propagates: {}", + err_str + ); + + Ok(()) + } + + #[tokio::test] + async fn inline_begin_failure_fallback_rpc_fails() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Status; + + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial query fails + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error query"))); + + // 2. Explicit begin transaction fails + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error begin tx"))); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let rs_result = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await; + + assert!( + rs_result.is_err(), + "The explicitly errored fallback boot securely propagated outwards" + ); + let err_str = rs_result.unwrap_err().to_string(); + assert!( + err_str.contains("Internal error begin tx"), + "Natively propagated specific BeginTx bounds: {}", + err_str + ); + + Ok(()) + } + + #[tokio::test] + async fn inline_begin_read_failure_retry_success() -> anyhow::Result<()> { + use crate::client::{KeySet, ReadRequest}; + use crate::value::Value; + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial read fails + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error"))); + + // 2. Explicit begin transaction succeeds + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(mock_v1::Transaction { + id: vec![7, 8, 9], + read_timestamp: None, + ..Default::default() + })) + }); + + // 3. Retry of the read succeeds + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + // Ensure it uses the new transaction ID + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![7, 8, 9]); + } + _ => panic!("Expected Selector::Id"), + } + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( + setup_select1(), + )])))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let read = ReadRequest::builder("Users", vec!["Id", "Name"]) + .with_keys(KeySet::all()) + .build(); + let mut rs = tx.execute_read(read).await?; + + let row = rs + .next() + .await + .ok_or_else(|| anyhow::anyhow!("Expected a row uniquely returned"))??; + assert_eq!( + row.raw_values(), + [Value(string_val("1"))], + "The macro correctly unpacked read arrays seamlessly" + ); + + Ok(()) + } + + #[tokio::test] + async fn single_use_query_send_error_returns_immediately() -> anyhow::Result<()> { + use crate::client::Statement; + use gaxi::grpc::tonic::Status; + + let mut mock = create_session_mock(); + + mock.expect_execute_streaming_sql() + .times(1) + .returning(|_| Err(Status::internal("Internal error single use query"))); + + mock.expect_begin_transaction().never(); + + let (db_client, _server) = setup_db_client(mock).await; + // single_use creates a Fixed selector + let tx = db_client.single_use().build(); + + let rs_result = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await; + + assert!(rs_result.is_err()); + let err_str = rs_result.unwrap_err().to_string(); + assert!(err_str.contains("Internal error single use query")); + + Ok(()) + } + + #[tokio::test] + async fn inline_begin_already_started_query_send_error_returns_immediately() + -> anyhow::Result<()> { + use crate::client::Statement; + use gaxi::grpc::tonic::Status; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + mock.expect_begin_transaction().never(); + + // 1. First query executes successfully and implicitly starts the transaction. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(move |_req| { + let mut rs = setup_select1(); + rs.metadata.as_mut().unwrap().transaction = Some(mock_v1::Transaction { + id: vec![4, 5, 6], + read_timestamp: None, + ..Default::default() + }); + Ok(gaxi::grpc::tonic::Response::new(Box::pin( + tokio_stream::iter(vec![Ok(rs)]), + ))) + }); + + // 2. Second query fails immediately upon send() + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Internal error second query"))); + + let (db_client, _server) = setup_db_client(mock).await; + + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + // Run first query (starts tx) + let mut rs = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + let _ = rs.next().await.expect("has row")?; + + // Run second query (fails) + let rs_result = tx + .execute_query(Statement::builder("SELECT 2").build()) + .await; + + assert!(rs_result.is_err()); + let err_str = rs_result.unwrap_err().to_string(); + assert!(err_str.contains("Internal error second query")); + + Ok(()) + } } diff --git a/tests/spanner/src/query.rs b/tests/spanner/src/query.rs index 0f1a54553e..b62b84a99c 100644 --- a/tests/spanner/src/query.rs +++ b/tests/spanner/src/query.rs @@ -285,6 +285,42 @@ pub async fn multi_use_read_only_transaction_invalid_query_fallback( Ok(()) } +pub async fn multi_use_read_only_transaction_invalid_query_fallback( + db_client: &DatabaseClient, +) -> anyhow::Result<()> { + // Start a multi-use read-only transaction with implicit begin. + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + // Expect a read timestamp to NOT have been chosen yet. + assert!(tx.read_timestamp().is_none()); + + // Execute the first query with invalid syntax. + let rs_result = tx + .execute_query(Statement::builder("SELECT * FROM NonExistentTable").build()) + .await; + + assert!(rs_result.is_err(), "Expected an error from an invalid query"); + + // The read timestamp should now be available because the transaction + // fell back to an explicit BeginTransaction. + assert!(tx.read_timestamp().is_some()); + + // It should be possible to use the transaction. + let mut rs2 = tx + .execute_query(Statement::builder("SELECT 2 AS col_int").build()) + .await?; + + let row2 = rs2.next().await.transpose()?.expect("should yield a row"); + let val2 = row2.raw_values()[0].as_string(); + assert_eq!(val2, "2"); + + Ok(()) +} + fn verify_null_row(row: &google_cloud_spanner::client::Row) { let raw_values = row.raw_values(); assert_eq!(raw_values.len(), 20, "Row should have exactly 20 columns"); From 0efcab82a6697fb7e2cc4b979ab83c253e9d8628 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 2 Apr 2026 20:49:19 +0200 Subject: [PATCH 11/17] test(spanner): add integration test for inline-begin error handling Adds an integration test for error handling for inline-begin-transaction. This test uses a gRPC proxy to intercept calls from the client to Spanner to be able to deterministically emulate specific concurrency issues. This test shows how a query that failed during the first attempt, and thereby also failed to start the transaction, could succeed during a retry after the transaction has been started with an explicit BeginTransaction RPC. --- Cargo.lock | 3 + deny.toml | 1 + src/spanner/grpc-mock/src/lib.rs | 1 + src/spanner/src/read_only_transaction.rs | 341 ----------------------- tests/spanner/Cargo.toml | 3 + tests/spanner/src/client.rs | 70 ++++- tests/spanner/src/lib.rs | 1 + tests/spanner/src/query.rs | 184 +++++++++--- tests/spanner/src/test_proxy.rs | 269 ++++++++++++++++++ tests/spanner/tests/driver.rs | 1 + 10 files changed, 487 insertions(+), 387 deletions(-) create mode 100644 tests/spanner/src/test_proxy.rs diff --git a/Cargo.lock b/Cargo.lock index ec1f0475c9..52c1732b1c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5914,7 +5914,10 @@ dependencies = [ "prost-types", "reqwest 0.13.2", "serde_json", + "spanner-grpc-mock", "tokio", + "tokio-stream", + "tonic", "tracing", ] diff --git a/deny.toml b/deny.toml index fa61a89904..67b8359d88 100644 --- a/deny.toml +++ b/deny.toml @@ -115,6 +115,7 @@ wrappers = [ # Use in tests is fine. "grpc-server", "integration-tests-o11y", + "integration-tests-spanner", "pubsub-grpc-mock", "spanner-grpc-mock", "storage-grpc-mock", diff --git a/src/spanner/grpc-mock/src/lib.rs b/src/spanner/grpc-mock/src/lib.rs index 72b19b07bd..10a3fb74c6 100644 --- a/src/spanner/grpc-mock/src/lib.rs +++ b/src/spanner/grpc-mock/src/lib.rs @@ -64,6 +64,7 @@ pub mod google { include!("generated/protos/google.rpc.rs"); } pub mod spanner { + #[allow(rustdoc::broken_intra_doc_links, rustdoc::bare_urls)] pub mod v1 { include!("generated/protos/google.spanner.v1.rs"); } diff --git a/src/spanner/src/read_only_transaction.rs b/src/spanner/src/read_only_transaction.rs index 4856a0495f..67ca13feb4 100644 --- a/src/spanner/src/read_only_transaction.rs +++ b/src/spanner/src/read_only_transaction.rs @@ -1425,345 +1425,4 @@ pub(crate) mod tests { Ok(()) } - - #[tokio::test] - async fn inline_begin_failure_retry_success() -> anyhow::Result<()> { - use crate::value::Value; - use gaxi::grpc::tonic::Response; - use gaxi::grpc::tonic::Status; - - let mut mock = create_session_mock(); - let mut seq = mockall::Sequence::new(); - - // 1. Initial query fails - mock.expect_execute_streaming_sql() - .times(1) - .in_sequence(&mut seq) - .returning(|_| Err(Status::internal("Internal error"))); - - // 2. Explicit begin transaction succeeds - mock.expect_begin_transaction() - .times(1) - .in_sequence(&mut seq) - .returning(|req| { - let req = req.into_inner(); - assert_eq!( - req.session, - "projects/p/instances/i/databases/d/sessions/123" - ); - // Return a transaction with ID - Ok(Response::new(mock_v1::Transaction { - id: vec![7, 8, 9], - read_timestamp: Some(prost_types::Timestamp { - seconds: 123456789, - nanos: 0, - }), - ..Default::default() - })) - }); - - // 3. Retry of the query succeeds - mock.expect_execute_streaming_sql() - .times(1) - .in_sequence(&mut seq) - .returning(|req| { - let req = req.into_inner(); - // Ensure it uses the new transaction ID - match req.transaction.unwrap().selector.unwrap() { - mock_v1::transaction_selector::Selector::Id(id) => { - assert_eq!(id, vec![7, 8, 9]); - } - _ => panic!("Expected Selector::Id"), - } - Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( - setup_select1(), - )])))) - }); - - let (db_client, _server) = setup_db_client(mock).await; - let tx = db_client - .read_only_transaction() - .with_explicit_begin_transaction(false) - .build() - .await?; - - let mut rs = tx - .execute_query(Statement::builder("SELECT 1").build()) - .await?; - - let row = rs - .next() - .await - .ok_or_else(|| anyhow::anyhow!("Expected a row but stream cleanly exhausted"))??; - assert_eq!( - row.raw_values(), - [Value(string_val("1"))], - "The parsed row value safely matched the underlying stream chunk" - ); - - Ok(()) - } - - #[tokio::test] - async fn inline_begin_failure_retry_failure() -> anyhow::Result<()> { - use gaxi::grpc::tonic::Response; - use gaxi::grpc::tonic::Status; - - let mut mock = create_session_mock(); - let mut seq = mockall::Sequence::new(); - - // 1. Initial query fails - mock.expect_execute_streaming_sql() - .times(1) - .in_sequence(&mut seq) - .returning(|_| Err(Status::internal("Internal error first"))); - - // 2. Explicit begin transaction succeeds - mock.expect_begin_transaction() - .times(1) - .in_sequence(&mut seq) - .returning(|_| { - Ok(Response::new(mock_v1::Transaction { - id: vec![7, 8, 9], - read_timestamp: Some(prost_types::Timestamp { - seconds: 123456789, - nanos: 0, - }), - ..Default::default() - })) - }); - - // 3. Retry of the query fails again - mock.expect_execute_streaming_sql() - .times(1) - .in_sequence(&mut seq) - .returning(|_| Err(Status::internal("Internal error second"))); - - let (db_client, _server) = setup_db_client(mock).await; - let tx = db_client - .read_only_transaction() - .with_explicit_begin_transaction(false) - .build() - .await?; - - let rs_result = tx - .execute_query(Statement::builder("SELECT 1").build()) - .await; - - assert!( - rs_result.is_err(), - "The failed execution bubbled upwards securely" - ); - let err_str = rs_result.unwrap_err().to_string(); - assert!( - err_str.contains("Internal error second"), - "Secondary error message accurately propagates: {}", - err_str - ); - - Ok(()) - } - - #[tokio::test] - async fn inline_begin_failure_fallback_rpc_fails() -> anyhow::Result<()> { - use gaxi::grpc::tonic::Status; - - let mut mock = create_session_mock(); - let mut seq = mockall::Sequence::new(); - - // 1. Initial query fails - mock.expect_execute_streaming_sql() - .times(1) - .in_sequence(&mut seq) - .returning(|_| Err(Status::internal("Internal error query"))); - - // 2. Explicit begin transaction fails - mock.expect_begin_transaction() - .times(1) - .in_sequence(&mut seq) - .returning(|_| Err(Status::internal("Internal error begin tx"))); - - let (db_client, _server) = setup_db_client(mock).await; - let tx = db_client - .read_only_transaction() - .with_explicit_begin_transaction(false) - .build() - .await?; - - let rs_result = tx - .execute_query(Statement::builder("SELECT 1").build()) - .await; - - assert!( - rs_result.is_err(), - "The explicitly errored fallback boot securely propagated outwards" - ); - let err_str = rs_result.unwrap_err().to_string(); - assert!( - err_str.contains("Internal error begin tx"), - "Natively propagated specific BeginTx bounds: {}", - err_str - ); - - Ok(()) - } - - #[tokio::test] - async fn inline_begin_read_failure_retry_success() -> anyhow::Result<()> { - use crate::client::{KeySet, ReadRequest}; - use crate::value::Value; - use gaxi::grpc::tonic::Response; - use gaxi::grpc::tonic::Status; - - let mut mock = create_session_mock(); - let mut seq = mockall::Sequence::new(); - - // 1. Initial read fails - mock.expect_streaming_read() - .times(1) - .in_sequence(&mut seq) - .returning(|_| Err(Status::internal("Internal error"))); - - // 2. Explicit begin transaction succeeds - mock.expect_begin_transaction() - .times(1) - .in_sequence(&mut seq) - .returning(|_| { - Ok(Response::new(mock_v1::Transaction { - id: vec![7, 8, 9], - read_timestamp: None, - ..Default::default() - })) - }); - - // 3. Retry of the read succeeds - mock.expect_streaming_read() - .times(1) - .in_sequence(&mut seq) - .returning(|req| { - let req = req.into_inner(); - // Ensure it uses the new transaction ID - match req.transaction.unwrap().selector.unwrap() { - mock_v1::transaction_selector::Selector::Id(id) => { - assert_eq!(id, vec![7, 8, 9]); - } - _ => panic!("Expected Selector::Id"), - } - Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( - setup_select1(), - )])))) - }); - - let (db_client, _server) = setup_db_client(mock).await; - let tx = db_client - .read_only_transaction() - .with_explicit_begin_transaction(false) - .build() - .await?; - - let read = ReadRequest::builder("Users", vec!["Id", "Name"]) - .with_keys(KeySet::all()) - .build(); - let mut rs = tx.execute_read(read).await?; - - let row = rs - .next() - .await - .ok_or_else(|| anyhow::anyhow!("Expected a row uniquely returned"))??; - assert_eq!( - row.raw_values(), - [Value(string_val("1"))], - "The macro correctly unpacked read arrays seamlessly" - ); - - Ok(()) - } - - #[tokio::test] - async fn single_use_query_send_error_returns_immediately() -> anyhow::Result<()> { - use crate::client::Statement; - use gaxi::grpc::tonic::Status; - - let mut mock = create_session_mock(); - - mock.expect_execute_streaming_sql() - .times(1) - .returning(|_| Err(Status::internal("Internal error single use query"))); - - mock.expect_begin_transaction().never(); - - let (db_client, _server) = setup_db_client(mock).await; - // single_use creates a Fixed selector - let tx = db_client.single_use().build(); - - let rs_result = tx - .execute_query(Statement::builder("SELECT 1").build()) - .await; - - assert!(rs_result.is_err()); - let err_str = rs_result.unwrap_err().to_string(); - assert!(err_str.contains("Internal error single use query")); - - Ok(()) - } - - #[tokio::test] - async fn inline_begin_already_started_query_send_error_returns_immediately() - -> anyhow::Result<()> { - use crate::client::Statement; - use gaxi::grpc::tonic::Status; - use spanner_grpc_mock::google::spanner::v1 as mock_v1; - - let mut mock = create_session_mock(); - let mut seq = mockall::Sequence::new(); - - mock.expect_begin_transaction().never(); - - // 1. First query executes successfully and implicitly starts the transaction. - mock.expect_execute_streaming_sql() - .times(1) - .in_sequence(&mut seq) - .returning(move |_req| { - let mut rs = setup_select1(); - rs.metadata.as_mut().unwrap().transaction = Some(mock_v1::Transaction { - id: vec![4, 5, 6], - read_timestamp: None, - ..Default::default() - }); - Ok(gaxi::grpc::tonic::Response::new(Box::pin( - tokio_stream::iter(vec![Ok(rs)]), - ))) - }); - - // 2. Second query fails immediately upon send() - mock.expect_execute_streaming_sql() - .times(1) - .in_sequence(&mut seq) - .returning(|_| Err(Status::internal("Internal error second query"))); - - let (db_client, _server) = setup_db_client(mock).await; - - let tx = db_client - .read_only_transaction() - .with_explicit_begin_transaction(false) - .build() - .await?; - - // Run first query (starts tx) - let mut rs = tx - .execute_query(Statement::builder("SELECT 1").build()) - .await?; - let _ = rs.next().await.expect("has row")?; - - // Run second query (fails) - let rs_result = tx - .execute_query(Statement::builder("SELECT 2").build()) - .await; - - assert!(rs_result.is_err()); - let err_str = rs_result.unwrap_err().to_string(); - assert!(err_str.contains("Internal error second query")); - - Ok(()) - } } diff --git a/tests/spanner/Cargo.toml b/tests/spanner/Cargo.toml index 461c18d110..30b6625f58 100644 --- a/tests/spanner/Cargo.toml +++ b/tests/spanner/Cargo.toml @@ -36,7 +36,10 @@ google-cloud-test-utils = { workspace = true } prost-types.workspace = true reqwest = { workspace = true, features = ["json"] } serde_json = { workspace = true } +spanner-grpc-mock = { path = "../../src/spanner/grpc-mock" } tokio = { workspace = true, features = ["sync"] } +tokio-stream = { workspace = true } +tonic = { workspace = true } tracing.workspace = true [lints] diff --git a/tests/spanner/src/client.rs b/tests/spanner/src/client.rs index 7b05ecf3c5..3516bc011c 100644 --- a/tests/spanner/src/client.rs +++ b/tests/spanner/src/client.rs @@ -14,6 +14,8 @@ use google_cloud_spanner::client::{KeySet, Mutation, Spanner}; use google_cloud_test_utils::resource_names::LowercaseAlphanumeric; +use std::time::Duration; +use tokio::time::sleep; const PROJECT_ID: &str = "test-project"; const INSTANCE_ID: &str = "test-instance"; @@ -40,7 +42,7 @@ pub async fn wait_for_emulator(endpoint: &str) { static PROVISION_EMULATOR: tokio::sync::OnceCell<()> = tokio::sync::OnceCell::const_new(); static DATABASE_ID: tokio::sync::OnceCell = tokio::sync::OnceCell::const_new(); -async fn get_database_id() -> &'static str { +pub async fn get_database_id() -> &'static str { DATABASE_ID .get_or_init(|| async { std::env::var("SPANNER_EMULATOR_TEST_DB") @@ -59,16 +61,19 @@ pub async fn provision_emulator(endpoint: &str) { .await; } +pub fn get_emulator_rest_endpoint(grpc_endpoint: &str) -> String { + let rest_endpoint = std::env::var("SPANNER_EMULATOR_REST_HOST") + .unwrap_or_else(|_| grpc_endpoint.replace("9010", "9020")); + if rest_endpoint.starts_with("http://") || rest_endpoint.starts_with("https://") { + rest_endpoint + } else { + format!("http://{}", rest_endpoint) + } +} + async fn do_provision_emulator(endpoint: &str) { // TODO(#4973): Re-write this to use the admin clients once those also support the Emulator. - let rest_endpoint = std::env::var("SPANNER_EMULATOR_REST_HOST") - .unwrap_or_else(|_| endpoint.replace("9010", "9020")); - let rest_endpoint = - if rest_endpoint.starts_with("http://") || rest_endpoint.starts_with("https://") { - rest_endpoint - } else { - format!("http://{}", rest_endpoint) - }; + let rest_endpoint = get_emulator_rest_endpoint(endpoint); let client = reqwest::Client::new(); // Create a test instance and ignore any ALREADY_EXISTS errors. @@ -196,3 +201,50 @@ pub async fn create_database_client() -> Option anyhow::Result<()> { + let emulator_host = get_emulator_host().expect("SPANNER_EMULATOR_HOST must be set"); + let rest_endpoint = get_emulator_rest_endpoint(&emulator_host); + let db_path = format!( + "projects/{}/instances/{}/databases/{}", + PROJECT_ID, + INSTANCE_ID, + get_database_id().await + ); + let url = format!("{}/v1/{}/ddl", rest_endpoint, db_path); + let client = reqwest::Client::new(); + let payload = serde_json::json!({ + "statements": [statement] + }); + + let mut attempts = 0; + const MAX_ATTEMPTS: u32 = 25; + + loop { + attempts += 1; + let res = client.patch(&url).json(&payload).send().await?; + + let status = res.status(); + let text = res.text().await?; + + if status.is_success() { + return Ok(()); + } + + // Check if the error is the specific one we want to retry. + // Code 9 is FailedPrecondition. + if text.contains("\"code\":9") && text.contains("Schema change operation rejected") { + if attempts >= MAX_ATTEMPTS { + anyhow::bail!( + "Failed to update DDL after {} attempts. Last error: {}", + attempts, + text + ); + } + sleep(Duration::from_millis(100)).await; + continue; + } + + anyhow::bail!("Failed to update DDL: status={}, body={}", status, text); + } +} diff --git a/tests/spanner/src/lib.rs b/tests/spanner/src/lib.rs index 2be88b7a18..d8a95bf9c1 100644 --- a/tests/spanner/src/lib.rs +++ b/tests/spanner/src/lib.rs @@ -18,4 +18,5 @@ pub mod partitioned_dml; pub mod query; pub mod read; pub mod read_write_transaction; +pub mod test_proxy; pub mod write; diff --git a/tests/spanner/src/query.rs b/tests/spanner/src/query.rs index b62b84a99c..58c785c0e1 100644 --- a/tests/spanner/src/query.rs +++ b/tests/spanner/src/query.rs @@ -12,7 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -use google_cloud_spanner::client::{DatabaseClient, Kind, Statement}; +use crate::client::{get_database_id, get_emulator_host}; +use crate::test_proxy::{InterceptedSpanner, SpannerInterceptor}; +use google_cloud_spanner::client::{DatabaseClient, Kind, Spanner, Statement}; +use google_cloud_test_utils::resource_names::LowercaseAlphanumeric; +use spanner_grpc_mock::google::spanner::v1 as spanner_v1; +use spanner_grpc_mock::google::spanner::v1::spanner_client::SpannerClient; +use spanner_grpc_mock::google::spanner::v1::spanner_server::SpannerServer; +use std::sync::Arc; +use tokio::net::TcpListener; +use tokio::sync::Notify; +use tokio_stream::wrappers::TcpListenerStream; +use tonic::transport::{Channel, Server}; pub async fn simple_query(db_client: &DatabaseClient) -> anyhow::Result<()> { let rot = db_client.single_use().build(); @@ -285,42 +296,6 @@ pub async fn multi_use_read_only_transaction_invalid_query_fallback( Ok(()) } -pub async fn multi_use_read_only_transaction_invalid_query_fallback( - db_client: &DatabaseClient, -) -> anyhow::Result<()> { - // Start a multi-use read-only transaction with implicit begin. - let tx = db_client - .read_only_transaction() - .with_explicit_begin_transaction(false) - .build() - .await?; - - // Expect a read timestamp to NOT have been chosen yet. - assert!(tx.read_timestamp().is_none()); - - // Execute the first query with invalid syntax. - let rs_result = tx - .execute_query(Statement::builder("SELECT * FROM NonExistentTable").build()) - .await; - - assert!(rs_result.is_err(), "Expected an error from an invalid query"); - - // The read timestamp should now be available because the transaction - // fell back to an explicit BeginTransaction. - assert!(tx.read_timestamp().is_some()); - - // It should be possible to use the transaction. - let mut rs2 = tx - .execute_query(Statement::builder("SELECT 2 AS col_int").build()) - .await?; - - let row2 = rs2.next().await.transpose()?.expect("should yield a row"); - let val2 = row2.raw_values()[0].as_string(); - assert_eq!(val2, "2"); - - Ok(()) -} - fn verify_null_row(row: &google_cloud_spanner::client::Row) { let raw_values = row.raw_values(); assert_eq!(raw_values.len(), 20, "Row should have exactly 20 columns"); @@ -444,3 +419,138 @@ fn verify_row_2(row: &google_cloud_spanner::client::Row) { "2026-03-11T16:20:00Z" ); } + +struct DelayedBeginProxy { + emulator_client: SpannerClient, + latch: Arc, + begin_transaction_entered_latch: Arc, +} + +#[tonic::async_trait] +impl SpannerInterceptor for DelayedBeginProxy { + fn emulator_client(&self) -> SpannerClient { + self.emulator_client.clone() + } + + async fn begin_transaction( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.begin_transaction_entered_latch.notify_one(); + self.latch.notified().await; + self.emulator_client().begin_transaction(request).await + } +} + +// This test verifies that the client correctly falls back to `BeginTransaction` when the +// first statement in a transaction fails. It also shows that the statement is retried and +// could (theoretically) succeed during this retry. It achieves this by doing the following: +// 1. It uses a proxy that allows it to intercept the RPCs that are being sent to Spanner. +// 2. It creates a read-only transaction that uses inline-begin-transaction. +// 3. It executes a query that tries to read from a table that does not exist. +// 4. As the first statement in the transaction fails, the client falls back to using +// an explicit BeginTransaction RPC. +// 5. The proxy blocks this BeginTransaction RPC, and in the meantime the test creates +// the missing table. +// 6. The proxy unblocks the BeginTransaction RPC. +// 7. The statement is retried and succeeds. The test never sees the error. +// +// This test might seem like an extreme corner case for a read-only transaction like this. +// However, for read/write transactions, similar types of failures are more likely to occur, +// for example if a transaction tries to insert a row that violates the primary key. Another +// transaction could delete the row in the time between the first attempt failed, and the +// BeginTransaction RPC has been executed. +pub async fn inline_begin_fallback(_db_client: &DatabaseClient) -> anyhow::Result<()> { + let emulator_host = get_emulator_host().expect("SPANNER_EMULATOR_HOST must be set"); + let latch = Arc::new(Notify::new()); + let begin_transaction_entered_latch = Arc::new(Notify::new()); + + // Create a raw gRPC client that connects to the Spanner Emulator. + // This will be used by the proxy server to forward requests to the Emulator. + let endpoint = Channel::from_shared(format!("http://{}", emulator_host))? + .connect() + .await?; + let raw_client = SpannerClient::new(endpoint); + + // Create a local TCP listener to bind our proxy server to. + let listener = TcpListener::bind("127.0.0.1:0").await?; + let local_addr = listener.local_addr()?; + let proxy_address = format!("{}:{}", local_addr.ip(), local_addr.port()); + + let proxy = DelayedBeginProxy { + emulator_client: raw_client, + latch: Arc::clone(&latch), + begin_transaction_entered_latch: Arc::clone(&begin_transaction_entered_latch), + }; + + let _server_handle = tokio::spawn(async move { + let stream = TcpListenerStream::new(listener); + Server::builder() + .add_service(SpannerServer::new(InterceptedSpanner(proxy))) + .serve_with_incoming(stream) + .await + .expect("Proxy server failed"); + }); + + // We build the Spanner DatabaseClient pointing directly to our proxy address over HTTP. + let proxy_db_client = Spanner::builder() + .with_endpoint(format!("http://{}", proxy_address)) + .build() + .await? + .database_client(format!( + "projects/test-project/instances/test-instance/databases/{}", + get_database_id().await + )) + .build() + .await?; + + let tx = proxy_db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let table_name = LowercaseAlphanumeric.random_string(10); + let table_name = format!("LateLoadedTable_{}", table_name); + + // Create a task that tries to query the table before it exists. + // This will initially fail, and the client will fall back to using + // an explicit BeginTransaction RPC. The table will then be created + // BEFORE the BeginTransaction RPC is executed, which will cause the + // query to succeed when it is retried using the transaction ID that + // was returned by BeginTransaction. This task will never see the + // initial error, and instead it will seem like the query simply + // succeeded. + let query_task = tokio::spawn({ + let table_name = table_name.clone(); + async move { + let stmt = Statement::builder(format!("SELECT * FROM {}", table_name)).build(); + let mut rs = tx.execute_query(stmt).await?; + let _ = rs.next().await; + Ok::<_, anyhow::Error>(tx) + } + }); + + // Wait until the query task above has been executed and has triggered an + // explicit BeginTransaction RPC. The BeginTransaction RPC is blocked until + // `latch` is notified. + begin_transaction_entered_latch.notified().await; + + // Create the table on the emulator while the BeginTransaction RPC is blocked. + let statement = format!("CREATE TABLE {} (Id INT64) PRIMARY KEY (Id)", table_name); + crate::client::update_database_ddl(statement).await?; + + // Unblock the BeginTransaction RPC. + latch.notify_one(); + + // Wait for the query task to complete. It should succeed and never see + // the initial error. + let tx = query_task.await??; + + assert!( + tx.read_timestamp().is_some(), + "The transaction should have a read timestamp" + ); + + Ok(()) +} diff --git a/tests/spanner/src/test_proxy.rs b/tests/spanner/src/test_proxy.rs new file mode 100644 index 0000000000..7f0a5837a2 --- /dev/null +++ b/tests/spanner/src/test_proxy.rs @@ -0,0 +1,269 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use spanner_grpc_mock::google::spanner::v1 as spanner_v1; +use spanner_grpc_mock::google::spanner::v1::spanner_client::SpannerClient; + +#[tonic::async_trait] +pub trait SpannerInterceptor: Send + Sync + 'static { + fn emulator_client(&self) -> SpannerClient; + + async fn create_session( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().create_session(request).await + } + + async fn batch_create_sessions( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> + { + self.emulator_client().batch_create_sessions(request).await + } + + async fn get_session( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().get_session(request).await + } + + async fn list_sessions( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().list_sessions(request).await + } + + async fn delete_session( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().delete_session(request).await + } + + async fn execute_sql( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().execute_sql(request).await + } + + async fn execute_streaming_sql( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response>, + tonic::Status, + > { + self.emulator_client().execute_streaming_sql(request).await + } + + async fn execute_batch_dml( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> + { + self.emulator_client().execute_batch_dml(request).await + } + + async fn read( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().read(request).await + } + + async fn streaming_read( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response>, + tonic::Status, + > { + self.emulator_client().streaming_read(request).await + } + + async fn begin_transaction( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().begin_transaction(request).await + } + + async fn commit( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().commit(request).await + } + + async fn rollback( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().rollback(request).await + } + + async fn partition_query( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().partition_query(request).await + } + + async fn partition_read( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.emulator_client().partition_read(request).await + } + + async fn batch_write( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response>, + tonic::Status, + > { + self.emulator_client().batch_write(request).await + } +} + +pub struct InterceptedSpanner(pub T); + +#[tonic::async_trait] +impl spanner_v1::spanner_server::Spanner for InterceptedSpanner { + async fn create_session( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.create_session(request).await + } + + async fn batch_create_sessions( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> + { + self.0.batch_create_sessions(request).await + } + + async fn get_session( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.get_session(request).await + } + + async fn list_sessions( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.list_sessions(request).await + } + + async fn delete_session( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.delete_session(request).await + } + + async fn execute_sql( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.execute_sql(request).await + } + + type ExecuteStreamingSqlStream = tonic::codec::Streaming; + + async fn execute_streaming_sql( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.execute_streaming_sql(request).await + } + + async fn execute_batch_dml( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> + { + self.0.execute_batch_dml(request).await + } + + async fn read( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.read(request).await + } + + type StreamingReadStream = tonic::codec::Streaming; + + async fn streaming_read( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.streaming_read(request).await + } + + async fn begin_transaction( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.begin_transaction(request).await + } + + async fn commit( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.commit(request).await + } + + async fn rollback( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.rollback(request).await + } + + async fn partition_query( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.partition_query(request).await + } + + async fn partition_read( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.partition_read(request).await + } + + type BatchWriteStream = tonic::codec::Streaming; + + async fn batch_write( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status> { + self.0.batch_write(request).await + } +} diff --git a/tests/spanner/tests/driver.rs b/tests/spanner/tests/driver.rs index 2492088f67..4d45413e13 100644 --- a/tests/spanner/tests/driver.rs +++ b/tests/spanner/tests/driver.rs @@ -30,6 +30,7 @@ mod spanner { &db_client, ) .await?; + integration_tests_spanner::query::inline_begin_fallback(&db_client).await?; Ok(()) } From 97cb2cd445b1382bd0d2a3f2e536f16f6af63416 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 2 Apr 2026 20:49:19 +0200 Subject: [PATCH 12/17] test(spanner): add integration test for inline-begin error handling Adds an integration test for error handling for inline-begin-transaction. This test uses a gRPC proxy to intercept calls from the client to Spanner to be able to deterministically emulate specific concurrency issues. This test shows how a query that failed during the first attempt, and thereby also failed to start the transaction, could succeed during a retry after the transaction has been started with an explicit BeginTransaction RPC. --- tests/spanner/src/query.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/spanner/src/query.rs b/tests/spanner/src/query.rs index 58c785c0e1..7532a80084 100644 --- a/tests/spanner/src/query.rs +++ b/tests/spanner/src/query.rs @@ -12,7 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +<<<<<<< HEAD use crate::client::{get_database_id, get_emulator_host}; +======= +use crate::client::{get_database_id, get_emulator_host, get_emulator_rest_endpoint}; +>>>>>>> 255a00797 (test(spanner): add integration test for inline-begin error handling) use crate::test_proxy::{InterceptedSpanner, SpannerInterceptor}; use google_cloud_spanner::client::{DatabaseClient, Kind, Spanner, Statement}; use google_cloud_test_utils::resource_names::LowercaseAlphanumeric; From 6b7cffc997305b6639a584932e4b67c01c44a4d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 3 Apr 2026 18:10:18 +0200 Subject: [PATCH 13/17] perf(spanner): support concurrent queries with inline-begin-transaction Adds support for running concurrent queries in combination with inline-begin-transaction. Only one of the queries will include the BeginTransaction option. The other queries will wait until the first query has returned a transaction ID. --- Cargo.lock | 3 + .../src/batch_read_only_transaction.rs | 10 +- src/spanner/src/read_only_transaction.rs | 940 ++++++++++++++++-- src/spanner/src/read_write_transaction.rs | 12 +- src/spanner/src/result_set.rs | 12 +- src/spanner/src/transaction_runner.rs | 2 +- tests/spanner/Cargo.toml | 3 + tests/spanner/src/concurrent_inline_begin.rs | 264 +++++ tests/spanner/src/lib.rs | 1 + tests/spanner/src/query.rs | 4 - tests/spanner/src/test_proxy.rs | 52 +- tests/spanner/tests/driver.rs | 5 + 12 files changed, 1221 insertions(+), 87 deletions(-) create mode 100644 tests/spanner/src/concurrent_inline_begin.rs diff --git a/Cargo.lock b/Cargo.lock index 52c1732b1c..5207f37093 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5911,10 +5911,13 @@ dependencies = [ "google-cloud-lro", "google-cloud-spanner", "google-cloud-test-utils", + "google-cloud-wkt", "prost-types", + "rand 0.10.0", "reqwest 0.13.2", "serde_json", "spanner-grpc-mock", + "time", "tokio", "tokio-stream", "tonic", diff --git a/src/spanner/src/batch_read_only_transaction.rs b/src/spanner/src/batch_read_only_transaction.rs index ff59050803..313f291d22 100644 --- a/src/spanner/src/batch_read_only_transaction.rs +++ b/src/spanner/src/batch_read_only_transaction.rs @@ -144,12 +144,13 @@ impl BatchReadOnlyTransaction { statement: T, options: PartitionOptions, ) -> crate::Result> { + let selector = self.inner.context.transaction_selector.selector().await?; let statement = statement.into(); let request = statement .clone() .into_partition_query_request() .set_session(self.inner.context.client.session.name.clone()) - .set_transaction(self.inner.context.transaction_selector.selector()) + .set_transaction(selector.clone()) .set_partition_options(options); let response = self @@ -166,7 +167,7 @@ impl BatchReadOnlyTransaction { .map(|p| Partition { inner: PartitionedOperation::Query { partition_token: p.partition_token, - transaction_selector: self.inner.context.transaction_selector.selector(), + transaction_selector: selector.clone(), session_name: self.inner.context.client.session.name.clone(), statement: statement.clone(), }, @@ -199,12 +200,13 @@ impl BatchReadOnlyTransaction { read: T, options: PartitionOptions, ) -> crate::Result> { + let selector = self.inner.context.transaction_selector.selector().await?; let read = read.into(); let request = read .clone() .into_partition_read_request() .set_session(self.inner.context.client.session.name.clone()) - .set_transaction(self.inner.context.transaction_selector.selector()) + .set_transaction(selector.clone()) .set_partition_options(options); let response = self @@ -221,7 +223,7 @@ impl BatchReadOnlyTransaction { .map(|p| Partition { inner: PartitionedOperation::Read { partition_token: p.partition_token, - transaction_selector: self.inner.context.transaction_selector.selector(), + transaction_selector: selector.clone(), session_name: self.inner.context.client.session.name.clone(), read_request: read.clone(), }, diff --git a/src/spanner/src/read_only_transaction.rs b/src/spanner/src/read_only_transaction.rs index 67ca13feb4..871a6d8624 100644 --- a/src/spanner/src/read_only_transaction.rs +++ b/src/spanner/src/read_only_transaction.rs @@ -20,6 +20,7 @@ use crate::result_set::{ResultSet, StreamOperation}; use crate::statement::Statement; use crate::timestamp_bound::TimestampBound; use std::sync::{Arc, Mutex}; +use tokio::sync::Notify; /// A builder for [SingleUseReadOnlyTransaction]. /// @@ -445,28 +446,72 @@ pub(crate) enum ReadContextTransactionSelector { #[derive(Clone, Debug)] pub(crate) enum TransactionState { NotStarted(crate::model::TransactionOptions), + Starting(crate::model::TransactionOptions, Arc), Started(crate::model::TransactionSelector, Option), + Failed(Arc), } -impl TransactionState { - fn selector(&self) -> crate::model::TransactionSelector { - match self { - Self::Started(selector, _) => selector.clone(), - Self::NotStarted(options) => { - crate::model::TransactionSelector::default().set_begin(options.clone()) - } - } - } +enum SelectorStatus { + Ready(crate::model::TransactionSelector), + Wait(std::sync::Arc), } impl ReadContextTransactionSelector { - pub(crate) fn selector(&self) -> crate::model::TransactionSelector { + pub(crate) async fn selector(&self) -> crate::Result { match self { - Self::Fixed(selector, _) => selector.clone(), - Self::Lazy(lazy) => lazy - .lock() - .expect("transaction state mutex poisoned") - .selector(), + Self::Fixed(selector, _) => Ok(selector.clone()), + Self::Lazy(_) => loop { + match self.poll_selector_status()? { + SelectorStatus::Ready(selector) => return Ok(selector), + SelectorStatus::Wait(notify) => notify.notified().await, + } + }, + } + } + + /// Inspects the current lazy selector state returning whether it is ready, + /// failed, or needs to wait for the transaction to start. + fn poll_selector_status(&self) -> crate::Result { + let Self::Lazy(lazy) = self else { + unreachable!("poll_selector_status called on non-Lazy selector"); + }; + let mut guard = lazy.lock().expect("transaction state mutex poisoned"); + + // Fast path: Transaction is already started. + if let TransactionState::Started(selector, _) = &*guard { + return Ok(SelectorStatus::Ready(selector.clone())); + } + + // If the transaction has not started, extract options and proceed to transition. + let pending_options = if let TransactionState::NotStarted(options) = &*guard { + Some(options.clone()) + } else { + None + }; + if let Some(options) = pending_options { + let notify = Arc::new(Notify::new()); + *guard = TransactionState::Starting(options.clone(), Arc::clone(¬ify)); + return Ok(SelectorStatus::Ready( + crate::model::TransactionSelector::default().set_begin(options), + )); + } + + // Handle other states: yield error or wait. + match &*guard { + // Note: Failed will only be reached if the following happens: + // 1. The first query fails and the transaction falls back to an explicit BeginTransaction RPC. + // 2. The BeginTransaction RPC fails. This is the error that will be returned to all the waiting queries. + TransactionState::Failed(err) => { + let error = if let Some(status) = err.status() { + crate::Error::service(status.clone()) + } else { + crate::error::internal_error(format!("Transaction failed to start: {}", err)) + }; + Err(error) + } + // Transaction is starting. Wait until a transaction ID is returned. + TransactionState::Starting(_, notify) => Ok(SelectorStatus::Wait(Arc::clone(notify))), + TransactionState::Started(_, _) | TransactionState::NotStarted(_) => unreachable!(), } } @@ -482,29 +527,101 @@ impl ReadContextTransactionSelector { return Ok(()); }; - let options = { + let (options, notify_opt) = { let guard = lazy.lock().expect("transaction state mutex poisoned"); - let TransactionState::NotStarted(options) = &*guard else { - return Ok(()); - }; - options.clone() + match &*guard { + // This should never happen in the current implementation. + TransactionState::NotStarted(_) => { + return Err(crate::error::internal_error( + "explicit begin with NotStarted state is currently unsupported", + )); + } + TransactionState::Starting(options, notify) => { + (options.clone(), Some(Arc::clone(notify))) + } + TransactionState::Started(_, _) | TransactionState::Failed(_) => return Ok(()), + } }; - let response = execute_begin_transaction(client, options).await?; - self.update(response.id, response.read_timestamp); + let response = match execute_begin_transaction(client, options).await { + Ok(r) => r, + Err(e) => { + let mut guard = lazy.lock().expect("transaction state mutex poisoned"); + let error = Arc::new(e); + *guard = TransactionState::Failed(Arc::clone(&error)); + // Release the lock and notify all the waiting queries that + // the transaction has failed. + drop(guard); + if let Some(notify) = notify_opt { + notify.notify_waiters(); + } + + let return_error = if let Some(status) = error.status() { + crate::Error::service(status.clone()) + } else { + crate::error::internal_error(format!("Transaction failed to start: {}", error)) + }; + return Err(return_error); + } + }; + + self.update(response.id, response.read_timestamp)?; Ok(()) } - pub(crate) fn update(&self, id: bytes::Bytes, timestamp: Option) { - if let Self::Lazy(lazy) = self { - let mut guard = lazy.lock().expect("transaction state mutex poisoned"); - if matches!(&*guard, TransactionState::NotStarted(_)) { - *guard = TransactionState::Started( + pub(crate) fn update( + &self, + id: bytes::Bytes, + timestamp: Option, + ) -> crate::Result<()> { + let Self::Lazy(lazy) = self else { + return Ok(()); + }; + let mut guard = lazy.lock().expect("transaction state mutex poisoned"); + + if matches!( + &*guard, + TransactionState::NotStarted(_) | TransactionState::Starting(_, _) + ) { + let previous_state = std::mem::replace( + &mut *guard, + TransactionState::Started( crate::model::TransactionSelector::default().set_id(id), timestamp, - ); + ), + ); + drop(guard); + + // Notify all queries that are waiting for the transaction. + if let TransactionState::Starting(_, notify) = previous_state { + notify.notify_waiters(); } + Ok(()) + } else { + Err(crate::error::internal_error( + "got a transaction id for an already Started or Failed transaction", + )) + } + } + + /// Resets the selector state from `Starting` back to `NotStarted`. + /// + /// This is used during stream resume fallbacks when the first query stream + /// fails before yielding a transaction ID. It unlocks any parked waiters + /// allowing them (or the retry attempt) to include the begin option again. + pub(crate) fn maybe_reset_starting(&self) { + let Self::Lazy(lazy) = self else { + return; + }; + + let mut guard = lazy.lock().expect("transaction state mutex poisoned"); + if let TransactionState::Starting(options, notify) = &*guard { + let options = options.clone(); + let notify = Arc::clone(notify); + *guard = TransactionState::NotStarted(options); + drop(guard); + notify.notify_waiters(); } } @@ -583,7 +700,7 @@ macro_rules! execute_stream_with_retry { Ok(s) => s, Err(e) => { if $self.begin_explicitly_if_not_started().await? { - $request.transaction = Some($self.transaction_selector.selector()); + $request.transaction = Some($self.transaction_selector.selector().await?); $self .client .spanner @@ -616,7 +733,7 @@ impl ReadContext { .into() .into_request() .set_session(self.client.session.name.clone()) - .set_transaction(self.transaction_selector.selector()); + .set_transaction(self.transaction_selector.selector().await?); request.request_options = self.amend_request_options(request.request_options); execute_stream_with_retry!(self, request, execute_streaming_sql, StreamOperation::Query) @@ -630,7 +747,7 @@ impl ReadContext { .into() .into_request() .set_session(self.client.session.name.clone()) - .set_transaction(self.transaction_selector.selector()); + .set_transaction(self.transaction_selector.selector().await?); request.request_options = self.amend_request_options(request.request_options); execute_stream_with_retry!(self, request, streaming_read, StreamOperation::Read) @@ -640,8 +757,16 @@ impl ReadContext { #[cfg(test)] pub(crate) mod tests { use super::*; + use crate::client::Statement; use crate::result_set::tests::string_val; + use crate::value::Value; + use gaxi::grpc::tonic::{self, Code, Response, Status}; + use mock_v1::transaction_selector::Selector; use spanner_grpc_mock::google::spanner::v1 as mock_v1; + use std::pin::Pin; + use std::sync::Arc; + use std::task::{Context, Poll}; + use tokio::sync::{Barrier, Mutex, Notify, mpsc}; #[test] fn auto_traits() { @@ -655,12 +780,10 @@ pub(crate) mod tests { pub(crate) fn create_session_mock() -> spanner_grpc_mock::MockSpanner { let mut mock = spanner_grpc_mock::MockSpanner::new(); mock.expect_create_session().once().returning(|_| { - Ok(gaxi::grpc::tonic::Response::new( - spanner_grpc_mock::google::spanner::v1::Session { - name: "projects/p/instances/i/databases/d/sessions/123".to_string(), - ..Default::default() - }, - )) + Ok(Response::new(mock_v1::Session { + name: "projects/p/instances/i/databases/d/sessions/123".to_string(), + ..Default::default() + })) }); mock } @@ -689,6 +812,7 @@ pub(crate) mod tests { let (address, server) = spanner_grpc_mock::start("0.0.0.0:0", mock) .await .expect("Failed to start mock server"); + let spanner = Spanner::builder() .with_endpoint(address) .with_credentials(Anonymous::new().build()) @@ -712,7 +836,12 @@ pub(crate) mod tests { let (db_client, _server) = setup_db_client(mock).await; let tx = db_client.single_use().build(); - let selector = tx.context.transaction_selector.selector(); + let selector = tx + .context + .transaction_selector + .selector() + .await + .expect("Failed to get selector"); let ro = selector .single_use() .expect("Expected SingleUse selector") @@ -729,7 +858,12 @@ pub(crate) mod tests { std::time::Duration::from_secs(10), )) .build(); - let selector = tx2.context.transaction_selector.selector(); + let selector = tx2 + .context + .transaction_selector + .selector() + .await + .expect("Failed to get selector"); let ro2 = selector .single_use() .expect("Expected SingleUse selector") @@ -761,9 +895,9 @@ pub(crate) mod tests { ); assert_eq!(req.sql, "SELECT 1"); - Ok(gaxi::grpc::tonic::Response::new(Box::pin( - tokio_stream::iter(vec![Ok(setup_select1())]), - ))) + Ok(tonic::Response::new(Box::pin(tokio_stream::iter(vec![ + Ok(setup_select1()), + ])))) }); let (db_client, _server) = setup_db_client(mock).await; @@ -795,7 +929,7 @@ pub(crate) mod tests { req.session, "projects/p/instances/i/databases/d/sessions/123" ); - Ok(gaxi::grpc::tonic::Response::new(mock_v1::Transaction { + Ok(tonic::Response::new(mock_v1::Transaction { id: vec![1, 2, 3], // prost_types::Timestamp fields need to be explicitly set because default is 0 for both read_timestamp: Some(prost_types::Timestamp { @@ -822,9 +956,9 @@ pub(crate) mod tests { mock_v1::transaction_selector::Selector::Id(vec![1, 2, 3]) ); - Ok(gaxi::grpc::tonic::Response::new(Box::pin( - tokio_stream::iter(vec![Ok(setup_select1())]), - ))) + Ok(tonic::Response::new(Box::pin(tokio_stream::iter(vec![ + Ok(setup_select1()), + ])))) }); let (db_client, _server) = setup_db_client(mock).await; @@ -894,9 +1028,9 @@ pub(crate) mod tests { }), ..Default::default() }); - Ok(gaxi::grpc::tonic::Response::new(Box::pin( - tokio_stream::iter(vec![Ok(rs)]), - ))) + Ok(tonic::Response::new(Box::pin(tokio_stream::iter(vec![ + Ok(rs), + ])))) }); mock.expect_execute_streaming_sql() @@ -911,9 +1045,9 @@ pub(crate) mod tests { } _ => panic!("Expected Selector::Id"), } - Ok(gaxi::grpc::tonic::Response::new(Box::pin( - tokio_stream::iter(vec![Ok(setup_select1())]), - ))) + Ok(tonic::Response::new(Box::pin(tokio_stream::iter(vec![ + Ok(setup_select1()), + ])))) }); let (db_client, _server) = setup_db_client(mock).await; @@ -969,9 +1103,9 @@ pub(crate) mod tests { assert_eq!(req.table, "Users"); assert_eq!(req.columns, vec!["Id".to_string(), "Name".to_string()]); - Ok(gaxi::grpc::tonic::Response::new(Box::pin( - tokio_stream::iter(vec![Ok(setup_select1())]), - ))) + Ok(tonic::Response::new(Box::pin(tokio_stream::iter(vec![ + Ok(setup_select1()), + ])))) }); let (db_client, _server) = setup_db_client(mock).await; @@ -1088,8 +1222,8 @@ pub(crate) mod tests { #[tokio::test] async fn inline_begin_failure_retry_success() -> anyhow::Result<()> { use crate::value::Value; - use gaxi::grpc::tonic::Response; use gaxi::grpc::tonic::Status; + use tonic::Response; let mut mock = create_session_mock(); let mut seq = mockall::Sequence::new(); @@ -1165,8 +1299,8 @@ pub(crate) mod tests { #[tokio::test] async fn inline_begin_failure_retry_failure() -> anyhow::Result<()> { - use gaxi::grpc::tonic::Response; use gaxi::grpc::tonic::Status; + use tonic::Response; let mut mock = create_session_mock(); let mut seq = mockall::Sequence::new(); @@ -1271,8 +1405,8 @@ pub(crate) mod tests { async fn inline_begin_read_failure_retry_success() -> anyhow::Result<()> { use crate::client::{KeySet, ReadRequest}; use crate::value::Value; - use gaxi::grpc::tonic::Response; use gaxi::grpc::tonic::Status; + use tonic::Response; let mut mock = create_session_mock(); let mut seq = mockall::Sequence::new(); @@ -1389,9 +1523,9 @@ pub(crate) mod tests { read_timestamp: None, ..Default::default() }); - Ok(gaxi::grpc::tonic::Response::new(Box::pin( - tokio_stream::iter(vec![Ok(rs)]), - ))) + Ok(tonic::Response::new(Box::pin(tokio_stream::iter(vec![ + Ok(rs), + ])))) }); // 2. Second query fails immediately upon send() @@ -1425,4 +1559,688 @@ pub(crate) mod tests { Ok(()) } + + /// A wrapper that implements `tokio_stream::Stream` for a `mpsc::Receiver`. + /// Useful in mock setups to yield controlled streaming test responses. + struct ReceiverStream(mpsc::Receiver); + impl tokio_stream::Stream for ReceiverStream { + type Item = T; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.0.poll_recv(cx) + } + } + + #[tokio::test] + async fn execute_concurrent_queries_inline_begin() -> anyhow::Result<()> { + let mut mock = create_session_mock(); + mock.expect_begin_transaction().never(); + + let mut seq = mockall::Sequence::new(); + let (tx_sender, rx_receiver) = mpsc::channel(1); + let rx_receiver = Arc::new(Mutex::new(Some(rx_receiver))); + + let task1_ready = Arc::new(Notify::new()); + let task1_ready_clone = Arc::clone(&task1_ready); + let tasks_started = Arc::new(Barrier::new(3)); + + // 1. First query: should include Selector::Begin + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(move |req| { + task1_ready_clone.notify_one(); + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin for first query"), + } + let rx = rx_receiver + .try_lock() + .expect("mutex poisoned") + .take() + .unwrap(); + Ok(Response::new(Box::pin(ReceiverStream(rx)))) + }); + + // 2. The other queries: should include populated Selector::Id + mock.expect_execute_streaming_sql() + .times(2) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + } + _ => panic!("Expected Selector::Id for other queries"), + } + + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( + setup_select1(), + )])))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let tx = Arc::new(tx); + + // Spawn 3 concurrent queries. + // Task 1 launches first and executes the first query. + let tx1 = Arc::clone(&tx); + let handle1 = tokio::spawn(async move { + let mut rs = tx1 + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + // Read the first result to get the transaction ID. + let _ = rs.next().await; + Ok::<_, crate::Error>(rs) + }); + + // Wait for Task 1 to reach the mock server. + task1_ready.notified().await; + + let tx2 = Arc::clone(&tx); + let tasks_started2 = Arc::clone(&tasks_started); + let handle2 = tokio::spawn(async move { + tasks_started2.wait().await; + tx2.execute_query(Statement::builder("SELECT 1").build()) + .await + }); + + let tx3 = Arc::clone(&tx); + let tasks_started3 = Arc::clone(&tasks_started); + let handle3 = tokio::spawn(async move { + tasks_started3.wait().await; + tx3.execute_query(Statement::builder("SELECT 1").build()) + .await + }); + + // Ensure both Tasks 2 and 3 have reached the barrier before proceeding. + tasks_started.wait().await; + + // Flush the scheduler on this single-threaded executor. + // This guarantees that Tasks 2 & 3 run until they both hit the internal + // selector Notify latch and become suspended. + tokio::task::yield_now().await; + + // Provide the first result (including the transaction ID) to Task 1. + // This transitions the selector to 'Started' and unblocks Tasks 2 and 3. + let mut rs = setup_select1(); + rs.metadata.as_mut().unwrap().transaction = Some(mock_v1::Transaction { + id: vec![4, 5, 6], + read_timestamp: Some(prost_types::Timestamp { + seconds: 987654321, + nanos: 0, + }), + ..Default::default() + }); + tx_sender.send(Ok(rs)).await.expect("channel broken"); + drop(tx_sender); + + // Collect all results + let mut rs1 = handle1.await??; + let mut rs2 = handle2.await??; + let mut rs3 = handle3.await??; + + // Verify the query results + assert!(rs1.next().await.is_none()); + + let row2 = rs2.next().await.expect("Expected a row")?; + assert_eq!(row2.raw_values(), [Value(string_val("1"))]); + assert!(rs2.next().await.is_none()); + + let row3 = rs3.next().await.expect("Expected a row")?; + assert_eq!(row3.raw_values(), [Value(string_val("1"))]); + assert!(rs3.next().await.is_none()); + + // Verify that the read timestamp was populated + assert_eq!(tx.read_timestamp().unwrap().seconds(), 987654321); + + Ok(()) + } + + #[tokio::test] + async fn execute_concurrent_queries_inline_begin_failed_cascade() -> anyhow::Result<()> { + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + let (tx_sender, rx_receiver) = mpsc::channel(1); + let rx_receiver = Arc::new(Mutex::new(Some(rx_receiver))); + + let task1_ready = Arc::new(Notify::new()); + let task1_ready_clone = Arc::clone(&task1_ready); + let tasks_started = Arc::new(Barrier::new(3)); + + // 1. Return a stream connected to tx_sender. + // We will use tx_sender later in the test to inject a failed first chunk. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(move |_req| { + task1_ready_clone.notify_one(); + let rx = rx_receiver + .try_lock() + .expect("mutex poisoned") + .take() + .unwrap(); + Ok(tonic::Response::new(Box::pin(ReceiverStream(rx)))) + }); + + // 2. Fallback BeginTransaction RPC fails + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| { + Err(gaxi::grpc::tonic::Status::internal( + "Fallback BeginTransaction failed", + )) + }); + + // The other queries will never be executed. + mock.expect_execute_streaming_sql().times(0).returning(|_| { + panic!("Other queries should not launch after failure to start the transaction") + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let tx = Arc::new(tx); + + // Spawn 3 concurrent queries. + let tx1 = Arc::clone(&tx); + let handle1 = tokio::spawn(async move { + let mut rs = tx1 + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + rs.next().await.ok_or_else(|| { + crate::error::internal_error("stream exhausted (this should never happen)") + })??; + Ok::<_, crate::Error>(rs) + }); + + // Wait for Task 1 to reach the mock and transition the selector to Starting. + task1_ready.notified().await; + + let tx2 = Arc::clone(&tx); + let tasks_started2 = Arc::clone(&tasks_started); + let handle2 = tokio::spawn(async move { + tasks_started2.wait().await; + tx2.execute_query(Statement::builder("SELECT 1").build()) + .await + }); + + let tx3 = Arc::clone(&tx); + let tasks_started3 = Arc::clone(&tasks_started); + let handle3 = tokio::spawn(async move { + tasks_started3.wait().await; + tx3.execute_query(Statement::builder("SELECT 1").build()) + .await + }); + + // Ensure both Tasks 2 and 3 have reached the barrier before proceeding. + tasks_started.wait().await; + + // Flush the scheduler on this single-threaded executor. + // This guarantees that Tasks 2 & 3 run until they both hit the internal + // selector Notify latch and become suspended. + tokio::task::yield_now().await; + + // Push error to channel failing first query stream! + tx_sender + .send(Err(gaxi::grpc::tonic::Status::internal( + "Mocked boot failed", + ))) + .await + .expect("channel broken"); + drop(tx_sender); + + // Collect all results - all should fail with identical cached error! + let err1 = handle1.await?.unwrap_err().to_string(); + let err2 = handle2.await?.unwrap_err().to_string(); + let err3 = handle3.await?.unwrap_err().to_string(); + + assert!( + err1.contains("Fallback BeginTransaction failed"), + "err1: {}", + err1 + ); + assert!( + err2.contains("Fallback BeginTransaction failed"), + "err2: {}", + err2 + ); + assert!( + err3.contains("Fallback BeginTransaction failed"), + "err3: {}", + err3 + ); + + Ok(()) + } + + #[tokio::test] + async fn execute_concurrent_queries_inline_begin_stream_restart_deadlock_prevention() + -> crate::Result<()> { + let mut mock = create_session_mock(); + mock.expect_begin_transaction().never(); + + let mut seq = mockall::Sequence::new(); + + let (tx_sender, rx_receiver) = mpsc::channel(1); + let rx_receiver = Arc::new(Mutex::new(Some(rx_receiver))); + + let task1_ready = Arc::new(Notify::new()); + let task1_ready_clone = Arc::clone(&task1_ready); + let tasks_started = Arc::new(Barrier::new(3)); + + // 1. Task 1 initial query: Return a stream connected to tx_sender for error injection. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + // Return a stream connected to tx_sender. + // We will use tx_sender later in the test to inject a transient error. + task1_ready_clone.notify_one(); + match req.transaction.unwrap().selector.unwrap() { + Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin for first query"), + } + let rx = rx_receiver + .try_lock() + .expect("mutex poisoned") + .take() + .unwrap(); + Ok(Response::new(Box::pin(ReceiverStream(rx)))) + }); + + // 2. Task 1 restart query: should include Selector::Begin, since + // it failed with a transient error. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + Selector::Begin(_) => { + let mut rs = setup_select1(); + rs.metadata.as_mut().unwrap().transaction = Some(mock_v1::Transaction { + id: vec![4, 5, 6], + ..Default::default() + }); + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok(rs)])))) + } + _ => panic!("Expected Selector::Begin for stream restart query"), + } + }); + + // 3. Tasks 2 & 3: should include populated Selector::Id + mock.expect_execute_streaming_sql() + .times(2) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( + setup_select1(), + )])))) + } + _ => panic!("Expected Selector::Id for concurrent queries"), + } + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let tx = Arc::new(tx); + + let handle1_tx = Arc::clone(&tx); + let handle1 = tokio::spawn(async move { + let mut rs = handle1_tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + let _ = rs.next().await.ok_or_else(|| { + crate::error::internal_error("stream exhausted (this should never happen)") + })??; + Ok::<_, crate::Error>(rs) + }); + + // Wait for Task 1 to reach the mock and transition the selector to Starting. + task1_ready.notified().await; + + let handle2_tx = Arc::clone(&tx); + let tasks_started2 = Arc::clone(&tasks_started); + let handle2 = tokio::spawn(async move { + tasks_started2.wait().await; + let mut rs = handle2_tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + let _ = rs.next().await.ok_or_else(|| { + crate::error::internal_error("stream exhausted (this should never happen)") + })??; + Ok::<_, crate::Error>(rs) + }); + + let handle3_tx = Arc::clone(&tx); + let tasks_started3 = Arc::clone(&tasks_started); + let handle3 = tokio::spawn(async move { + tasks_started3.wait().await; + let mut rs = handle3_tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + let _ = rs.next().await.ok_or_else(|| { + crate::error::internal_error("stream exhausted (this should never happen)") + })??; + Ok::<_, crate::Error>(rs) + }); + + // Ensure both Tasks 2 and 3 have reached the barrier before proceeding. + tasks_started.wait().await; + + // Flush the scheduler on this single-threaded executor. + // This guarantees that Tasks 2 & 3 run until they both hit the internal + // selector Notify latch and become suspended. + tokio::task::yield_now().await; + + let grpc_status = Status::new(gaxi::grpc::tonic::Code::Unavailable, "transient error"); + tx_sender.send(Err(grpc_status)).await.expect("send failed"); + drop(tx_sender); + + // Collect and verify all results. + // handle.await returns Result, JoinError>. + // The first ? handles the potential JoinError (panic in the task), + // and the second ? handles the Spanner error. + let mut rs1 = handle1.await.expect("Task 1 panicked")?; + let mut rs2 = handle2.await.expect("Task 2 panicked")?; + let mut rs3 = handle3.await.expect("Task 3 panicked")?; + + // Verify that all results have been exhausted. + // (The tasks themselves already successfully read the first row). + assert!(rs1.next().await.is_none(), "Stream 1 should be exhausted"); + assert!(rs2.next().await.is_none(), "Stream 2 should be exhausted"); + assert!(rs3.next().await.is_none(), "Stream 3 should be exhausted"); + + Ok(()) + } + + #[tokio::test] + async fn execute_concurrent_queries_late_arrival_failure() -> anyhow::Result<()> { + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. Initial query fails. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin for first query"), + } + Err(Status::internal("Initial inline-begin failed")) + }); + + // 2. Fallback BeginTransaction RPC also fails. + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::internal("Fallback BeginTransaction failed"))); + + // Any further attempts would panic because we haven't mocked them. + mock.expect_execute_streaming_sql().never(); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + // First query: triggers the failure and transitions the state to Failed. + let err1 = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await + .expect_err("First query should fail"); + assert!( + err1.to_string() + .contains("Fallback BeginTransaction failed") + ); + + // Second query: starts AFTER the failure is already cached. + // It should immediately return the same error without invoking the mock server. + let err2 = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await + .expect_err("Late query should fail immediately"); + assert!( + err2.to_string() + .contains("Fallback BeginTransaction failed") + ); + + Ok(()) + } + + #[tokio::test] + async fn execute_concurrent_reads_inline_begin() -> anyhow::Result<()> { + use crate::client::{KeySet, ReadRequest}; + let mut mock = create_session_mock(); + mock.expect_begin_transaction().never(); + + let mut seq = mockall::Sequence::new(); + let (tx_sender, rx_receiver) = mpsc::channel(1); + let rx_receiver = Arc::new(Mutex::new(Some(rx_receiver))); + + let task1_ready = Arc::new(Notify::new()); + let task1_ready_clone = Arc::clone(&task1_ready); + let tasks_started = Arc::new(Barrier::new(3)); + + // 1. First read: should include Selector::Begin + mock.expect_streaming_read() + .times(1) + .in_sequence(&mut seq) + .returning(move |req| { + task1_ready_clone.notify_one(); + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Begin(_) => {} + _ => panic!("Expected Selector::Begin for first read"), + } + + let rx = rx_receiver + .try_lock() + .expect("mutex poisoned") + .take() + .unwrap(); + Ok(Response::new(Box::pin(ReceiverStream(rx)))) + }); + + // 2. The other reads: should include populated Selector::Id + mock.expect_streaming_read() + .times(2) + .in_sequence(&mut seq) + .returning(move |req| { + let req = req.into_inner(); + match req.transaction.unwrap().selector.unwrap() { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, vec![4, 5, 6]); + } + _ => panic!("Expected Selector::Id for other reads"), + } + + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( + setup_select1(), + )])))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + let tx = Arc::new(tx); + + let read_req = ReadRequest::builder("Table", vec!["Col"]) + .with_keys(KeySet::all()) + .build(); + + // Spawn 3 concurrent reads. + let tx1 = Arc::clone(&tx); + let read1 = read_req.clone(); + let handle1 = tokio::spawn(async move { + let mut rs = tx1.execute_read(read1).await?; + let _ = rs.next().await; + Ok::<_, crate::Error>(rs) + }); + + task1_ready.notified().await; + + let tx2 = Arc::clone(&tx); + let read2 = read_req.clone(); + let tasks_started2 = Arc::clone(&tasks_started); + let handle2 = tokio::spawn(async move { + tasks_started2.wait().await; + let mut rs = tx2.execute_read(read2).await?; + let _ = rs.next().await; + Ok::<_, crate::Error>(rs) + }); + + let tx3 = Arc::clone(&tx); + let read3 = read_req.clone(); + let tasks_started3 = Arc::clone(&tasks_started); + let handle3 = tokio::spawn(async move { + tasks_started3.wait().await; + let mut rs = tx3.execute_read(read3).await?; + let _ = rs.next().await; + Ok::<_, crate::Error>(rs) + }); + + tasks_started.wait().await; + tokio::task::yield_now().await; + + // Provide the transaction ID. + let mut rs = setup_select1(); + rs.metadata.as_mut().unwrap().transaction = Some(mock_v1::Transaction { + id: vec![4, 5, 6], + ..Default::default() + }); + tx_sender.send(Ok(rs)).await.expect("send failed"); + drop(tx_sender); + + let mut rs1 = handle1.await.expect("Task 1 panicked")?; + let mut rs2 = handle2.await.expect("Task 2 panicked")?; + let mut rs3 = handle3.await.expect("Task 3 panicked")?; + + assert!(rs1.next().await.is_none()); + assert!(rs2.next().await.is_none()); + assert!(rs3.next().await.is_none()); + + Ok(()) + } + + #[tokio::test] + async fn execute_inline_begin_idempotent_update() -> anyhow::Result<()> { + let (db_client, _server) = setup_db_client(create_session_mock()).await; + // Access internal state for unit testing. + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let id1 = bytes::Bytes::from_static(b"tx1"); + let id2 = bytes::Bytes::from_static(b"tx2"); + + // 1. Initial update. + tx.context.transaction_selector.update(id1.clone(), None)?; + assert_eq!( + tx.context + .transaction_selector + .selector() + .await? + .id() + .unwrap(), + &id1 + ); + + // 2. Redundant update with same ID should result in an error. + // The implementation explicitly prevents redundant updates to ensure state consistency. + let err1 = tx + .context + .transaction_selector + .update(id1.clone(), None) + .expect_err("Redundant update should fail"); + assert!(err1.to_string().contains("already Started or Failed")); + + // 3. Update with DIFFERENT ID after already Started should also fail. + let err2 = tx + .context + .transaction_selector + .update(id2, None) + .expect_err("Update after Started should fail"); + assert!(err2.to_string().contains("already Started or Failed")); + + Ok(()) + } + + #[tokio::test] + async fn execute_inline_begin_with_transient_failure() -> anyhow::Result<()> { + let mut mock = create_session_mock(); + let mut seq = mockall::Sequence::new(); + + // 1. First attempt fails transiently. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| Err(Status::new(Code::Unavailable, "Transient 1"))); + + // 2. Fallback BeginTransaction succeeds. + mock.expect_begin_transaction() + .times(1) + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(mock_v1::Transaction { + id: vec![7, 8, 9], + ..Default::default() + })) + }); + + // 3. The manual retry of the query (which happens after explicit begin fallback). + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_| { + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok( + setup_select1(), + )])))) + }); + + let (db_client, _server) = setup_db_client(mock).await; + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + let mut rs = tx + .execute_query(Statement::builder("SELECT 1").build()) + .await?; + assert!(rs.next().await.is_some()); + assert!(rs.next().await.is_none()); + + Ok(()) + } } diff --git a/src/spanner/src/read_write_transaction.rs b/src/spanner/src/read_write_transaction.rs index 920d69bfda..dd7f9e0553 100644 --- a/src/spanner/src/read_write_transaction.rs +++ b/src/spanner/src/read_write_transaction.rs @@ -148,7 +148,7 @@ impl ReadWriteTransaction { .into() .into_request() .set_session(self.context.client.session.name.clone()) - .set_transaction(self.context.transaction_selector.selector()) + .set_transaction(self.context.transaction_selector.selector().await?) .set_seqno(seqno); request.request_options = self.context.amend_request_options(request.request_options); @@ -249,7 +249,7 @@ impl ReadWriteTransaction { let request = ExecuteBatchDmlRequest::default() .set_session(self.context.client.session.name.clone()) - .set_transaction(self.context.transaction_selector.selector()) + .set_transaction(self.context.transaction_selector.selector().await?) .set_seqno(seqno) .set_statements(statements) .set_or_clear_request_options( @@ -274,8 +274,8 @@ impl ReadWriteTransaction { } } - pub(crate) fn transaction_id(&self) -> crate::Result { - match &self.context.transaction_selector.selector().selector { + pub(crate) async fn transaction_id(&self) -> crate::Result { + match &self.context.transaction_selector.selector().await?.selector { Some(Selector::Id(id)) => Ok(id.clone()), _ => Err(internal_error("Transaction ID is missing")), } @@ -283,7 +283,7 @@ impl ReadWriteTransaction { /// Commits the transaction. pub(crate) async fn commit(self) -> crate::Result { - let transaction_id = self.transaction_id()?; + let transaction_id = self.transaction_id().await?; let precommit_token = self.context.precommit_token_tracker.get(); let request = CommitRequest::default() .set_session(self.context.client.session.name.clone()) @@ -323,7 +323,7 @@ impl ReadWriteTransaction { /// Rolls back the transaction. pub(crate) async fn rollback(self) -> crate::Result<()> { - let transaction_id = self.transaction_id()?; + let transaction_id = self.transaction_id().await?; let request = RollbackRequest::default() .set_session(self.context.client.session.name.clone()) diff --git a/src/spanner/src/result_set.rs b/src/spanner/src/result_set.rs index 2d775b7412..1e18c4e8c8 100644 --- a/src/spanner/src/result_set.rs +++ b/src/spanner/src/result_set.rs @@ -310,7 +310,7 @@ impl ResultSet { transaction .read_timestamp .and_then(|t| wkt::Timestamp::new(t.seconds, t.nanos).ok()), - ); + )?; } else if let ReadContextTransactionSelector::Lazy(lazy) = selector { let is_started = matches!( &*lazy.lock().expect("transaction state mutex poisoned"), @@ -368,8 +368,16 @@ impl ResultSet { } async fn restart_stream(&mut self) -> crate::Result<()> { + if let Some(s) = &self.transaction_selector { + s.maybe_reset_starting(); + } + // Get the latest transaction selector for this transaction. - let transaction_selector = self.transaction_selector.as_ref().map(|s| s.selector()); + let transaction_selector = if let Some(s) = &self.transaction_selector { + Some(s.selector().await?) + } else { + None + }; match &mut self.operation { StreamOperation::Query(req) => { diff --git a/src/spanner/src/transaction_runner.rs b/src/spanner/src/transaction_runner.rs index e6f5b4da6d..089e6f7d65 100644 --- a/src/spanner/src/transaction_runner.rs +++ b/src/spanner/src/transaction_runner.rs @@ -223,7 +223,7 @@ impl TransactionRunner { let mut current_tx_id = None; let attempt_result = async { let transaction = self.builder.begin_transaction().await?; - current_tx_id = transaction.transaction_id().ok(); + current_tx_id = transaction.transaction_id().await.ok(); let result = match work(transaction.clone()).await { Ok(res) => res, diff --git a/tests/spanner/Cargo.toml b/tests/spanner/Cargo.toml index 30b6625f58..2869708160 100644 --- a/tests/spanner/Cargo.toml +++ b/tests/spanner/Cargo.toml @@ -33,10 +33,13 @@ google-cloud-gax = { workspace = true } google-cloud-lro = { workspace = true } google-cloud-spanner = { workspace = true, features = ["unstable-stream"] } google-cloud-test-utils = { workspace = true } +google-cloud-wkt = { workspace = true } prost-types.workspace = true +rand = { workspace = true } reqwest = { workspace = true, features = ["json"] } serde_json = { workspace = true } spanner-grpc-mock = { path = "../../src/spanner/grpc-mock" } +time = { workspace = true } tokio = { workspace = true, features = ["sync"] } tokio-stream = { workspace = true } tonic = { workspace = true } diff --git a/tests/spanner/src/concurrent_inline_begin.rs b/tests/spanner/src/concurrent_inline_begin.rs new file mode 100644 index 0000000000..c191a4bd73 --- /dev/null +++ b/tests/spanner/src/concurrent_inline_begin.rs @@ -0,0 +1,264 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::client::{get_database_id, get_emulator_host, provision_emulator, update_database_ddl}; +use crate::test_proxy::{InterceptedSpanner, SpannerInterceptor}; +use futures::stream::{self, StreamExt}; +use google_cloud_spanner::client::{ResultSet, Row, Spanner, TimestampBound}; +use google_cloud_test_utils::resource_names::LowercaseAlphanumeric; +use spanner_grpc_mock::google::spanner::v1 as spanner_v1; +use spanner_grpc_mock::google::spanner::v1::spanner_client::SpannerClient; +use std::collections::HashMap; +use std::sync::Arc; +use time::OffsetDateTime; +use tokio::net::TcpListener; +use tokio::sync::{Barrier, Mutex}; +use tokio_stream::wrappers::TcpListenerStream; +use tonic::transport::{Channel, Server}; + +/// An interceptor that injects transient (Unavailable) and permanent (Internal) failures +/// into streaming SQL responses for specific query patterns. +pub struct ConcurrentFaultInterceptor { + emulator_client: SpannerClient, + /// Tracks failure counts to allow transient recovery. + failure_counts: Arc>>, +} + +impl ConcurrentFaultInterceptor { + pub fn new(emulator_client: SpannerClient) -> Self { + Self { + emulator_client, + failure_counts: Arc::new(Mutex::new(HashMap::new())), + } + } +} + +#[tonic::async_trait] +impl SpannerInterceptor for ConcurrentFaultInterceptor { + fn emulator_client(&self) -> SpannerClient { + self.emulator_client.clone() + } + + async fn execute_streaming_sql( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + let sql = request.get_ref().sql.clone(); + + // Emulates a transient stream failure. + if sql.starts_with("SELECT 'Transient-") { + let mut counts = self.failure_counts.lock().await; + let count = counts.entry(sql.clone()).or_insert(0); + if *count == 0 { + *count += 1; + // Return a stream that fails immediately with Unavailable. + let stream = stream::once(async { + Err(tonic::Status::unavailable("Transient stream failure")) + }); + return Ok(tonic::Response::new(stream.boxed())); + } + // Second attempt succeeds (fall through to emulator). + } + + // Emulates a permanent stream failure. + if sql == "SELECT 'Permanent'" { + // Returns a stream that always fails with an Internal error. + let stream = + stream::once(async { Err(tonic::Status::internal("Permanent stream failure")) }); + return Ok(tonic::Response::new(stream.boxed())); + } + + // Forward other queries to the emulator. + let res = self + .emulator_client() + .execute_streaming_sql(request) + .await?; + let (metadata, stream, extensions) = res.into_parts(); + Ok(tonic::Response::from_parts( + metadata, + stream.boxed(), + extensions, + )) + } +} + +pub async fn test_concurrent_inline_begin_with_snapshot_consistency() -> anyhow::Result<()> { + let emulator_host = match get_emulator_host() { + Some(host) => host, + None => return Ok(()), + }; + provision_emulator(&emulator_host).await; + let db_id = get_database_id().await; + let db_path = format!( + "projects/test-project/instances/test-instance/databases/{}", + db_id + ); + + // 1. Setup Table 1 (Exists at snapshot time) + let suffix = LowercaseAlphanumeric.random_string(6); + let table_success = format!("TableSuccess_{}", suffix); + let table_not_found = format!("TableNotFound_{}", suffix); + + let statement = format!("CREATE TABLE {} (Id INT64) PRIMARY KEY (Id)", table_success); + update_database_ddl(statement).await?; + + // 2. Capture snapshot time. + let spanner = Spanner::builder() + .with_endpoint(format!("http://{}", emulator_host)) + .build() + .await?; + let db_client = spanner.database_client(&db_path).build().await?; + + let mut rs: ResultSet = db_client + .single_use() + .build() + .execute_query("SELECT CURRENT_TIMESTAMP") + .await?; + let row: Row = rs.next().await.unwrap().unwrap(); + let snapshot_time: OffsetDateTime = row.try_get(0)?; + + // 3. Setup Table 2 (Does NOT exist at snapshot time) + let statement = format!( + "CREATE TABLE {} (Id INT64) PRIMARY KEY (Id)", + table_not_found + ); + update_database_ddl(statement).await?; + + // 4. Start the Intercepted Server + let listener = TcpListener::bind("127.0.0.1:0").await?; + let local_addr = listener.local_addr()?; + let emulator_channel = Channel::from_shared(format!("http://{}", emulator_host))? + .connect() + .await?; + let interceptor = ConcurrentFaultInterceptor::new(SpannerClient::new(emulator_channel)); + let service = InterceptedSpanner(interceptor); + + tokio::spawn(async move { + Server::builder() + .add_service(spanner_v1::spanner_server::SpannerServer::new(service)) + .serve_with_incoming(TcpListenerStream::new(listener)) + .await + .expect("Server failed"); + }); + + // 5. Build Client pointing to Interceptor + let intercepted_spanner = Spanner::builder() + .with_endpoint(format!("http://{}", local_addr)) + .build() + .await?; + let intercepted_db = intercepted_spanner + .database_client(&db_path) + .build() + .await?; + + // 6. Spawn 20 tasks with random workloads + let tx = intercepted_db + .read_only_transaction() + .with_timestamp_bound(TimestampBound::read_timestamp(snapshot_time)) + .with_explicit_begin_transaction(false) + .build() + .await?; + let tx = Arc::new(tx); + let barrier = Arc::new(Barrier::new(20)); + let mut handles = Vec::new(); + + for i in 0..20 { + let role = rand::random_range(0..4); + let tx = Arc::clone(&tx); + let barrier = Arc::clone(&barrier); + let table_success = table_success.clone(); + let table_not_found = table_not_found.clone(); + + handles.push(tokio::spawn(async move { + barrier.wait().await; + match role { + 0 => { + // Success + let mut result_set: ResultSet = tx + .execute_query(format!("SELECT * FROM {}", table_success)) + .await?; + while let Some(row) = result_set.next().await { + row?; + } + Ok::<_, anyhow::Error>(format!("Task {} Success: OK", i)) + } + 1 => { + // Table not found + let res: Result = tx + .execute_query(format!("SELECT * FROM {}", table_not_found)) + .await; + match res { + Err(e) + if e.to_string().contains("not found") + || e.to_string().contains("NotFound") => + { + Ok(format!("Task {} NotFound: OK", i)) + } + Ok(_) => anyhow::bail!("Task {} expected NotFound but got Success", i), + Err(e) => anyhow::bail!("Task {} expected NotFound but got: {:?}", i, e), + } + } + 2 => { + // Transient stream error. This will trigger a retry of the stream. + let sql = format!("SELECT 'Transient-{}'", i); + let mut result_set: ResultSet = tx.execute_query(sql).await?; + while let Some(row) = result_set.next().await { + row?; + } + Ok(format!("Task {} Transient: OK", i)) + } + 3 => { + // Permanent stream error. + let result_set_res: Result = + tx.execute_query("SELECT 'Permanent'").await; + let mut result_set = match result_set_res { + Ok(rs) => rs, + Err(e) => anyhow::bail!( + "Task {} expected successful RPC initiation but got: {:?}", + i, + e + ), + }; + + let next = result_set.next().await; + match next { + Some(Err(e)) + if e.to_string().contains("Permanent") + || e.to_string().contains("Internal") => + { + Ok(format!("Task {} Permanent: OK", i)) + } + Some(Ok(_)) => { + anyhow::bail!("Task {} expected Permanent error but got a valid row", i) + } + _ => anyhow::bail!( + "Task {} expected Permanent error but succeeded or got empty results", + i + ), + } + } + _ => unreachable!(), + } + })); + } + + for handle in handles { + handle.await??; + } + + Ok(()) +} diff --git a/tests/spanner/src/lib.rs b/tests/spanner/src/lib.rs index d8a95bf9c1..ad0413d66a 100644 --- a/tests/spanner/src/lib.rs +++ b/tests/spanner/src/lib.rs @@ -14,6 +14,7 @@ pub mod batch_read_only_transaction; pub mod client; +pub mod concurrent_inline_begin; pub mod partitioned_dml; pub mod query; pub mod read; diff --git a/tests/spanner/src/query.rs b/tests/spanner/src/query.rs index 7532a80084..58c785c0e1 100644 --- a/tests/spanner/src/query.rs +++ b/tests/spanner/src/query.rs @@ -12,11 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -<<<<<<< HEAD use crate::client::{get_database_id, get_emulator_host}; -======= -use crate::client::{get_database_id, get_emulator_host, get_emulator_rest_endpoint}; ->>>>>>> 255a00797 (test(spanner): add integration test for inline-begin error handling) use crate::test_proxy::{InterceptedSpanner, SpannerInterceptor}; use google_cloud_spanner::client::{DatabaseClient, Kind, Spanner, Statement}; use google_cloud_test_utils::resource_names::LowercaseAlphanumeric; diff --git a/tests/spanner/src/test_proxy.rs b/tests/spanner/src/test_proxy.rs index 7f0a5837a2..f7a07d881f 100644 --- a/tests/spanner/src/test_proxy.rs +++ b/tests/spanner/src/test_proxy.rs @@ -12,9 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +use futures::stream::{BoxStream, StreamExt}; use spanner_grpc_mock::google::spanner::v1 as spanner_v1; use spanner_grpc_mock::google::spanner::v1::spanner_client::SpannerClient; +pub type ExecuteStreamingSqlStream = + BoxStream<'static, std::result::Result>; + #[tonic::async_trait] pub trait SpannerInterceptor: Send + Sync + 'static { fn emulator_client(&self) -> SpannerClient; @@ -66,10 +70,21 @@ pub trait SpannerInterceptor: Send + Sync + 'static { &self, request: tonic::Request, ) -> std::result::Result< - tonic::Response>, + tonic::Response< + BoxStream<'static, std::result::Result>, + >, tonic::Status, > { - self.emulator_client().execute_streaming_sql(request).await + let res = self + .emulator_client() + .execute_streaming_sql(request) + .await?; + let (metadata, stream, extensions) = res.into_parts(); + Ok(tonic::Response::from_parts( + metadata, + stream.boxed(), + extensions, + )) } async fn execute_batch_dml( @@ -91,10 +106,18 @@ pub trait SpannerInterceptor: Send + Sync + 'static { &self, request: tonic::Request, ) -> std::result::Result< - tonic::Response>, + tonic::Response< + BoxStream<'static, std::result::Result>, + >, tonic::Status, > { - self.emulator_client().streaming_read(request).await + let res = self.emulator_client().streaming_read(request).await?; + let (metadata, stream, extensions) = res.into_parts(); + Ok(tonic::Response::from_parts( + metadata, + stream.boxed(), + extensions, + )) } async fn begin_transaction( @@ -136,10 +159,18 @@ pub trait SpannerInterceptor: Send + Sync + 'static { &self, request: tonic::Request, ) -> std::result::Result< - tonic::Response>, + tonic::Response< + BoxStream<'static, std::result::Result>, + >, tonic::Status, > { - self.emulator_client().batch_write(request).await + let res = self.emulator_client().batch_write(request).await?; + let (metadata, stream, extensions) = res.into_parts(); + Ok(tonic::Response::from_parts( + metadata, + stream.boxed(), + extensions, + )) } } @@ -190,7 +221,8 @@ impl spanner_v1::spanner_server::Spanner for InterceptedS self.0.execute_sql(request).await } - type ExecuteStreamingSqlStream = tonic::codec::Streaming; + type ExecuteStreamingSqlStream = + BoxStream<'static, std::result::Result>; async fn execute_streaming_sql( &self, @@ -214,7 +246,8 @@ impl spanner_v1::spanner_server::Spanner for InterceptedS self.0.read(request).await } - type StreamingReadStream = tonic::codec::Streaming; + type StreamingReadStream = + BoxStream<'static, std::result::Result>; async fn streaming_read( &self, @@ -258,7 +291,8 @@ impl spanner_v1::spanner_server::Spanner for InterceptedS self.0.partition_read(request).await } - type BatchWriteStream = tonic::codec::Streaming; + type BatchWriteStream = + BoxStream<'static, std::result::Result>; async fn batch_write( &self, diff --git a/tests/spanner/tests/driver.rs b/tests/spanner/tests/driver.rs index 4d45413e13..ff90e2dd92 100644 --- a/tests/spanner/tests/driver.rs +++ b/tests/spanner/tests/driver.rs @@ -115,4 +115,9 @@ mod spanner { Ok(()) } + + #[tokio::test] + async fn run_concurrent_inline_begin_tests() -> anyhow::Result<()> { + integration_tests_spanner::concurrent_inline_begin::test_concurrent_inline_begin_with_snapshot_consistency().await + } } From 2161268781cfecdc38b2f29a21650adb29dcfbe4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 2 Apr 2026 11:26:16 +0200 Subject: [PATCH 14/17] perf(spanner): inline begin transaction error handling Adds error handling for inline-begin-transaction. If the first statement in a transaction fails, and that statement included a BeginTransaction option, then the transaction has not been started. In order to keep the semantics of the transaction consistent for an 'outside observer', we need to do the following: 1. Catch the error that was thrown by the initial statement. 2. Start the transaction using an explicit BeginTransaction RPC. 3. Retry the initial statement, but now using the transaction ID from step 2. 4. Return the error or result for the retried initial statement. The above makes sure that: 1. The transaction is actually started when the first statement is executed, also when the statement failed. 2. The statement becomes part of the transaction, and the result of the statement is consistent with the read-timestamp of the transaction. The second part is important in order to comply with Spanner's strong consistency guarantees; If for example a statement returns a 'Table not found' error, then that error is only valid for the read timestamp that was used for executing the statement. This is the reason that we retry the statement after the BeginTransaction RPC to be able to return a result that is guaranteed to be consistent with any other queries/reads that will be executed in the same transaction. --- src/spanner/src/read_only_transaction.rs | 2 +- tests/spanner/src/query.rs | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/spanner/src/read_only_transaction.rs b/src/spanner/src/read_only_transaction.rs index 871a6d8624..6c27d100b1 100644 --- a/src/spanner/src/read_only_transaction.rs +++ b/src/spanner/src/read_only_transaction.rs @@ -1405,8 +1405,8 @@ pub(crate) mod tests { async fn inline_begin_read_failure_retry_success() -> anyhow::Result<()> { use crate::client::{KeySet, ReadRequest}; use crate::value::Value; + use gaxi::grpc::tonic::Response; use gaxi::grpc::tonic::Status; - use tonic::Response; let mut mock = create_session_mock(); let mut seq = mockall::Sequence::new(); diff --git a/tests/spanner/src/query.rs b/tests/spanner/src/query.rs index 58c785c0e1..8625543115 100644 --- a/tests/spanner/src/query.rs +++ b/tests/spanner/src/query.rs @@ -275,10 +275,7 @@ pub async fn multi_use_read_only_transaction_invalid_query_fallback( .execute_query(Statement::builder("SELECT * FROM NonExistentTable").build()) .await; - assert!( - rs_result.is_err(), - "Expected an error from an invalid query" - ); + assert!(rs_result.is_err(), "Expected an error from an invalid query"); // The read timestamp should now be available because the transaction // fell back to an explicit BeginTransaction. From 09c1c9c0ef2d704850fe6cc3cca6295e08a74a33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 2 Apr 2026 20:49:19 +0200 Subject: [PATCH 15/17] test(spanner): add integration test for inline-begin error handling Adds an integration test for error handling for inline-begin-transaction. This test uses a gRPC proxy to intercept calls from the client to Spanner to be able to deterministically emulate specific concurrency issues. This test shows how a query that failed during the first attempt, and thereby also failed to start the transaction, could succeed during a retry after the transaction has been started with an explicit BeginTransaction RPC. --- tests/spanner/src/query.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/spanner/src/query.rs b/tests/spanner/src/query.rs index 8625543115..58c785c0e1 100644 --- a/tests/spanner/src/query.rs +++ b/tests/spanner/src/query.rs @@ -275,7 +275,10 @@ pub async fn multi_use_read_only_transaction_invalid_query_fallback( .execute_query(Statement::builder("SELECT * FROM NonExistentTable").build()) .await; - assert!(rs_result.is_err(), "Expected an error from an invalid query"); + assert!( + rs_result.is_err(), + "Expected an error from an invalid query" + ); // The read timestamp should now be available because the transaction // fell back to an explicit BeginTransaction. From 5cad64d67fc9cc6c784b65e94b11ccb5332baa99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 3 Apr 2026 18:10:18 +0200 Subject: [PATCH 16/17] perf(spanner): support concurrent queries with inline-begin-transaction Adds support for running concurrent queries in combination with inline-begin-transaction. Only one of the queries will include the BeginTransaction option. The other queries will wait until the first query has returned a transaction ID. --- src/spanner/src/read_only_transaction.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spanner/src/read_only_transaction.rs b/src/spanner/src/read_only_transaction.rs index 6c27d100b1..871a6d8624 100644 --- a/src/spanner/src/read_only_transaction.rs +++ b/src/spanner/src/read_only_transaction.rs @@ -1405,8 +1405,8 @@ pub(crate) mod tests { async fn inline_begin_read_failure_retry_success() -> anyhow::Result<()> { use crate::client::{KeySet, ReadRequest}; use crate::value::Value; - use gaxi::grpc::tonic::Response; use gaxi::grpc::tonic::Status; + use tonic::Response; let mut mock = create_session_mock(); let mut seq = mockall::Sequence::new(); From 0b12ac03438ffa180abb4937c01bfe7ec36b38c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Sat, 4 Apr 2026 11:43:30 +0200 Subject: [PATCH 17/17] fix(spanner): prevent deadlock if ResultSet#next() is never called If a query included a BeginTransaction option and the application never called ResultSet#next(), then the transaction ID would never be returned. This would block any other query from using the transaction. This change refactors ResultSet to use a background worker to read from the stream. This prevents that a deadlock can happen if the application does not call ResultSet#next(). It also allows the application to call ResultSet#metadata() without first calling ResultSet#next(). Finally, it also allows the ResultSet to decode data from the server asynchronously while the application processes rows that it has already read from the ResultSet. --- src/spanner/src/result_set.rs | 556 +++++++++++++++++++------ src/spanner/src/result_set_metadata.rs | 4 +- tests/spanner/src/query.rs | 18 +- tests/spanner/src/write.rs | 1 + 4 files changed, 450 insertions(+), 129 deletions(-) diff --git a/src/spanner/src/result_set.rs b/src/spanner/src/result_set.rs index 1e18c4e8c8..02a2333ec4 100644 --- a/src/spanner/src/result_set.rs +++ b/src/spanner/src/result_set.rs @@ -25,6 +25,10 @@ use gaxi::prost::FromProto; use google_cloud_gax::error::rpc::Code; use std::collections::VecDeque; use std::mem::take; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use tokio::sync::mpsc; +use tokio::sync::watch; #[cfg(feature = "unstable-stream")] use futures::Stream; @@ -44,11 +48,20 @@ use futures::Stream; /// ``` #[derive(Debug)] pub struct ResultSet { + receiver: mpsc::Receiver>, + metadata: watch::Receiver>, + // This field is only modified in tests to set a small buffer size. + #[allow(dead_code)] + max_buffered_partial_result_sets: Arc, +} + +#[derive(Debug)] +struct ResultSetWorker { stream: PartialResultSetStream, buffered_values: Vec, chunked: bool, ready_rows: VecDeque, - metadata: Option, + metadata: watch::Sender>, precommit_token_tracker: PrecommitTokenTracker, // Fields for retries and buffering of a stream of PartialResultSets. @@ -57,7 +70,7 @@ pub struct ResultSet { last_resume_token: Bytes, partial_result_sets_buffer: VecDeque, safe_to_retry: bool, - max_buffered_partial_result_sets: usize, + max_buffered_partial_result_sets: Arc, retry_count: usize, transaction_selector: Option, } @@ -91,21 +104,33 @@ impl ResultSet { client: DatabaseClient, operation: StreamOperation, ) -> Self { - Self { + let (sender, receiver) = mpsc::channel(4); + let (metadata_sender, metadata_receiver) = watch::channel(None); + let max_buffered_partial_result_sets = + Arc::new(AtomicUsize::new(MAX_BUFFERED_PARTIAL_RESULT_SETS)); + + let mut worker = ResultSetWorker::new( stream, - buffered_values: Vec::new(), - chunked: false, - ready_rows: VecDeque::new(), - metadata: None, + transaction_selector, precommit_token_tracker, client, operation, - last_resume_token: Bytes::new(), - partial_result_sets_buffer: VecDeque::new(), - safe_to_retry: true, - max_buffered_partial_result_sets: MAX_BUFFERED_PARTIAL_RESULT_SETS, - retry_count: 0, - transaction_selector, + metadata_sender, + Arc::clone(&max_buffered_partial_result_sets), + ); + + tokio::spawn(async move { + while let Some(row) = worker.next().await { + if sender.send(row).await.is_err() { + break; // Receiver dropped + } + } + }); + + Self { + receiver, + metadata: metadata_receiver, + max_buffered_partial_result_sets, } } @@ -114,21 +139,29 @@ impl ResultSet { /// # Example /// ``` /// # use google_cloud_spanner::client::{ResultSet, Row}; - /// # async fn fetch_metadata(mut rs: ResultSet) -> Result<(), Box> { - /// if let Some(row) = rs.next().await.transpose()? { - /// let metadata = rs.metadata()?; - /// for column in metadata.column_names() { - /// println!("Column name: {}", column); - /// } + /// # async fn fetch_metadata(mut result_set: ResultSet) -> Result<(), Box> { + /// let metadata = result_set.metadata().await?; + /// for column in metadata.column_names() { + /// println!("Column name: {}", column); /// } /// # Ok(()) /// # } /// ``` /// - /// The metadata is only available after the first call to [`next`](Self::next). - /// If called before the first `next()` call, it returns a [`ResultSetError::MetadataNotAvailable`] error. - pub fn metadata(&self) -> Result { - self.metadata + /// This method blocks until the metadata is available, which is after the + /// first chunk is received from the server. If the stream ends or fails + /// before metadata is available, it returns [`ResultSetError::MetadataNotAvailable`]. + pub async fn metadata(&self) -> Result { + let mut receiver = self.metadata.clone(); + if let Some(metadata) = &*receiver.borrow() { + return Ok(metadata.clone()); + } + receiver + .changed() + .await + .map_err(|_| ResultSetError::MetadataNotAvailable)?; + receiver + .borrow() .clone() .ok_or(ResultSetError::MetadataNotAvailable) } @@ -148,6 +181,73 @@ impl ResultSet { /// /// Returns `None` when all rows have been retrieved. pub async fn next(&mut self) -> Option> { + self.receiver.recv().await + } + + /// Converts the [`ResultSet`] into a [`Stream`]. + /// + /// # Example + /// + /// ``` + /// # use google_cloud_spanner::client::ResultSet; + /// # use futures::TryStreamExt; + /// # use std::future::ready; + /// # async fn example(result_set: ResultSet) -> Result<(), google_cloud_spanner::Error> { + /// let rows: Vec<_> = result_set + /// .into_stream() + /// .try_filter(|row| { + /// let id = row.get::("Id"); + /// ready(id == "id1") + /// }) + /// .try_collect() + /// .await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// This consumes the [`ResultSet`] and returns a stream of rows. + #[cfg(feature = "unstable-stream")] + pub fn into_stream(self) -> impl Stream> + Unpin { + use futures::stream::unfold; + Box::pin(unfold(self, |mut result_set| async move { + result_set.next().await.map(|row| (row, result_set)) + })) + } +} + +impl ResultSetWorker { + /// Creates a new result set worker. + pub(crate) fn new( + stream: PartialResultSetStream, + transaction_selector: Option, + precommit_token_tracker: PrecommitTokenTracker, + client: DatabaseClient, + operation: StreamOperation, + metadata: watch::Sender>, + max_buffered_partial_result_sets: Arc, + ) -> Self { + Self { + stream, + buffered_values: Vec::new(), + chunked: false, + ready_rows: VecDeque::new(), + metadata, + precommit_token_tracker, + client, + operation, + last_resume_token: Bytes::new(), + partial_result_sets_buffer: VecDeque::new(), + safe_to_retry: true, + max_buffered_partial_result_sets, + retry_count: 0, + transaction_selector, + } + } + + /// Fetches the next row from the result set. + /// + /// Returns `None` when all rows have been retrieved. + pub(crate) async fn next(&mut self) -> Option> { if let Some(row) = self.ready_rows.pop_front() { return Some(Ok(row)); } @@ -210,7 +310,11 @@ impl ResultSet { // The PartialResultSet did not have a resume_token. Buffer the result // and continue with the next PartialResultSet, unless the buffer is full. - if self.partial_result_sets_buffer.len() >= self.max_buffered_partial_result_sets { + if self.partial_result_sets_buffer.len() + >= self + .max_buffered_partial_result_sets + .load(Ordering::Relaxed) + { // Mark this stream as 'unsafe to retry', meaning that any transient error // that we see will not be retried. We will instead propagate the error. self.safe_to_retry = false; @@ -290,37 +394,47 @@ impl ResultSet { &mut self, partial_result_set: PartialResultSet, ) -> crate::Result<()> { - match (self.metadata.as_ref(), partial_result_set.metadata) { - (Some(_), None) => {} - (None, None) => { - return Err(internal_error( - "First PartialResultSet did not contain metadata", - )); - } - (Some(_), Some(_)) => { - return Err(internal_error("Additional metadata after first result set")); + let update_selector = { + let metadata_ref = self.metadata.borrow(); + match (&*metadata_ref, partial_result_set.metadata) { + (Some(_), None) => None, + (None, None) => { + return Err(internal_error( + "First PartialResultSet did not contain metadata", + )); + } + (Some(_), Some(_)) => { + return Err(internal_error("Additional metadata after first result set")); + } + (None, Some(mut m)) => { + let transaction = m.transaction.take(); + Some((ResultSetMetadata::new(Some(m)), transaction)) + } } - (None, Some(mut m)) => { - let transaction = m.transaction.take(); - self.metadata = Some(ResultSetMetadata::new(Some(m))); - if let Some(selector) = &self.transaction_selector { - if let Some(transaction) = transaction { - selector.update( - transaction.id, - transaction - .read_timestamp - .and_then(|t| wkt::Timestamp::new(t.seconds, t.nanos).ok()), - )?; - } else if let ReadContextTransactionSelector::Lazy(lazy) = selector { - let is_started = matches!( - &*lazy.lock().expect("transaction state mutex poisoned"), - crate::read_only_transaction::TransactionState::Started(_, _) - ); - if !is_started { - return Err(internal_error( - "Spanner failed to return a transaction ID for a query that included a BeginTransaction option", - )); - } + }; + + if let Some((metadata, transaction)) = update_selector { + self.metadata + .send(Some(metadata)) + .map_err(|_| internal_error("Failed to send metadata"))?; + + if let Some(selector) = &self.transaction_selector { + if let Some(transaction) = transaction { + selector.update( + transaction.id, + transaction + .read_timestamp + .and_then(|t| wkt::Timestamp::new(t.seconds, t.nanos).ok()), + )?; + } else if let ReadContextTransactionSelector::Lazy(lazy) = selector { + let is_started = matches!( + &*lazy.lock().expect("transaction state mutex poisoned"), + crate::read_only_transaction::TransactionState::Started(_, _) + ); + if !is_started { + return Err(internal_error( + "Spanner failed to return a transaction ID for a query that included a BeginTransaction option", + )); } } } @@ -329,7 +443,8 @@ impl ResultSet { if partial_result_set.values.is_empty() { return Ok(()); } - let metadata = self.metadata.as_ref().unwrap(); + + let metadata = self.metadata.borrow().as_ref().unwrap().clone(); if metadata.column_types.is_empty() { return Err(internal_error( "PartialResultSet contained values but no column metadata was provided", @@ -418,36 +533,6 @@ impl ResultSet { e.status() .is_some_and(|status| status.code == Code::Unavailable) } - - /// Converts the [`ResultSet`] into a [`Stream`]. - /// - /// # Example - /// - /// ``` - /// # use google_cloud_spanner::client::ResultSet; - /// # use futures::TryStreamExt; - /// # use std::future::ready; - /// # async fn example(result_set: ResultSet) -> Result<(), google_cloud_spanner::Error> { - /// let rows: Vec<_> = result_set - /// .into_stream() - /// .try_filter(|row| { - /// let id = row.get::("Id"); - /// ready(id == "id1") - /// }) - /// .try_collect() - /// .await?; - /// # Ok(()) - /// # } - /// ``` - /// - /// This consumes the [`ResultSet`] and returns a stream of rows. - #[cfg(feature = "unstable-stream")] - pub fn into_stream(self) -> impl Stream> + Unpin { - use futures::stream::unfold; - Box::pin(unfold(self, |mut result_set| async move { - result_set.next().await.map(|row| (row, result_set)) - })) - } } /// Merges two values from successive `PartialResultSet`s into a single value. @@ -505,7 +590,8 @@ fn merge_values(target: &mut prost_types::Value, source: prost_types::Value) -> #[cfg(test)] impl ResultSet { pub(crate) fn set_max_buffered_partial_result_sets(&mut self, limit: usize) { - self.max_buffered_partial_result_sets = limit; + self.max_buffered_partial_result_sets + .store(limit, Ordering::Relaxed); } } @@ -624,19 +710,8 @@ pub(crate) mod tests { }]) .await; - // Called before next() -> returns MetadataNotAvailable - let meta_err = rs.metadata(); - assert!(meta_err.is_err()); - assert!(matches!( - meta_err.unwrap_err(), - ResultSetError::MetadataNotAvailable - )); - - // Advance to fetch metadata - let _next = rs.next().await.expect("Expected a row")?; - - // Called after next() -> returns metadata - let meta = rs.metadata(); + // Called before next() -> blocks and returns metadata + let meta = rs.metadata().await; assert!(meta.is_ok()); let meta = meta.unwrap(); assert_eq!( @@ -644,6 +719,9 @@ pub(crate) mod tests { &["col0".to_string(), "col1".to_string()] ); + // Now consume the row + let _next = rs.next().await.expect("Expected a row")?; + Ok(()) } @@ -1107,21 +1185,62 @@ pub(crate) mod tests { } #[tokio::test] - async fn test_result_set_precommit_token_tracked() { - let mut rs = run_mock_query(vec![PartialResultSet { - metadata: metadata(1), - precommit_token: Some( - spanner_grpc_mock::google::spanner::v1::MultiplexedSessionPrecommitToken { - precommit_token: b"test_token".to_vec(), - seq_num: 99, - }, - ), - ..Default::default() - }]) - .await; + async fn test_result_set_precommit_token_tracked() -> anyhow::Result<()> { + let mut mock = MockSpanner::new(); + mock.expect_execute_streaming_sql() + .returning(move |_request| { + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: metadata(1), + precommit_token: Some( + spanner_grpc_mock::google::spanner::v1::MultiplexedSessionPrecommitToken { + precommit_token: b"test_token".to_vec(), + seq_num: 99, + }, + ), + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; - // Force tracking mode since run_mock_query uses a ReadOnly transaction (NoOp). - rs.precommit_token_tracker = PrecommitTokenTracker::new(); + let db_client = client.database_client("db").build().await?; + + let req = crate::model::ExecuteSqlRequest::default() + .set_session(db_client.session.name.clone()) + .set_sql("SELECT 1".to_string()); + + let stream = db_client + .spanner + .execute_streaming_sql(req.clone(), crate::RequestOptions::default()) + .send() + .await?; + + let tracker = PrecommitTokenTracker::new(); // Track mode! + + let mut rs = ResultSet::new( + stream, + None, + tracker.clone(), + db_client.clone(), + StreamOperation::Query(req), + ); // Read a row to trigger precommit token extraction assert!( @@ -1130,12 +1249,11 @@ pub(crate) mod tests { ); // Validate the tracker correctly intercepted and preserved the token - let token = rs - .precommit_token_tracker - .get() - .expect("token should be tracked"); + let token = tracker.get().expect("token should be tracked"); assert_eq!(token.seq_num, 99); assert_eq!(token.precommit_token, bytes::Bytes::from("test_token")); + + Ok(()) } #[tokio::test] @@ -1923,4 +2041,208 @@ pub(crate) mod tests { Ok(()) } + + #[tokio::test] + async fn test_lazy_begin_deadlock_fixed() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::google::spanner::v1 as mock_v1; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + let mut seq = mockall::Sequence::new(); + + // Setup mock to return metadata with transaction ID on first query. + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|_request| { + let mut meta = metadata(1).expect("failed to create metadata"); + meta.transaction = Some(mock_v1::Transaction { + id: b"lazy_tx_id".to_vec(), + ..Default::default() + }); + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: Some(meta), + values: vec![string_val("1")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + // Mock call for second query which must carry the returned transaction ID + mock.expect_execute_streaming_sql() + .times(1) + .in_sequence(&mut seq) + .returning(|req| { + let req = req.into_inner(); + let selector = req + .transaction + .expect("missing transaction component") + .selector + .expect("missing selector component"); + + match selector { + mock_v1::transaction_selector::Selector::Id(id) => { + assert_eq!(id, b"lazy_tx_id".to_vec()); + } + _ => panic!("Expected Selector::Id"), + } + + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: metadata(1), + values: vec![string_val("2")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + + // Use inline begin transaction + let tx = db_client + .read_only_transaction() + .with_explicit_begin_transaction(false) + .build() + .await?; + + // Execute query but DO NOT call rs.next() + let _rs = tx.execute_query("SELECT 1").await?; + + // Execute second query against same transaction + let mut rs2 = tx.execute_query("SELECT 2").await?; + + // Assert it does not hang and yielded elements properly + let row2 = rs2.next().await; + assert!( + row2.is_some(), + "Implicit deadlock encountered; query 2 stalled!" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_result_set_metadata_not_available() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use gaxi::grpc::tonic::Status; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + + // Setup mock to return a stream that fails immediately. + mock.expect_execute_streaming_sql().returning(|_request| { + let stream = tokio_stream::iter(vec![Err(Status::internal("Internal error"))]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + use spanner_grpc_mock::google::spanner::v1::Session; + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + let tx = db_client.single_use().build(); + + let rs = tx.execute_query("SELECT 1").await?; + + // Call metadata() immediately. It should fail because the stream ends without metadata. + let result = rs.metadata().await; + assert!(result.is_err(), "Expected error but got Ok"); + assert!( + matches!(result.unwrap_err(), ResultSetError::MetadataNotAvailable), + "Expected MetadataNotAvailable error" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_result_set_metadata_available_before_next() -> anyhow::Result<()> { + use gaxi::grpc::tonic::Response; + use spanner_grpc_mock::MockSpanner; + use spanner_grpc_mock::start; + + let mut mock = MockSpanner::new(); + + // Setup mock to return metadata in first chunk. + mock.expect_execute_streaming_sql().returning(|_request| { + let stream = tokio_stream::iter(vec![Ok(PartialResultSet { + metadata: metadata(1), + values: vec![string_val("1")], + ..Default::default() + })]); + Ok(Response::new( + Box::pin(stream) as ::ExecuteStreamingSqlStream, + )) + }); + + mock.expect_create_session().returning(|_| { + use spanner_grpc_mock::google::spanner::v1::Session; + Ok(Response::new(Session { + name: "session".to_string(), + multiplexed: true, + ..Default::default() + })) + }); + + let (address, _server) = start("127.0.0.1:0", mock).await?; + + let client: Spanner = Spanner::builder() + .with_endpoint(address) + .with_credentials(Anonymous::new().build()) + .build() + .await?; + + let db_client = client.database_client("db").build().await?; + let tx = db_client.single_use().build(); + + let mut rs = tx.execute_query("SELECT 1").await?; + + // Call metadata() BEFORE next(). It should succeed. + let metadata = rs.metadata().await?; + assert_eq!(metadata.column_names().len(), 1); + assert_eq!(metadata.column_names()[0], "col0"); + + // Now consume the row + let row = rs.next().await; + assert!(row.is_some()); + + Ok(()) + } } diff --git a/src/spanner/src/result_set_metadata.rs b/src/spanner/src/result_set_metadata.rs index 2a6cd2bf1e..3e48cf2b09 100644 --- a/src/spanner/src/result_set_metadata.rs +++ b/src/spanner/src/result_set_metadata.rs @@ -26,9 +26,7 @@ use std::sync::Arc; /// let tx = db.single_use().build(); /// let mut rs = tx.execute_query(Statement::builder("SELECT 1 AS Number").build()).await?; /// -/// // Metadata is available after the first `next` call -/// let _ = rs.next().await.transpose()?; -/// let metadata = rs.metadata()?; +/// let metadata = rs.metadata().await?; /// /// for (name, type_) in metadata.column_names().iter().zip(metadata.column_types().iter()) { /// println!("Column: {} has type: {:?}", name, type_.code()); diff --git a/tests/spanner/src/query.rs b/tests/spanner/src/query.rs index 58c785c0e1..badb161f37 100644 --- a/tests/spanner/src/query.rs +++ b/tests/spanner/src/query.rs @@ -158,14 +158,14 @@ pub async fn result_set_metadata(db_client: &DatabaseClient) -> anyhow::Result<( // 1. Simple normal query let sql = "SELECT 1 as num, 'Alice' as name"; - let mut rs = rot.execute_query(Statement::builder(sql).build()).await?; + let mut result_set = rot.execute_query(Statement::builder(sql).build()).await?; - assert!(rs.next().await.transpose()?.is_some()); - let metadata = rs.metadata()?; + let metadata = result_set.metadata().await?; assert_eq!( metadata.column_names(), &["num".to_string(), "name".to_string()] ); + assert!(result_set.next().await.transpose()?.is_some()); // 2. Query that returns zero rows let sql_zero_rows = r#" @@ -174,25 +174,25 @@ pub async fn result_set_metadata(db_client: &DatabaseClient) -> anyhow::Result<( ) SELECT num, name FROM Data WHERE 1=0 "#; - let mut rs_zero_rows = rot + let mut result_set_zero_rows = rot .execute_query(Statement::builder(sql_zero_rows).build()) .await?; - assert!(rs_zero_rows.next().await.transpose()?.is_none()); - let metadata_zero_rows = rs_zero_rows.metadata()?; + let metadata_zero_rows = result_set_zero_rows.metadata().await?; assert_eq!( metadata_zero_rows.column_names(), &["num".to_string(), "name".to_string()] ); + assert!(result_set_zero_rows.next().await.transpose()?.is_none()); // 3. Query with duplicate aliases let sql_dup = "SELECT 1 as dup, 2 as dup"; - let mut rs_dup = rot + let mut result_set_dup = rot .execute_query(Statement::builder(sql_dup).build()) .await?; - let row_dup = rs_dup.next().await.transpose()?.unwrap(); - let metadata_dup = rs_dup.metadata()?; + let row_dup = result_set_dup.next().await.transpose()?.unwrap(); + let metadata_dup = result_set_dup.metadata().await?; assert_eq!( metadata_dup.column_names(), &["dup".to_string(), "dup".to_string()] diff --git a/tests/spanner/src/write.rs b/tests/spanner/src/write.rs index cd2f9ae485..f0007df60b 100644 --- a/tests/spanner/src/write.rs +++ b/tests/spanner/src/write.rs @@ -526,6 +526,7 @@ async fn write_internal( let metadata = rs .metadata() + .await .expect("result set metadata is unexpectedly missing"); let column_count = metadata.column_names().len(); assert_eq!(row2.raw_values().len(), column_count);