diff --git a/cpp/src/arrow/flight/flight-sql/client.h b/cpp/src/arrow/flight/flight-sql/client.h index cb7ba88ae9ab1..b74bed35570f2 100644 --- a/cpp/src/arrow/flight/flight-sql/client.h +++ b/cpp/src/arrow/flight/flight-sql/client.h @@ -178,6 +178,24 @@ class FlightSqlClientT { Status GetTableTypes(const FlightCallOptions& options, std::unique_ptr* flight_info) const; + /// \brief Request a list of SQL information. + /// \param[in] options RPC-layer hints for this call. + /// \param[in] sql_info the SQL info required. + /// \param[out] flight_info The FlightInfo describing where to access the dataset. + /// \return Status. + Status GetSqlInfo(const FlightCallOptions& options, + const std::vector& sql_info, + std::unique_ptr* flight_info) const; + + /// \brief Request a list of SQL information. + /// \param[in] options RPC-layer hints for this call. + /// \param[in] sql_info the SQL info required. + /// \param[out] flight_info The FlightInfo describing where to access the dataset. + /// \return Status. + Status GetSqlInfo(const FlightCallOptions& options, + const std::vector& sql_info, + std::unique_ptr* flight_info) const; + /// \brief Create a prepared statement object. /// \param[in] options RPC-layer hints for this call. /// \param[in] query The query that will be executed. diff --git a/cpp/src/arrow/flight/flight-sql/client_impl.h b/cpp/src/arrow/flight/flight-sql/client_impl.h index fb39ff5c17d50..33f5933b6738a 100644 --- a/cpp/src/arrow/flight/flight-sql/client_impl.h +++ b/cpp/src/arrow/flight/flight-sql/client_impl.h @@ -325,6 +325,22 @@ Status PreparedStatementT::Close() { return Status::OK(); } +template +Status FlightSqlClientT::GetSqlInfo( + const FlightCallOptions& options, const std::vector& sql_info, + std::unique_ptr* flight_info) const { + pb::sql::CommandGetSqlInfo command; + for (const int& info : sql_info) command.add_info(info); + return GetFlightInfoForCommand(client, options, flight_info, command); +} + +template +Status FlightSqlClientT::GetSqlInfo( + const FlightCallOptions& options, const std::vector& sql_info, + std::unique_ptr* flight_info) const { + return GetSqlInfo(options, reinterpret_cast&>(sql_info), flight_info); +} + } // namespace internal } // namespace sql } // namespace flight diff --git a/cpp/src/arrow/flight/flight-sql/client_test.cc b/cpp/src/arrow/flight/flight-sql/client_test.cc index 820a831eda3ed..64809a49e39ac 100644 --- a/cpp/src/arrow/flight/flight-sql/client_test.cc +++ b/cpp/src/arrow/flight/flight-sql/client_test.cc @@ -317,6 +317,30 @@ TEST(TestFlightSqlClient, TestExecuteUpdate) { ASSERT_EQ(num_rows, 100); } + +TEST(TestFlightSqlClient, TestGetSqlInfo) { + auto* client_mock = new FlightClientMock(); + std::unique_ptr client_mock_ptr(client_mock); + FlightSqlClientT sql_client(client_mock_ptr); + + std::vector sql_info{ + pb::sql::SqlInfo::FLIGHT_SQL_SERVER_NAME, + pb::sql::SqlInfo::FLIGHT_SQL_SERVER_VERSION, + pb::sql::SqlInfo::FLIGHT_SQL_SERVER_ARROW_VERSION}; + std::unique_ptr flight_info; + pb::sql::CommandGetSqlInfo command; + + for (const auto& info : sql_info) command.add_info(info); + google::protobuf::Any any; + any.PackFrom(command); + const FlightDescriptor& descriptor = FlightDescriptor::Command(any.SerializeAsString()); + + FlightCallOptions call_options; + EXPECT_CALL(*client_mock, + GetFlightInfo(Ref(call_options), descriptor, &flight_info)); + (void) sql_client.GetSqlInfo(call_options, sql_info, &flight_info); +} + } // namespace sql } // namespace flight } // namespace arrow