Skip to content

Commit

Permalink
[apache#2543] feat(spark-connector): support row-level operations to …
Browse files Browse the repository at this point in the history
…iceberg Table
  • Loading branch information
caican00 committed May 5, 2024
1 parent a4733a4 commit ae71559
Show file tree
Hide file tree
Showing 16 changed files with 605 additions and 98 deletions.
4 changes: 4 additions & 0 deletions integration-test/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ plugins {

val scalaVersion: String = project.properties["scalaVersion"] as? String ?: extra["defaultScalaVersion"].toString()
val sparkVersion: String = libs.versions.spark.get()
val sparkMajorVersion: String = sparkVersion.substringBeforeLast(".")
val kyuubiVersion: String = libs.versions.kyuubi.get()
val icebergVersion: String = libs.versions.iceberg.get()
val scalaCollectionCompatVersion: String = libs.versions.scala.collection.compat.get()

Expand Down Expand Up @@ -114,6 +116,8 @@ dependencies {
exclude("io.dropwizard.metrics")
exclude("org.rocksdb")
}
testImplementation("org.apache.iceberg:iceberg-spark-runtime-${sparkMajorVersion}_$scalaVersion:$icebergVersion")
testImplementation("org.apache.kyuubi:kyuubi-spark-connector-hive_$scalaVersion:$kyuubiVersion")

testImplementation(libs.okhttp3.loginterceptor)
testImplementation(libs.postgresql.driver)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,39 @@ protected static String getDeleteSql(String tableName, String condition) {
return String.format("DELETE FROM %s where %s", tableName, condition);
}

private static String getUpdateTableSql(String tableName, String setClause, String whereClause) {
return String.format("UPDATE %s SET %s WHERE %s", tableName, setClause, whereClause);
}

private static String getRowLevelUpdateTableSql(
String targetTableName, String selectClause, String sourceTableName, String onClause) {
return String.format(
"MERGE INTO %s "
+ "USING (%s) %s "
+ "ON %s "
+ "WHEN MATCHED THEN UPDATE SET * "
+ "WHEN NOT MATCHED THEN INSERT *",
targetTableName, selectClause, sourceTableName, onClause);
}

private static String getRowLevelDeleteTableSql(
String targetTableName, String selectClause, String sourceTableName, String onClause) {
return String.format(
"MERGE INTO %s "
+ "USING (%s) %s "
+ "ON %s "
+ "WHEN MATCHED THEN DELETE "
+ "WHEN NOT MATCHED THEN INSERT *",
targetTableName, selectClause, sourceTableName, onClause);
}

// Whether supports [CLUSTERED BY col_name3 SORTED BY col_name INTO num_buckets BUCKETS]
protected abstract boolean supportsSparkSQLClusteredBy();

protected abstract boolean supportsPartition();

protected abstract boolean supportsDelete();

// Use a custom database not the original default database because SparkCommonIT couldn't
// read&write data to tables in default database. The main reason is default database location is
// determined by `hive.metastore.warehouse.dir` in hive-site.xml which is local HDFS address
Expand Down Expand Up @@ -702,6 +730,28 @@ void testTableOptions() {
checkTableReadWrite(tableInfo);
}

@Test
@EnabledIf("supportsDelete")
void testDeleteOperation() {
String tableName = "test_row_level_delete_table";
dropTableIfExists(tableName);
createSimpleTable(tableName);

SparkTableInfo table = getTableInfo(tableName);
checkTableColumns(tableName, getSimpleTableColumn(), table);
sql(
String.format(
"INSERT INTO %s VALUES (1, '1', 1),(2, '2', 2),(3, '3', 3),(4, '4', 4),(5, '5', 5)",
tableName));
List<String> queryResult1 = getTableData(tableName);
Assertions.assertEquals(5, queryResult1.size());
Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", queryResult1));
sql(getDeleteSql(tableName, "id <= 4"));
List<String> queryResult2 = getTableData(tableName);
Assertions.assertEquals(1, queryResult2.size());
Assertions.assertEquals("5,5,5", queryResult2.get(0));
}

protected void checkTableReadWrite(SparkTableInfo table) {
String name = table.getTableIdentifier();
boolean isPartitionTable = table.isPartitionTable();
Expand Down Expand Up @@ -760,6 +810,49 @@ protected String getExpectedTableData(SparkTableInfo table) {
.collect(Collectors.joining(","));
}

protected void checkTableRowLevelUpdate(String tableName) {
writeToEmptyTableAndCheckData(tableName);
String updatedValues = "id = 6, name = '6', age = 6";
sql(getUpdateTableSql(tableName, updatedValues, "id = 5"));
List<String> queryResult = getQueryData(getSelectAllSqlWithOrder(tableName));
Assertions.assertEquals(5, queryResult.size());
Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;6,6,6", String.join(";", queryResult));
}

protected void checkTableRowLevelDelete(String tableName) {
writeToEmptyTableAndCheckData(tableName);
sql(getDeleteSql(tableName, "id <= 2"));
List<String> queryResult = getQueryData(getSelectAllSqlWithOrder(tableName));
Assertions.assertEquals(3, queryResult.size());
Assertions.assertEquals("3,3,3;4,4,4;5,5,5", String.join(";", queryResult));
}

protected void checkTableDeleteByMergeInto(String tableName) {
writeToEmptyTableAndCheckData(tableName);

String sourceTableName = "source_table";
String selectClause =
"SELECT 1 AS id, '1' AS name, 1 AS age UNION ALL SELECT 6 AS id, '6' AS name, 6 AS age";
String onClause = String.format("%s.id = %s.id", tableName, sourceTableName);
sql(getRowLevelDeleteTableSql(tableName, selectClause, sourceTableName, onClause));
List<String> queryResult = getQueryData(getSelectAllSqlWithOrder(tableName));
Assertions.assertEquals(5, queryResult.size());
Assertions.assertEquals("2,2,2;3,3,3;4,4,4;5,5,5;6,6,6", String.join(";", queryResult));
}

protected void checkTableUpdateByMergeInto(String tableName) {
writeToEmptyTableAndCheckData(tableName);

String sourceTableName = "source_table";
String selectClause =
"SELECT 1 AS id, '2' AS name, 2 AS age UNION ALL SELECT 6 AS id, '6' AS name, 6 AS age";
String onClause = String.format("%s.id = %s.id", tableName, sourceTableName);
sql(getRowLevelUpdateTableSql(tableName, selectClause, sourceTableName, onClause));
List<String> queryResult = getQueryData(getSelectAllSqlWithOrder(tableName));
Assertions.assertEquals(6, queryResult.size());
Assertions.assertEquals("1,2,2;2,2,2;3,3,3;4,4,4;5,5,5;6,6,6", String.join(";", queryResult));
}

protected String getCreateSimpleTableString(String tableName) {
return getCreateSimpleTableString(tableName, false);
}
Expand Down Expand Up @@ -801,6 +894,16 @@ protected void checkTableColumns(
.check(tableInfo);
}

private void writeToEmptyTableAndCheckData(String tableName) {
sql(
String.format(
"INSERT INTO %s VALUES (1, '1', 1),(2, '2', 2),(3, '3', 3),(4, '4', 4),(5, '5', 5)",
tableName));
List<String> queryResult = getTableData(tableName);
Assertions.assertEquals(5, queryResult.size());
Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", queryResult));
}

// partition expression may contain "'", like a='s'/b=1
private String getPartitionExpression(SparkTableInfo table, String delimiter) {
return table.getPartitionedColumns().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ protected boolean supportsPartition() {
return true;
}

@Override
protected boolean supportsDelete() {
return false;
}

@Test
public void testCreateHiveFormatPartitionTable() {
String tableName = "hive_partition_table";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo;
import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfoChecker;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -18,10 +19,13 @@
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.hadoop.fs.Path;
import org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException;
import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
Expand All @@ -30,13 +34,21 @@
import org.apache.spark.sql.connector.catalog.FunctionCatalog;
import org.apache.spark.sql.connector.catalog.Identifier;
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
import org.apache.spark.sql.internal.StaticSQLConf;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.platform.commons.util.StringUtils;
import scala.Tuple3;

public abstract class SparkIcebergCatalogIT extends SparkCommonIT {

private static final String ICEBERG_FORMAT_VERSION = "format-version";
private static final String ICEBERG_DELETE_MODE = "write.delete.mode";
private static final String ICEBERG_UPDATE_MODE = "write.update.mode";
private static final String ICEBERG_MERGE_MODE = "write.merge.mode";

@Override
protected String getCatalogName() {
return "iceberg";
Expand All @@ -57,6 +69,11 @@ protected boolean supportsPartition() {
return true;
}

@Override
protected boolean supportsDelete() {
return true;
}

@Override
protected String getTableLocation(SparkTableInfo table) {
return String.join(File.separator, table.getTableLocation(), "data");
Expand Down Expand Up @@ -216,6 +233,24 @@ void testIcebergMetadataColumns() throws NoSuchTableException {
testDeleteMetadataColumn();
}

@Test
void testInjectSparkExtensions() {
SparkSession sparkSession = getSparkSession();
SparkConf conf = sparkSession.sparkContext().getConf();
Assertions.assertTrue(conf.contains(StaticSQLConf.SPARK_SESSION_EXTENSIONS().key()));
String extensions = conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS().key());
Assertions.assertTrue(StringUtils.isNotBlank(extensions));
Assertions.assertEquals(IcebergSparkSessionExtensions.class.getName(), extensions);
}

@Test
void testIcebergTableRowLevelOperations() {
testIcebergDeleteOperation();
testIcebergUpdateOperation();
testIcebergMergeIntoDeleteOperation();
testIcebergMergeIntoUpdateOperation();
}

private void testMetadataColumns() {
String tableName = "test_metadata_columns";
dropTableIfExists(tableName);
Expand Down Expand Up @@ -386,6 +421,88 @@ private void testDeleteMetadataColumn() {
Assertions.assertEquals(0, queryResult1.size());
}

private void testIcebergDeleteOperation() {
getIcebergTablePropertyValues()
.forEach(
tuple -> {
String tableName =
String.format("test_iceberg_%s_%s_delete_operation", tuple._1(), tuple._2());
dropTableIfExists(tableName);
createIcebergTableWithTabProperties(
tableName,
tuple._1(),
ImmutableMap.of(
ICEBERG_FORMAT_VERSION,
String.valueOf(tuple._2()),
ICEBERG_DELETE_MODE,
tuple._3()));
checkTableColumns(tableName, getSimpleTableColumn(), getTableInfo(tableName));
checkTableRowLevelDelete(tableName);
});
}

private void testIcebergUpdateOperation() {
getIcebergTablePropertyValues()
.forEach(
tuple -> {
String tableName =
String.format("test_iceberg_%s_%s_update_operation", tuple._1(), tuple._2());
dropTableIfExists(tableName);
createIcebergTableWithTabProperties(
tableName,
tuple._1(),
ImmutableMap.of(
ICEBERG_FORMAT_VERSION,
String.valueOf(tuple._2()),
ICEBERG_UPDATE_MODE,
tuple._3()));
checkTableColumns(tableName, getSimpleTableColumn(), getTableInfo(tableName));
checkTableRowLevelUpdate(tableName);
});
}

private void testIcebergMergeIntoDeleteOperation() {
getIcebergTablePropertyValues()
.forEach(
tuple -> {
String tableName =
String.format(
"test_iceberg_%s_%s_mergeinto_delete_operation", tuple._1(), tuple._2());
dropTableIfExists(tableName);
createIcebergTableWithTabProperties(
tableName,
tuple._1(),
ImmutableMap.of(
ICEBERG_FORMAT_VERSION,
String.valueOf(tuple._2()),
ICEBERG_MERGE_MODE,
tuple._3()));
checkTableColumns(tableName, getSimpleTableColumn(), getTableInfo(tableName));
checkTableDeleteByMergeInto(tableName);
});
}

private void testIcebergMergeIntoUpdateOperation() {
getIcebergTablePropertyValues()
.forEach(
tuple -> {
String tableName =
String.format(
"test_iceberg_%s_%s_mergeinto_update_operation", tuple._1(), tuple._2());
dropTableIfExists(tableName);
createIcebergTableWithTabProperties(
tableName,
tuple._1(),
ImmutableMap.of(
ICEBERG_FORMAT_VERSION,
String.valueOf(tuple._2()),
ICEBERG_MERGE_MODE,
tuple._3()));
checkTableColumns(tableName, getSimpleTableColumn(), getTableInfo(tableName));
checkTableUpdateByMergeInto(tableName);
});
}

private List<SparkTableInfo.SparkColumnInfo> getIcebergSimpleTableColumn() {
return Arrays.asList(
SparkTableInfo.SparkColumnInfo.of("id", DataTypes.IntegerType, "id comment"),
Expand Down Expand Up @@ -416,4 +533,26 @@ private SparkMetadataColumnInfo[] getIcebergMetadataColumns() {
new SparkMetadataColumnInfo("_deleted", DataTypes.BooleanType, false)
};
}

private List<Tuple3<Boolean, Integer, String>> getIcebergTablePropertyValues() {
return Arrays.asList(
new Tuple3<>(false, 1, "copy-on-write"),
new Tuple3<>(false, 2, "merge-on-read"),
new Tuple3<>(true, 1, "copy-on-write"),
new Tuple3<>(true, 2, "merge-on-read"));
}

private void createIcebergTableWithTabProperties(
String tableName, boolean isPartitioned, ImmutableMap<String, String> tblProperties) {
String partitionedClause = isPartitioned ? " PARTITIONED BY (name) " : "";
String tblPropertiesStr =
tblProperties.entrySet().stream()
.map(e -> String.format("'%s'='%s'", e.getKey(), e.getValue()))
.collect(Collectors.joining(","));
String createSql =
String.format(
"CREATE TABLE %s (id INT COMMENT 'id comment', name STRING COMMENT '', age INT) %s TBLPROPERTIES(%s)",
tableName, partitionedClause, tblPropertiesStr);
sql(createSql);
}
}
Loading

0 comments on commit ae71559

Please sign in to comment.