Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement FlightSQL spec change to support stateless prepared statements #5433

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion arrow-flight/examples/flight_sql_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use arrow_flight::sql::server::PeekableFlightDataStream;
use arrow_flight::sql::DoPutPreparedStatementResult;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use futures::{stream, Stream, TryStreamExt};
Expand Down Expand Up @@ -619,7 +620,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
&self,
_query: CommandPreparedStatementQuery,
_request: Request<PeekableFlightDataStream>,
) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
) -> Result<DoPutPreparedStatementResult, Status> {
Err(Status::unimplemented(
"do_put_prepared_statement_query not implemented",
))
Expand Down
20 changes: 20 additions & 0 deletions arrow-flight/src/sql/arrow.flight.protocol.sql.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 27 additions & 5 deletions arrow-flight/src/sql/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ use crate::sql::{
CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys,
CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo,
CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery,
CommandStatementUpdate, DoPutUpdateResult, ProstMessageExt, SqlInfo,
CommandStatementUpdate, DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt,
SqlInfo,
};
use crate::trailers::extract_lazy_trailers;
use crate::{
Expand Down Expand Up @@ -501,6 +502,7 @@ impl PreparedStatement<Channel> {
}

/// Submit parameters to the server, if any have been set on this prepared statement instance
/// Updates our stored prepared statement handle with the handle given by the server response.
async fn write_bind_params(&mut self) -> Result<(), ArrowError> {
if let Some(ref params_batch) = self.parameter_binding {
let cmd = CommandPreparedStatementQuery {
Expand All @@ -519,17 +521,37 @@ impl PreparedStatement<Channel> {
.await
.map_err(flight_error_to_arrow_error)?;

self.flight_sql_client
// Attempt to update the stored handle with any updated handle in the DoPut result.
// Not all servers support this, so ignore any errors when attempting to decode.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says errors are ignored, but the code doesn't seem to ignore errors. I wonder if I am misreading this or if the comment or code should be updated 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says errors are ignored, but the code doesn't seem to ignore errors. I wonder if I am misreading this or if the comment or code should be updated 🤔

comments never lie ;)

updated this to explain that we ignore the lack of a response from legacy servers, rather than any error.

if let Some(result) = self
.flight_sql_client
.do_put(stream::iter(flight_data))
.await?
.try_collect::<Vec<_>>()
.message()
.await
.map_err(status_to_arrow_error)?;
.map_err(status_to_arrow_error)?
{
if let Some(handle) = self.unpack_prepared_statement_handle(&result)? {
self.handle = handle;
}
}
}

Ok(())
}

/// Decodes the app_metadata stored in a [`PutResult`] as a
/// [`DoPutPreparedStatementResult`] and then returns
/// the inner prepared statement handle as [`Bytes`]
fn unpack_prepared_statement_handle(
&self,
put_result: &PutResult,
) -> Result<Option<Bytes>, ArrowError> {
let any = Any::decode(&*put_result.app_metadata).map_err(decode_error_to_arrow_error)?;
Ok(any
.unpack::<DoPutPreparedStatementResult>()?
.and_then(|result| result.prepared_statement_handle))
}

/// Close the prepared statement, so that this PreparedStatement can not used
/// anymore and server can free up any resources.
pub async fn close(mut self) -> Result<(), ArrowError> {
Expand Down
2 changes: 2 additions & 0 deletions arrow-flight/src/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pub use gen::CommandPreparedStatementUpdate;
pub use gen::CommandStatementQuery;
pub use gen::CommandStatementSubstraitPlan;
pub use gen::CommandStatementUpdate;
pub use gen::DoPutPreparedStatementResult;
pub use gen::DoPutUpdateResult;
pub use gen::Nullable;
pub use gen::Searchable;
Expand Down Expand Up @@ -251,6 +252,7 @@ prost_message_ext!(
CommandStatementSubstraitPlan,
CommandStatementUpdate,
DoPutUpdateResult,
DoPutPreparedStatementResult,
TicketStatementQuery,
);

Expand Down
17 changes: 14 additions & 3 deletions arrow-flight/src/sql/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ use super::{
CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate,
CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate,
DoPutUpdateResult, ProstMessageExt, SqlInfo, TicketStatementQuery,
DoPutPreparedStatementResult, DoPutUpdateResult, ProstMessageExt, SqlInfo,
TicketStatementQuery,
};
use crate::{
flight_service_server::FlightService, gen::PollInfo, Action, ActionType, Criteria, Empty,
Expand Down Expand Up @@ -397,11 +398,15 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
}

/// Bind parameters to given prepared statement.
///
/// Returns an opaque handle that the client should pass
erratic-pattern marked this conversation as resolved.
Show resolved Hide resolved
/// back to the server during subsequent requests with this
/// prepared statement.
async fn do_put_prepared_statement_query(
&self,
_query: CommandPreparedStatementQuery,
_request: Request<PeekableFlightDataStream>,
) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
) -> Result<DoPutPreparedStatementResult, Status> {
Err(Status::unimplemented(
"do_put_prepared_statement_query has no default implementation",
))
Expand Down Expand Up @@ -709,7 +714,13 @@ where
Ok(Response::new(Box::pin(output)))
}
Command::CommandPreparedStatementQuery(command) => {
self.do_put_prepared_statement_query(command, request).await
let result = self
.do_put_prepared_statement_query(command, request)
.await?;
let output = futures::stream::iter(vec![Ok(PutResult {
app_metadata: result.as_any().encode_to_vec().into(),
})]);
Ok(Response::new(Box::pin(output)))
}
Command::CommandStatementSubstraitPlan(command) => {
let record_count = self.do_put_substrait_plan(command, request).await?;
Expand Down
77 changes: 59 additions & 18 deletions arrow-flight/tests/flight_sql_client_cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,18 @@ use arrow_flight::{
CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes,
CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery,
CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementSubstraitPlan,
CommandStatementUpdate, ProstMessageExt, SqlInfo, TicketStatementQuery,
CommandStatementUpdate, DoPutPreparedStatementResult, ProstMessageExt, SqlInfo,
TicketStatementQuery,
},
utils::batches_to_flight_data,
Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest,
HandshakeResponse, IpcMessage, PutResult, SchemaAsIpc, Ticket,
HandshakeResponse, IpcMessage, SchemaAsIpc, Ticket,
};
use arrow_ipc::writer::IpcWriteOptions;
use arrow_schema::{ArrowError, DataType, Field, Schema};
use assert_cmd::Command;
use bytes::Bytes;
use futures::{Stream, StreamExt, TryStreamExt};
use futures::{Stream, TryStreamExt};
use prost::Message;
use tokio::{net::TcpListener, task::JoinHandle};
use tonic::{Request, Response, Status, Streaming};
Expand All @@ -51,7 +52,7 @@ const QUERY: &str = "SELECT * FROM table;";

#[tokio::test]
async fn test_simple() {
let test_server = FlightSqlServiceImpl {};
let test_server = FlightSqlServiceImpl::default();
let fixture = TestFixture::new(&test_server).await;
let addr = fixture.addr;

Expand Down Expand Up @@ -92,10 +93,9 @@ async fn test_simple() {

const PREPARED_QUERY: &str = "SELECT * FROM table WHERE field = $1";
const PREPARED_STATEMENT_HANDLE: &str = "prepared_statement_handle";
const UPDATED_PREPARED_STATEMENT_HANDLE: &str = "updated_prepared_statement_handle";

#[tokio::test]
async fn test_do_put_prepared_statement() {
let test_server = FlightSqlServiceImpl {};
async fn test_do_put_prepared_statement(test_server: FlightSqlServiceImpl) {
let fixture = TestFixture::new(&test_server).await;
let addr = fixture.addr;

Expand Down Expand Up @@ -136,11 +136,40 @@ async fn test_do_put_prepared_statement() {
);
}

#[tokio::test]
pub async fn test_do_put_prepared_statement_stateless() {
test_do_put_prepared_statement(FlightSqlServiceImpl {
stateless_prepared_statements: true,
})
.await
}

#[tokio::test]
pub async fn test_do_put_prepared_statement_stateful() {
test_do_put_prepared_statement(FlightSqlServiceImpl {
stateless_prepared_statements: false,
})
.await
}

/// All tests must complete within this many seconds or else the test server is shutdown
const DEFAULT_TIMEOUT_SECONDS: u64 = 30;

#[derive(Clone, Default)]
pub struct FlightSqlServiceImpl {}
#[derive(Clone)]
pub struct FlightSqlServiceImpl {
/// Whether to emulate stateless (true) or stateful (false) behavior for
/// prepared statements. stateful servers will not return an updated
/// handle after executing `DoPut(CommandPreparedStatementQuery)`
stateless_prepared_statements: bool,
}

impl Default for FlightSqlServiceImpl {
fn default() -> Self {
Self {
stateless_prepared_statements: true,
}
}
}

impl FlightSqlServiceImpl {
/// Return an [`FlightServiceServer`] that can be used with a
Expand Down Expand Up @@ -274,10 +303,17 @@ impl FlightSqlService for FlightSqlServiceImpl {
cmd: CommandPreparedStatementQuery,
_request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
assert_eq!(
cmd.prepared_statement_handle,
PREPARED_STATEMENT_HANDLE.as_bytes()
);
if self.stateless_prepared_statements {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to do this kind of testing (simulate both stateful and stateless server behavior) in the Go implementation as well, but I am less confident around the Arrow Go test suite and general Go testing practices.

assert_eq!(
cmd.prepared_statement_handle,
UPDATED_PREPARED_STATEMENT_HANDLE.as_bytes()
);
} else {
assert_eq!(
cmd.prepared_statement_handle,
PREPARED_STATEMENT_HANDLE.as_bytes()
);
}
let resp = Response::new(self.fake_flight_info().unwrap());
Ok(resp)
}
Expand Down Expand Up @@ -524,7 +560,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
&self,
_query: CommandPreparedStatementQuery,
request: Request<PeekableFlightDataStream>,
) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
) -> Result<DoPutPreparedStatementResult, Status> {
// just make sure decoding the parameters works
let parameters = FlightRecordBatchStream::new_from_flight_data(
request.into_inner().map_err(|e| e.into()),
Expand All @@ -543,10 +579,15 @@ impl FlightSqlService for FlightSqlServiceImpl {
)));
}
}

Ok(Response::new(
futures::stream::once(async { Ok(PutResult::default()) }).boxed(),
))
let handle = if self.stateless_prepared_statements {
UPDATED_PREPARED_STATEMENT_HANDLE.to_string().into()
} else {
PREPARED_STATEMENT_HANDLE.to_string().into()
};
let result = DoPutPreparedStatementResult {
prepared_statement_handle: Some(handle),
};
Ok(result)
}

async fn do_put_prepared_statement_update(
Expand Down
23 changes: 22 additions & 1 deletion format/FlightSql.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1796,7 +1796,28 @@
// an unknown updated record count.
int64 record_count = 1;
}


/* An *optional* response returned when `DoPut` is called with `CommandPreparedStatementQuery`.
*
* *Note on legacy behavior*: previous versions of the protocol did not return any result for
* this command, and that behavior should still be supported by clients. See documentation
* of individual fields for more details on expected client behavior in this case.
*/
message DoPutPreparedStatementResult {
alamb marked this conversation as resolved.
Show resolved Hide resolved
option (experimental) = true;

// Represents a (potentially updated) opaque handle for the prepared statement on the server.
// Because the handle could potentially be updated, any previous handles for this prepared
// statement should be considered invalid, and all subsequent requests for this prepared
// statement must use this new handle, if specified.
// The updated handle allows implementing query parameters with stateless services
// as described in https://github.com/apache/arrow/issues/37720.
//
// When an updated handle is not provided by the server, clients should contiue
// using the previous handle provided by `ActionCreatePreparedStatementResonse`.
optional bytes prepared_statement_handle = 1;
}

/*
* Request message for the "CancelQuery" action.
*
Expand Down
Loading