diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs index b2abfb0c17b2..a848d2a213ef 100644 --- a/arrow-flight/src/client.rs +++ b/arrow-flight/src/client.rs @@ -34,7 +34,7 @@ use futures::{ FutureExt, Stream, StreamExt, TryStreamExt, }; use prost::Message; -use tonic::{metadata::MetadataMap, transport::Channel}; +use tonic::{metadata::MetadataMap, transport::Channel, Status}; use crate::error::{FlightError, Result}; @@ -418,8 +418,7 @@ impl FlightClient { /// // encode the batch as a stream of `FlightData` /// let flight_data_stream = FlightDataEncoderBuilder::new() /// .build(futures::stream::iter(vec![Ok(batch)])) - /// // data encoder return Results, but do_exchange requires FlightData - /// .map(|batch|batch.unwrap()); + /// ; /// /// // send the stream and get the results as `RecordBatches` /// let response: Vec = client @@ -431,20 +430,40 @@ impl FlightClient { /// .expect("error calling do_exchange"); /// # } /// ``` - pub async fn do_exchange + Send + 'static>( + pub async fn do_exchange> + Send + 'static>( &mut self, request: S, ) -> Result { - let request = self.make_request(request); + let (sender, mut receiver) = futures::channel::oneshot::channel(); - let response = self - .inner - .do_exchange(request) - .await? - .into_inner() - .map_err(FlightError::Tonic); + // Intercepts client errors and sends them to the oneshot channel above + let mut request = Box::pin(request); // Pin to heap + let mut sender = Some(sender); // Wrap into Option so can be taken + let request_stream = futures::stream::poll_fn(move |cx| { + Poll::Ready(match ready!(request.poll_next_unpin(cx)) { + Some(Ok(data)) => Some(data), + Some(Err(e)) => { + let _ = sender.take().unwrap().send(e); + None + } + None => None, + }) + }); + + let request = self.make_request(request_stream); + let mut response_stream = self.inner.do_exchange(request).await?.into_inner(); - Ok(FlightRecordBatchStream::new_from_flight_data(response)) + // Forwards errors from the error oneshot with priority over responses from server + let error_stream = futures::stream::poll_fn(move |cx| -> Poll>> { + if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) { + return Poll::Ready(Some(Err(err))); + } + let next: Option> = ready!(response_stream.poll_next_unpin(cx)); + Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic))) + }); + + // combine the response from the server and any error from the client + Ok(FlightRecordBatchStream::new_from_flight_data(error_stream)) } /// Make a `ListFlights` call to the server with the provided diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs index 47565334cb63..32551c8682d6 100644 --- a/arrow-flight/tests/client.rs +++ b/arrow-flight/tests/client.rs @@ -493,7 +493,7 @@ async fn test_do_exchange() { .set_do_exchange_response(output_flight_data.clone().into_iter().map(Ok).collect()); let response_stream = client - .do_exchange(futures::stream::iter(input_flight_data.clone())) + .do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok)) .await .expect("error making request"); @@ -528,7 +528,7 @@ async fn test_do_exchange_error() { let input_flight_data = test_flight_data().await; let response = client - .do_exchange(futures::stream::iter(input_flight_data.clone())) + .do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok)) .await; let response = match response { Ok(_) => panic!("unexpected success"), @@ -572,7 +572,7 @@ async fn test_do_exchange_error_stream() { test_server.set_do_exchange_response(response); let response_stream = client - .do_exchange(futures::stream::iter(input_flight_data.clone())) + .do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok)) .await .expect("error making request");