Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-38255: [Java] Implement Flight SQL Bulk Ingestion #43551

Merged
merged 15 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dev/archery/archery/integration/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True,
Scenario(
"flight_sql:ingestion",
description="Ensure Flight SQL ingestion works as expected.",
skip_testers={"JS", "C#", "Rust", "Java"}
skip_testers={"JS", "C#", "Rust"}
),
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,17 @@
*/
package org.apache.arrow.flight.integration.tests;

import java.util.HashMap;
import java.util.Map;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.SchemaResult;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.flight.sql.CancelResult;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.flight.sql.FlightSqlProducer;
import org.apache.arrow.flight.sql.impl.FlightSql;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.UInt4Vector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.DenseUnionVector;
import org.apache.arrow.vector.types.pojo.Schema;

/**
* Integration test scenario for validating Flight SQL specs across multiple implementations. This
Expand All @@ -53,69 +46,32 @@ public void client(BufferAllocator allocator, Location location, FlightClient cl
}

private void validateMetadataRetrieval(FlightSqlClient sqlClient) throws Exception {
FlightInfo info = sqlClient.getSqlInfo();
Ticket ticket = info.getEndpoints().get(0).getTicket();

Map<Integer, Object> infoValues = new HashMap<>();
try (FlightStream stream = sqlClient.getStream(ticket)) {
Schema actualSchema = stream.getSchema();
IntegrationAssertions.assertEquals(
FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA, actualSchema);

while (stream.next()) {
UInt4Vector infoName = (UInt4Vector) stream.getRoot().getVector(0);
DenseUnionVector value = (DenseUnionVector) stream.getRoot().getVector(1);

for (int i = 0; i < stream.getRoot().getRowCount(); i++) {
final int code = infoName.get(i);
if (infoValues.containsKey(code)) {
throw new AssertionError("Duplicate SqlInfo value: " + code);
}
Object object;
byte typeId = value.getTypeId(i);
switch (typeId) {
case 0: // string
object =
Preconditions.checkNotNull(
value.getVarCharVector(typeId).getObject(value.getOffset(i)))
.toString();
break;
case 1: // bool
object = value.getBitVector(typeId).getObject(value.getOffset(i));
break;
case 2: // int64
object = value.getBigIntVector(typeId).getObject(value.getOffset(i));
break;
case 3: // int32
object = value.getIntVector(typeId).getObject(value.getOffset(i));
break;
default:
throw new AssertionError("Decoding SqlInfo of type code " + typeId);
}
infoValues.put(code, object);
}
}
}

IntegrationAssertions.assertEquals(
Boolean.FALSE, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SQL_VALUE));
IntegrationAssertions.assertEquals(
Boolean.TRUE, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_VALUE));
IntegrationAssertions.assertEquals(
"min_version",
infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION_VALUE));
IntegrationAssertions.assertEquals(
"max_version",
infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION_VALUE));
IntegrationAssertions.assertEquals(
FlightSql.SqlSupportedTransaction.SQL_SUPPORTED_TRANSACTION_SAVEPOINT_VALUE,
infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_TRANSACTION_VALUE));
IntegrationAssertions.assertEquals(
Boolean.TRUE, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_CANCEL_VALUE));
IntegrationAssertions.assertEquals(
42, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT_VALUE));
IntegrationAssertions.assertEquals(
7, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT_VALUE));
validate(
FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA,
sqlClient.getSqlInfo(),
sqlClient,
s -> {
Map<Integer, Object> infoValues = readSqlInfoStream(s);
IntegrationAssertions.assertEquals(
Boolean.FALSE, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SQL_VALUE));
IntegrationAssertions.assertEquals(
Boolean.TRUE, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_VALUE));
IntegrationAssertions.assertEquals(
"min_version",
infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION_VALUE));
IntegrationAssertions.assertEquals(
"max_version",
infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION_VALUE));
IntegrationAssertions.assertEquals(
FlightSql.SqlSupportedTransaction.SQL_SUPPORTED_TRANSACTION_SAVEPOINT_VALUE,
infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_TRANSACTION_VALUE));
IntegrationAssertions.assertEquals(
Boolean.TRUE, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_CANCEL_VALUE));
IntegrationAssertions.assertEquals(
42, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT_VALUE));
IntegrationAssertions.assertEquals(
7, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT_VALUE));
});
}

private void validateStatementExecution(FlightSqlClient sqlClient) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.flight.integration.tests;

import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import java.util.Map;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.flight.sql.FlightSqlClient.ExecuteIngestOptions;
import org.apache.arrow.flight.sql.FlightSqlProducer;
import org.apache.arrow.flight.sql.impl.FlightSql;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementIngest.TableDefinitionOptions;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Schema;

/**
* Integration test scenario for validating Flight SQL specs across multiple implementations. This
* should ensure that RPC objects are being built and parsed correctly for multiple languages and
* that the Arrow schemas are returned as expected.
*/
public class FlightSqlIngestionScenario extends FlightSqlScenario {

@Override
public FlightProducer producer(BufferAllocator allocator, Location location) throws Exception {
FlightSqlScenarioProducer producer =
(FlightSqlScenarioProducer) super.producer(allocator, location);
producer
.getSqlInfoBuilder()
.withFlightSqlServerBulkIngestionTransaction(true)
.withFlightSqlServerBulkIngestion(true);
return producer;
}

@Override
public void client(BufferAllocator allocator, Location location, FlightClient client)
throws Exception {
try (final FlightSqlClient sqlClient = new FlightSqlClient(client)) {
validateMetadataRetrieval(sqlClient);
validateIngestion(allocator, sqlClient);
}
}

private void validateMetadataRetrieval(FlightSqlClient sqlClient) throws Exception {
validate(
FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA,
sqlClient.getSqlInfo(
FlightSql.SqlInfo.FLIGHT_SQL_SERVER_INGEST_TRANSACTIONS_SUPPORTED,
FlightSql.SqlInfo.FLIGHT_SQL_SERVER_BULK_INGESTION),
sqlClient,
s -> {
Map<Integer, Object> infoValues = readSqlInfoStream(s);
IntegrationAssertions.assertEquals(
Boolean.TRUE,
infoValues.get(
FlightSql.SqlInfo.FLIGHT_SQL_SERVER_INGEST_TRANSACTIONS_SUPPORTED_VALUE));
IntegrationAssertions.assertEquals(
Boolean.TRUE,
infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_BULK_INGESTION_VALUE));
});
}

private VectorSchemaRoot getIngestVectorRoot(BufferAllocator allocator) {
Schema schema = FlightSqlScenarioProducer.getIngestSchema();
VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
root.setRowCount(3);
return root;
}

private void validateIngestion(BufferAllocator allocator, FlightSqlClient sqlClient) {
try (VectorSchemaRoot data = getIngestVectorRoot(allocator)) {
TableDefinitionOptions tableDefinitionOptions =
TableDefinitionOptions.newBuilder()
.setIfExists(TableDefinitionOptions.TableExistsOption.TABLE_EXISTS_OPTION_REPLACE)
.setIfNotExist(
TableDefinitionOptions.TableNotExistOption.TABLE_NOT_EXIST_OPTION_CREATE)
.build();
Map<String, String> options = new HashMap<>(ImmutableMap.of("key1", "val1", "key2", "val2"));
ExecuteIngestOptions executeIngestOptions =
new ExecuteIngestOptions(
"test_table", tableDefinitionOptions, true, "test_catalog", "test_schema", options);
FlightSqlClient.Transaction transaction = new FlightSqlClient.Transaction(TRANSACTION_ID);
long updatedRows = sqlClient.executeIngest(data, executeIngestOptions, transaction);

IntegrationAssertions.assertEquals(3L, updatedRows);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@
*/
package org.apache.arrow.flight.integration.tests;

import static java.util.Objects.isNull;

import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightInfo;
Expand All @@ -32,7 +37,10 @@
import org.apache.arrow.flight.sql.impl.FlightSql;
import org.apache.arrow.flight.sql.util.TableRef;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.UInt4Vector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.DenseUnionVector;
import org.apache.arrow.vector.types.pojo.Schema;

/**
Expand Down Expand Up @@ -158,7 +166,15 @@ private void validateMetadataRetrieval(FlightSqlClient sqlClient) throws Excepti
FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY
},
options),
sqlClient);
sqlClient,
s -> {
Map<Integer, Object> infoValues = readSqlInfoStream(s);
IntegrationAssertions.assertEquals(
Boolean.FALSE, infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY_VALUE));
IntegrationAssertions.assertEquals(
FlightSqlScenarioProducer.SERVER_NAME,
infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME_VALUE));
});
validateSchema(
FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA, sqlClient.getSqlInfoSchema(options));
}
Expand Down Expand Up @@ -194,14 +210,64 @@ private void validatePreparedStatementExecution(

protected void validate(Schema expectedSchema, FlightInfo flightInfo, FlightSqlClient sqlClient)
throws Exception {
validate(expectedSchema, flightInfo, sqlClient, null);
}

protected void validate(
Schema expectedSchema,
FlightInfo flightInfo,
FlightSqlClient sqlClient,
Consumer<FlightStream> streamConsumer)
throws Exception {
Ticket ticket = flightInfo.getEndpoints().get(0).getTicket();
try (FlightStream stream = sqlClient.getStream(ticket)) {
Schema actualSchema = stream.getSchema();
IntegrationAssertions.assertEquals(expectedSchema, actualSchema);
if (!isNull(streamConsumer)) {
streamConsumer.accept(stream);
}
}
}

protected void validateSchema(Schema expected, SchemaResult actual) {
IntegrationAssertions.assertEquals(expected, actual.getSchema());
}

protected Map<Integer, Object> readSqlInfoStream(FlightStream stream) {
Map<Integer, Object> infoValues = new HashMap<>();
while (stream.next()) {
UInt4Vector infoName = (UInt4Vector) stream.getRoot().getVector(0);
DenseUnionVector value = (DenseUnionVector) stream.getRoot().getVector(1);

for (int i = 0; i < stream.getRoot().getRowCount(); i++) {
final int code = infoName.get(i);
if (infoValues.containsKey(code)) {
throw new AssertionError("Duplicate SqlInfo value: " + code);
}
Object object;
byte typeId = value.getTypeId(i);
switch (typeId) {
case 0: // string
object =
Preconditions.checkNotNull(
value.getVarCharVector(typeId).getObject(value.getOffset(i)))
.toString();
break;
case 1: // bool
object = value.getBitVector(typeId).getObject(value.getOffset(i));
break;
case 2: // int64
object = value.getBigIntVector(typeId).getObject(value.getOffset(i));
break;
case 3: // int32
object = value.getIntVector(typeId).getObject(value.getOffset(i));
break;
default:
throw new AssertionError("Decoding SqlInfo of type code " + typeId);
}
infoValues.put(code, object);
}
}
return infoValues;
}
}
Loading
Loading