Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-38255: [Go][C++] Implement Flight SQL Bulk Ingestion #38385

Merged
merged 13 commits into from
Apr 17, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ TEST(FlightIntegration, FlightSqlExtension) {
ASSERT_OK(RunScenario("flight_sql:extension"));
}

TEST(FlightIntegration, FlightSqlIngestion) {
ASSERT_OK(RunScenario("flight_sql:ingestion"));
}

} // namespace integration_tests
} // namespace flight
} // namespace arrow
125 changes: 125 additions & 0 deletions cpp/src/arrow/flight/integration_tests/test_integration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,7 @@ constexpr int64_t kUpdateStatementExpectedRows = 10000L;
constexpr int64_t kUpdateStatementWithTransactionExpectedRows = 15000L;
constexpr int64_t kUpdatePreparedStatementExpectedRows = 20000L;
constexpr int64_t kUpdatePreparedStatementWithTransactionExpectedRows = 25000L;
constexpr int64_t kIngestStatementExpectedRows = 3L;
constexpr char kSelectStatement[] = "SELECT STATEMENT";
constexpr char kSavepointId[] = "savepoint_id";
constexpr char kSavepointName[] = "savepoint_name";
Expand Down Expand Up @@ -2123,6 +2124,127 @@ class ReuseConnectionScenario : public Scenario {
return Status::OK();
}
};

std::shared_ptr<Schema> GetIngestSchema() {
return arrow::schema({arrow::field("test_field", arrow::int64(), true)});
}

arrow::Result<std::shared_ptr<RecordBatchReader>> GetIngestRecords() {
auto schema = GetIngestSchema();
auto array = arrow::ArrayFromJSON(arrow::int64(), "[null,null,null]");
auto record_batch = arrow::RecordBatch::Make(schema, 3, {array});
return RecordBatchReader::Make({record_batch});
}

/// \brief The server used for testing bulk ingestion
class FlightSqlIngestionServer : public sql::FlightSqlServerBase {
public:
FlightSqlIngestionServer() : sql::FlightSqlServerBase() {
RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_BULK_INGESTION,
sql::SqlInfoResult(true));
RegisterSqlInfo(
sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_INGEST_TRANSACTIONS_SUPPORTED,
sql::SqlInfoResult(true));
}

arrow::Result<int64_t> DoPutCommandStatementIngest(
const ServerCallContext& context, const sql::StatementIngest& command,
FlightMessageReader* reader) override {
ARROW_RETURN_NOT_OK(AssertEq<bool>(
true,
sql::TableDefinitionOptionsTableNotExistOption::kCreate ==
command.table_definition_options.if_not_exist,
"Wrong TableDefinitionOptionsTableNotExistOption for ExecuteIngest"));
ARROW_RETURN_NOT_OK(AssertEq<bool>(
true,
sql::TableDefinitionOptionsTableExistsOption::kReplace ==
command.table_definition_options.if_exists,
"Wrong TableDefinitionOptionsTableExistsOption for ExecuteIngest"));
ARROW_RETURN_NOT_OK(AssertEq<std::string>("test_table", command.table,
"Wrong table for ExecuteIngest"));
ARROW_RETURN_NOT_OK(AssertEq<std::string>("test_schema", command.schema.value(),
"Wrong schema for ExecuteIngest"));
ARROW_RETURN_NOT_OK(AssertEq<std::string>("test_catalog", command.catalog.value(),
"Wrong catalog for ExecuteIngest"));
ARROW_RETURN_NOT_OK(AssertEq<bool>(true, command.temporary,
"Wrong temporary setting for ExecuteIngest"));
ARROW_RETURN_NOT_OK(AssertEq<std::string>("123", command.transaction_id.value(),
"Wrong transaction_id for ExecuteIngest"));

std::unordered_map<std::string, std::string> expected_options = {{"key1", "val1"},
{"key2", "val2"}};
ARROW_RETURN_NOT_OK(
AssertEq<std::size_t>(expected_options.size(), command.options.size(),
"Wrong number of options set for ExecuteIngest"));
for (auto it = expected_options.begin(); it != expected_options.end(); ++it) {
auto key = it->first;
auto expected_val = it->second;
ARROW_RETURN_NOT_OK(
AssertEq<std::string>(expected_val, command.options.at(key),
"Wrong option value set for ExecuteIngest"));
}

auto expected_schema = GetIngestSchema();
int64_t num_records = 0;
while (true) {
ARROW_ASSIGN_OR_RAISE(FlightStreamChunk chunk, reader->Next());
if (chunk.data == nullptr) break;

ARROW_RETURN_NOT_OK(
AssertEq(true, expected_schema->Equals(chunk.data->schema()),
"Chunk schema does not match expected schema for ExecuteIngest"));
num_records += chunk.data->num_rows();
}

return num_records;
}
};

/// \brief The FlightSqlIngestion scenario.
///
/// This tests that the client can execute bulk ingestion against the server.
///
/// The server implements DoPutCommandStatementIngest and validates that the arguments
/// it receives are the same as those supplied to the client, or have been successfully
/// mapped to the equivalent server-side representation. The size and schema of the sent
/// and received streams are also validated against eachother.
class FlightSqlIngestionScenario : public Scenario {
Status MakeServer(std::unique_ptr<FlightServerBase>* server,
FlightServerOptions* options) override {
server->reset(new FlightSqlIngestionServer());
return Status::OK();
}

Status MakeClient(FlightClientOptions* options) override { return Status::OK(); }

Status RunClient(std::unique_ptr<FlightClient> client) override {
sql::FlightSqlClient sql_client(std::move(client));
ARROW_RETURN_NOT_OK(ValidateIngestion(&sql_client));
return Status::OK();
}

Status ValidateIngestion(sql::FlightSqlClient* sql_client) {
ARROW_ASSIGN_OR_RAISE(auto record_batch_reader, GetIngestRecords());

sql::TableDefinitionOptions table_definition_options;
table_definition_options.if_not_exist =
sql::TableDefinitionOptionsTableNotExistOption::kCreate;
table_definition_options.if_exists =
sql::TableDefinitionOptionsTableExistsOption::kReplace;
bool temporary = true;
std::unordered_map<std::string, std::string> options = {{"key1", "val1"},
{"key2", "val2"}};
ARROW_ASSIGN_OR_RAISE(
auto updated_rows,
sql_client->ExecuteIngest({}, record_batch_reader, table_definition_options,
"test_table", "test_schema", "test_catalog", temporary,
sql::Transaction("123"), options));
ARROW_RETURN_NOT_OK(AssertEq(kIngestStatementExpectedRows, updated_rows,
"Wrong number of updated rows for ExecuteIngest"));

return Status::OK();
}
};
} // namespace

Status GetScenario(const std::string& scenario_name, std::shared_ptr<Scenario>* out) {
Expand Down Expand Up @@ -2165,6 +2287,9 @@ Status GetScenario(const std::string& scenario_name, std::shared_ptr<Scenario>*
} else if (scenario_name == "flight_sql:extension") {
*out = std::make_shared<FlightSqlExtensionScenario>();
return Status::OK();
} else if (scenario_name == "flight_sql:ingestion") {
*out = std::make_shared<FlightSqlIngestionScenario>();
return Status::OK();
}
return Status::KeyError("Scenario not found: ", scenario_name);
}
Expand Down
108 changes: 108 additions & 0 deletions cpp/src/arrow/flight/sql/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,114 @@ arrow::Result<int64_t> FlightSqlClient::ExecuteSubstraitUpdate(
return update_result.record_count();
}

arrow::Result<int64_t> FlightSqlClient::ExecuteIngest(
const FlightCallOptions& options, const std::shared_ptr<RecordBatchReader>& reader,
const TableDefinitionOptions& table_definition_options, const std::string& table,
const std::optional<std::string>& schema, const std::optional<std::string>& catalog,
const bool temporary, const Transaction& transaction,
const std::unordered_map<std::string, std::string>& ingest_options) {
flight_sql_pb::CommandStatementIngest command;

flight_sql_pb::CommandStatementIngest_TableDefinitionOptions*
pb_table_definition_options =
new flight_sql_pb::CommandStatementIngest_TableDefinitionOptions();
switch (table_definition_options.if_not_exist) {
case TableDefinitionOptionsTableNotExistOption::kUnspecified:
pb_table_definition_options->set_if_not_exist(
flight_sql_pb::
CommandStatementIngest_TableDefinitionOptions_TableNotExistOption_TABLE_NOT_EXIST_OPTION_UNSPECIFIED); // NOLINT(whitespace/line_length)
break;
case TableDefinitionOptionsTableNotExistOption::kCreate:
pb_table_definition_options->set_if_not_exist(
flight_sql_pb::
CommandStatementIngest_TableDefinitionOptions_TableNotExistOption_TABLE_NOT_EXIST_OPTION_CREATE); // NOLINT(whitespace/line_length)
break;
case TableDefinitionOptionsTableNotExistOption::kFail:
pb_table_definition_options->set_if_not_exist(
flight_sql_pb::
CommandStatementIngest_TableDefinitionOptions_TableNotExistOption_TABLE_NOT_EXIST_OPTION_FAIL); // NOLINT(whitespace/line_length)
break;

default:
break;
}

switch (table_definition_options.if_exists) {
case TableDefinitionOptionsTableExistsOption::kUnspecified:
pb_table_definition_options->set_if_exists(
flight_sql_pb::
CommandStatementIngest_TableDefinitionOptions_TableExistsOption_TABLE_EXISTS_OPTION_UNSPECIFIED); // NOLINT(whitespace/line_length)
break;
case TableDefinitionOptionsTableExistsOption::kFail:
pb_table_definition_options->set_if_exists(
flight_sql_pb::
CommandStatementIngest_TableDefinitionOptions_TableExistsOption_TABLE_EXISTS_OPTION_FAIL); // NOLINT(whitespace/line_length)
break;
case TableDefinitionOptionsTableExistsOption::kAppend:
pb_table_definition_options->set_if_exists(
flight_sql_pb::
CommandStatementIngest_TableDefinitionOptions_TableExistsOption_TABLE_EXISTS_OPTION_APPEND); // NOLINT(whitespace/line_length)
break;
case TableDefinitionOptionsTableExistsOption::kReplace:
pb_table_definition_options->set_if_exists(
flight_sql_pb::
CommandStatementIngest_TableDefinitionOptions_TableExistsOption_TABLE_EXISTS_OPTION_REPLACE); // NOLINT(whitespace/line_length)
break;

default:
break;
}

command.set_allocated_table_definition_options(pb_table_definition_options);
command.set_table(table);

if (schema.has_value()) {
command.set_schema(schema.value());
}

if (catalog.has_value()) {
command.set_catalog(catalog.value());
}

command.set_temporary(temporary);

if (transaction.is_valid()) {
command.set_transaction_id(transaction.transaction_id());
}

auto command_options = command.mutable_options();
for (const auto& [key, val] : ingest_options) {
(*command_options)[key] = val;
}

ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor,
GetFlightDescriptorForCommand(command));

auto reader_ = reader.get();
ARROW_ASSIGN_OR_RAISE(auto stream, DoPut(options, descriptor, reader_->schema()));

while (true) {
ARROW_ASSIGN_OR_RAISE(auto batch, reader_->Next());
if (!batch) break;
ARROW_RETURN_NOT_OK(stream.writer->WriteRecordBatch(*batch));
}

ARROW_RETURN_NOT_OK(stream.writer->DoneWriting());
std::shared_ptr<Buffer> metadata;
ARROW_RETURN_NOT_OK(stream.reader->ReadMetadata(&metadata));
ARROW_RETURN_NOT_OK(stream.writer->Close());

if (!metadata) return Status::IOError("Server did not send a response");

flight_sql_pb::DoPutUpdateResult update_result;
if (!update_result.ParseFromArray(metadata->data(),
static_cast<int>(metadata->size()))) {
return Status::Invalid("Unable to parse DoPutUpdateResult");
}

return update_result.record_count();
}

arrow::Result<std::unique_ptr<FlightInfo>> FlightSqlClient::GetCatalogs(
const FlightCallOptions& options) {
flight_sql_pb::CommandGetCatalogs command;
Expand Down
18 changes: 18 additions & 0 deletions cpp/src/arrow/flight/sql/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,24 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient {
const FlightCallOptions& options, const SubstraitPlan& plan,
const Transaction& transaction = no_transaction());

/// \brief Execute a bulk ingestion to the server.
/// \param[in] options RPC-layer hints for this call.
/// \param[in] reader The records to ingest.
/// \param[in] table_definition_options The behavior for handling the table definition.
/// \param[in] table The destination table to load into.
/// \param[in] schema The DB schema of the destination table.
/// \param[in] catalog The catalog of the destination table.
/// \param[in] temporary Use a temporary table.
/// \param[in] transaction Ingest as part of this transaction.
/// \param[in] ingest_options Additional, backend-specific options.
/// \return The number of rows ingested to the server.
arrow::Result<int64_t> ExecuteIngest(
const FlightCallOptions& options, const std::shared_ptr<RecordBatchReader>& reader,
const TableDefinitionOptions& table_definition_options, const std::string& table,
const std::optional<std::string>& schema, const std::optional<std::string>& catalog,
const bool temporary, const Transaction& transaction = no_transaction(),
const std::unordered_map<std::string, std::string>& ingest_options = {});

/// \brief Request a list of catalogs.
/// \param[in] options RPC-layer hints for this call.
/// \return The FlightInfo describing where to access the dataset.
Expand Down
Loading
Loading