Skip to content

Commit

Permalink
Fallible stream for arrow-flight do_exchange call (#3462)
Browse files Browse the repository at this point in the history
Signed-off-by: Praveen Kumar <praveen@bit2byte.net>
  • Loading branch information
bitpacker committed Apr 27, 2024
1 parent a61f1dc commit 283a7f1
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 15 deletions.
43 changes: 31 additions & 12 deletions arrow-flight/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use futures::{
FutureExt, Stream, StreamExt, TryStreamExt,
};
use prost::Message;
use tonic::{metadata::MetadataMap, transport::Channel};
use tonic::{metadata::MetadataMap, transport::Channel, Status};

use crate::error::{FlightError, Result};

Expand Down Expand Up @@ -418,8 +418,7 @@ impl FlightClient {
/// // encode the batch as a stream of `FlightData`
/// let flight_data_stream = FlightDataEncoderBuilder::new()
/// .build(futures::stream::iter(vec![Ok(batch)]))
/// // data encoder return Results, but do_exchange requires FlightData
/// .map(|batch|batch.unwrap());
/// ;
///
/// // send the stream and get the results as `RecordBatches`
/// let response: Vec<RecordBatch> = client
Expand All @@ -431,20 +430,40 @@ impl FlightClient {
/// .expect("error calling do_exchange");
/// # }
/// ```
pub async fn do_exchange<S: Stream<Item = FlightData> + Send + 'static>(
pub async fn do_exchange<S: Stream<Item = Result<FlightData>> + Send + 'static>(
&mut self,
request: S,
) -> Result<FlightRecordBatchStream> {
let request = self.make_request(request);
let (sender, mut receiver) = futures::channel::oneshot::channel();

let response = self
.inner
.do_exchange(request)
.await?
.into_inner()
.map_err(FlightError::Tonic);
// Intercepts client errors and sends them to the oneshot channel above
let mut request = Box::pin(request); // Pin to heap
let mut sender = Some(sender); // Wrap into Option so can be taken
let request_stream = futures::stream::poll_fn(move |cx| {
Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
Some(Ok(data)) => Some(data),
Some(Err(e)) => {
let _ = sender.take().unwrap().send(e);
None
}
None => None,
})
});

let request = self.make_request(request_stream);
let mut response_stream = self.inner.do_exchange(request).await?.into_inner();

Ok(FlightRecordBatchStream::new_from_flight_data(response))
// Forwards errors from the error oneshot with priority over responses from server
let error_stream = futures::stream::poll_fn(move |cx| -> Poll<Option<Result<FlightData>>> {
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
return Poll::Ready(Some(Err(err)));
}
let next: Option<std::result::Result<FlightData, Status>> = ready!(response_stream.poll_next_unpin(cx));
Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
});

// combine the response from the server and any error from the client
Ok(FlightRecordBatchStream::new_from_flight_data(error_stream))
}

/// Make a `ListFlights` call to the server with the provided
Expand Down
6 changes: 3 additions & 3 deletions arrow-flight/tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ async fn test_do_exchange() {
.set_do_exchange_response(output_flight_data.clone().into_iter().map(Ok).collect());

let response_stream = client
.do_exchange(futures::stream::iter(input_flight_data.clone()))
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
.await
.expect("error making request");

Expand Down Expand Up @@ -528,7 +528,7 @@ async fn test_do_exchange_error() {
let input_flight_data = test_flight_data().await;

let response = client
.do_exchange(futures::stream::iter(input_flight_data.clone()))
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
.await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Expand Down Expand Up @@ -572,7 +572,7 @@ async fn test_do_exchange_error_stream() {
test_server.set_do_exchange_response(response);

let response_stream = client
.do_exchange(futures::stream::iter(input_flight_data.clone()))
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
.await
.expect("error making request");

Expand Down

0 comments on commit 283a7f1

Please sign in to comment.