From d2ed24fe9b6da5aff838140407071e1acaf663f0 Mon Sep 17 00:00:00 2001 From: FANNG Date: Sat, 23 Mar 2024 01:03:15 +0800 Subject: [PATCH] [#1550] feat(spark-connector) support partition,bucket, sortorder table (#2540) ### What changes were proposed in this pull request? add partition, distribution, sort order support for spark connector ### Why are the changes needed? Fix: #1550 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? add UT and IT, also verified in local env. --- gradle/libs.versions.toml | 1 + .../integration/test/spark/SparkCommonIT.java | 166 +++++++++--- .../integration/test/spark/SparkEnvIT.java | 32 ++- .../test/spark/hive/SparkHiveCatalogIT.java | 65 +++++ .../test/util/spark/SparkTableInfo.java | 57 ++++ .../util/spark/SparkTableInfoChecker.java | 41 +++ .../test/util/spark/SparkUtilIT.java | 25 ++ spark-connector/build.gradle.kts | 5 + .../spark/connector/ConnectorConstants.java | 7 + .../connector/SparkTransformConverter.java | 244 ++++++++++++++++++ .../connector/catalog/GravitinoCatalog.java | 18 +- .../spark/connector/table/SparkBaseTable.java | 13 + .../TestSparkTransformConverter.java | 206 +++++++++++++++ 13 files changed, 834 insertions(+), 46 deletions(-) create mode 100644 spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/SparkTransformConverter.java create mode 100644 spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/TestSparkTransformConverter.java diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 71deaac89ef..acc44b37f8d 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -31,6 +31,7 @@ iceberg = '1.3.1' # 1.4.0 causes test to fail trino = '426' spark = "3.4.1" # 3.5.0 causes tests to fail scala-collection-compat = "2.7.0" +scala-java-compat = "1.0.2" sqlite-jdbc = "3.42.0.0" testng = "7.5.1" testcontainers = "1.19.0" diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkCommonIT.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkCommonIT.java index 6b735affd69..9805c1f7554 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkCommonIT.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkCommonIT.java @@ -8,12 +8,14 @@ import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo.SparkColumnInfo; import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfoChecker; import com.google.common.collect.ImmutableMap; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import org.apache.hadoop.fs.Path; import org.apache.spark.sql.AnalysisException; import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; @@ -23,16 +25,10 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIf; import org.junit.platform.commons.util.StringUtils; public abstract class SparkCommonIT extends SparkEnvIT { - private static String getSelectAllSql(String tableName) { - return String.format("SELECT * FROM %s", tableName); - } - - private static String getInsertWithoutPartitionSql(String tableName, String values) { - return String.format("INSERT INTO %s VALUES (%s)", tableName, values); - } // To generate test data for write&read table. private static final Map typeConstant = @@ -51,8 +47,21 @@ private static String getInsertWithoutPartitionSql(String tableName, String valu DataTypes.createStructField("col2", DataTypes.StringType, true))), "struct(1, 'a')"); - // Use a custom database not the original default database because SparkCommonIT couldn't - // read&write data to tables in default database. The main reason is default database location is + private static String getInsertWithoutPartitionSql(String tableName, String values) { + return String.format("INSERT INTO %s VALUES (%s)", tableName, values); + } + + private static String getInsertWithPartitionSql( + String tableName, String partitionString, String values) { + return String.format( + "INSERT OVERWRITE %s PARTITION (%s) VALUES (%s)", tableName, partitionString, values); + } + + // Whether supports [CLUSTERED BY col_name3 SORTED BY col_name INTO num_buckets BUCKETS] + protected abstract boolean supportsSparkSQLClusteredBy(); + + // Use a custom database not the original default database because SparkIT couldn't read&write + // data to tables in default database. The main reason is default database location is // determined by `hive.metastore.warehouse.dir` in hive-site.xml which is local HDFS address // not real HDFS address. The location of tables created under default database is like // hdfs://localhost:9000/xxx which couldn't read write data from SparkCommonIT. Will use default @@ -69,10 +78,6 @@ void init() { sql("USE " + getDefaultDatabase()); } - protected String getDefaultDatabase() { - return "default_db"; - } - @Test void testLoadCatalogs() { Set catalogs = getCatalogs(); @@ -442,24 +447,97 @@ void testComplexType() { checkTableReadWrite(tableInfo); } - private void checkTableColumns( - String tableName, List columnInfos, SparkTableInfo tableInfo) { - SparkTableInfoChecker.create() - .withName(tableName) - .withColumns(columnInfos) - .withComment(null) - .check(tableInfo); + @Test + void testCreateDatasourceFormatPartitionTable() { + String tableName = "datasource_partition_table"; + + dropTableIfExists(tableName); + String createTableSQL = getCreateSimpleTableString(tableName); + createTableSQL = createTableSQL + "USING PARQUET PARTITIONED BY (name, age)"; + sql(createTableSQL); + SparkTableInfo tableInfo = getTableInfo(tableName); + SparkTableInfoChecker checker = + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(getSimpleTableColumn()) + .withIdentifyPartition(Arrays.asList("name", "age")); + checker.check(tableInfo); + checkTableReadWrite(tableInfo); + checkPartitionDirExists(tableInfo); } - private void checkTableReadWrite(SparkTableInfo table) { + @Test + @EnabledIf("supportsSparkSQLClusteredBy") + void testCreateBucketTable() { + String tableName = "bucket_table"; + + dropTableIfExists(tableName); + String createTableSQL = getCreateSimpleTableString(tableName); + createTableSQL = createTableSQL + "CLUSTERED BY (id, name) INTO 4 buckets;"; + sql(createTableSQL); + SparkTableInfo tableInfo = getTableInfo(tableName); + SparkTableInfoChecker checker = + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(getSimpleTableColumn()) + .withBucket(4, Arrays.asList("id", "name")); + checker.check(tableInfo); + checkTableReadWrite(tableInfo); + } + + @Test + @EnabledIf("supportsSparkSQLClusteredBy") + void testCreateSortBucketTable() { + String tableName = "sort_bucket_table"; + + dropTableIfExists(tableName); + String createTableSQL = getCreateSimpleTableString(tableName); + createTableSQL = + createTableSQL + "CLUSTERED BY (id, name) SORTED BY (name, id) INTO 4 buckets;"; + sql(createTableSQL); + SparkTableInfo tableInfo = getTableInfo(tableName); + SparkTableInfoChecker checker = + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(getSimpleTableColumn()) + .withBucket(4, Arrays.asList("id", "name"), Arrays.asList("name", "id")); + checker.check(tableInfo); + checkTableReadWrite(tableInfo); + } + + protected void checkPartitionDirExists(SparkTableInfo table) { + Assertions.assertTrue(table.isPartitionTable(), "Not a partition table"); + String tableLocation = table.getTableLocation(); + String partitionExpression = getPartitionExpression(table, "/").replace("'", ""); + Path partitionPath = new Path(tableLocation, partitionExpression); + checkDirExists(partitionPath); + } + + protected void checkDirExists(Path dir) { + try { + Assertions.assertTrue(hdfs.exists(dir), "HDFS directory not exists," + dir); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected void checkTableReadWrite(SparkTableInfo table) { String name = table.getTableIdentifier(); + boolean isPartitionTable = table.isPartitionTable(); String insertValues = - table.getColumns().stream() + table.getUnPartitionedColumns().stream() .map(columnInfo -> typeConstant.get(columnInfo.getType())) .map(Object::toString) .collect(Collectors.joining(",")); - sql(getInsertWithoutPartitionSql(name, insertValues)); + String insertDataSQL = ""; + if (isPartitionTable) { + String partitionExpressions = getPartitionExpression(table, ","); + insertDataSQL = getInsertWithPartitionSql(name, partitionExpressions, insertValues); + } else { + insertDataSQL = getInsertWithoutPartitionSql(name, insertValues); + } + sql(insertDataSQL); // do something to match the query result: // 1. remove "'" from values, such as 'a' is trans to a @@ -492,45 +570,49 @@ private void checkTableReadWrite(SparkTableInfo table) { }) .collect(Collectors.joining(",")); - List queryResult = - sql(getSelectAllSql(name)).stream() - .map( - line -> - Arrays.stream(line) - .map( - item -> { - if (item instanceof Object[]) { - return Arrays.stream((Object[]) item) - .map(Object::toString) - .collect(Collectors.joining(",")); - } else { - return item.toString(); - } - }) - .collect(Collectors.joining(","))) - .collect(Collectors.toList()); + List queryResult = getTableData(name); Assertions.assertTrue( queryResult.size() == 1, "Should just one row, table content: " + queryResult); Assertions.assertEquals(checkValues, queryResult.get(0)); } - private String getCreateSimpleTableString(String tableName) { + protected String getCreateSimpleTableString(String tableName) { return String.format( "CREATE TABLE %s (id INT COMMENT 'id comment', name STRING COMMENT '', age INT)", tableName); } - private List getSimpleTableColumn() { + protected List getSimpleTableColumn() { return Arrays.asList( SparkColumnInfo.of("id", DataTypes.IntegerType, "id comment"), SparkColumnInfo.of("name", DataTypes.StringType, ""), SparkColumnInfo.of("age", DataTypes.IntegerType, null)); } + protected String getDefaultDatabase() { + return "default_db"; + } + // Helper method to create a simple table, and could use corresponding // getSimpleTableColumn to check table column. private void createSimpleTable(String identifier) { String createTableSql = getCreateSimpleTableString(identifier); sql(createTableSql); } + + private void checkTableColumns( + String tableName, List columns, SparkTableInfo tableInfo) { + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(columns) + .withComment(null) + .check(tableInfo); + } + + // partition expression may contain "'", like a='s'/b=1 + private String getPartitionExpression(SparkTableInfo table, String delimiter) { + return table.getPartitionedColumns().stream() + .map(column -> column.getName() + "=" + typeConstant.get(column.getType())) + .collect(Collectors.joining(delimiter)); + } } diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkEnvIT.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkEnvIT.java index b0b7fd895e6..52de8da4a67 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkEnvIT.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/SparkEnvIT.java @@ -14,8 +14,11 @@ import com.datastrato.gravitino.spark.connector.GravitinoSparkConfig; import com.datastrato.gravitino.spark.connector.plugin.GravitinoSparkPlugin; import com.google.common.collect.Maps; +import java.io.IOException; import java.util.Collections; import java.util.Map; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; import org.apache.spark.sql.SparkSession; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assertions; @@ -28,11 +31,12 @@ public abstract class SparkEnvIT extends SparkUtilIT { private static final Logger LOG = LoggerFactory.getLogger(SparkEnvIT.class); private static final ContainerSuite containerSuite = ContainerSuite.getInstance(); + protected FileSystem hdfs; private final String metalakeName = "test"; private SparkSession sparkSession; - private String hiveMetastoreUri; - private String gravitinoUri; + private String hiveMetastoreUri = "thrift://127.0.0.1:9083"; + private String gravitinoUri = "http://127.0.0.1:8090"; protected abstract String getCatalogName(); @@ -47,6 +51,7 @@ protected SparkSession getSparkSession() { @BeforeAll void startUp() { initHiveEnv(); + initHdfsFileSystem(); initGravitinoEnv(); initMetalakeAndCatalogs(); initSparkEnv(); @@ -58,6 +63,13 @@ void startUp() { @AfterAll void stop() { + if (hdfs != null) { + try { + hdfs.close(); + } catch (IOException e) { + LOG.warn("Close HDFS filesystem failed,", e); + } + } if (sparkSession != null) { sparkSession.close(); } @@ -92,6 +104,22 @@ private void initHiveEnv() { HiveContainer.HIVE_METASTORE_PORT); } + private void initHdfsFileSystem() { + Configuration conf = new Configuration(); + conf.set( + "fs.defaultFS", + String.format( + "hdfs://%s:%d", + containerSuite.getHiveContainer().getContainerIpAddress(), + HiveContainer.HDFS_DEFAULTFS_PORT)); + try { + hdfs = FileSystem.get(conf); + } catch (IOException e) { + LOG.error("Create HDFS filesystem failed", e); + throw new RuntimeException(e); + } + } + private void initSparkEnv() { sparkSession = SparkSession.builder() diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/hive/SparkHiveCatalogIT.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/hive/SparkHiveCatalogIT.java index bce6cb212bf..d69030ab31e 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/hive/SparkHiveCatalogIT.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/hive/SparkHiveCatalogIT.java @@ -5,7 +5,17 @@ package com.datastrato.gravitino.integration.test.spark.hive; import com.datastrato.gravitino.integration.test.spark.SparkCommonIT; +import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo; +import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo.SparkColumnInfo; +import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfoChecker; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.hadoop.fs.Path; +import org.apache.spark.sql.types.DataTypes; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; @Tag("gravitino-docker-it") @@ -21,4 +31,59 @@ protected String getCatalogName() { protected String getProvider() { return "hive"; } + + @Override + protected boolean supportsSparkSQLClusteredBy() { + return true; + } + + @Test + public void testCreateHiveFormatPartitionTable() { + String tableName = "hive_partition_table"; + + dropTableIfExists(tableName); + String createTableSQL = getCreateSimpleTableString(tableName); + createTableSQL = createTableSQL + "PARTITIONED BY (age_p1 INT, age_p2 STRING)"; + sql(createTableSQL); + + List columns = new ArrayList<>(getSimpleTableColumn()); + columns.add(SparkColumnInfo.of("age_p1", DataTypes.IntegerType)); + columns.add(SparkColumnInfo.of("age_p2", DataTypes.StringType)); + + SparkTableInfo tableInfo = getTableInfo(tableName); + SparkTableInfoChecker checker = + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(columns) + .withIdentifyPartition(Arrays.asList("age_p1", "age_p2")); + checker.check(tableInfo); + // write to static partition + checkTableReadWrite(tableInfo); + checkPartitionDirExists(tableInfo); + } + + @Test + public void testWriteHiveDynamicPartition() { + String tableName = "hive_dynamic_partition_table"; + + dropTableIfExists(tableName); + String createTableSQL = getCreateSimpleTableString(tableName); + createTableSQL = createTableSQL + "PARTITIONED BY (age_p1 INT, age_p2 STRING)"; + sql(createTableSQL); + + SparkTableInfo tableInfo = getTableInfo(tableName); + + // write data to dynamic partition + String insertData = + String.format( + "INSERT OVERWRITE %s PARTITION(age_p1=1, age_p2) values(1,'a',3,'b');", tableName); + sql(insertData); + List queryResult = getTableData(tableName); + Assertions.assertTrue(queryResult.size() == 1); + Assertions.assertEquals("1,a,3,1,b", queryResult.get(0)); + String location = tableInfo.getTableLocation(); + String partitionExpression = "age_p1=1/age_p2=b"; + Path partitionPath = new Path(location, partitionExpression); + checkDirExists(partitionPath); + } } diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfo.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfo.java index 65e06c977c3..449237ff157 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfo.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfo.java @@ -9,11 +9,18 @@ import com.datastrato.gravitino.spark.connector.table.SparkBaseTable; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; +import javax.ws.rs.NotSupportedException; import lombok.Data; import org.apache.commons.lang3.StringUtils; +import org.apache.spark.sql.connector.expressions.BucketTransform; +import org.apache.spark.sql.connector.expressions.IdentityTransform; +import org.apache.spark.sql.connector.expressions.SortedBucketTransform; +import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.types.DataType; import org.junit.jupiter.api.Assertions; @@ -26,6 +33,9 @@ public class SparkTableInfo { private List columns; private Map tableProperties; private List unknownItems = new ArrayList<>(); + private Transform bucket; + private List partitions = new ArrayList<>(); + private Set partitionColumnNames = new HashSet<>(); public SparkTableInfo() {} @@ -42,6 +52,28 @@ public String getTableIdentifier() { } } + public String getTableLocation() { + return tableProperties.get(ConnectorConstants.LOCATION); + } + + public boolean isPartitionTable() { + return partitions.size() > 0; + } + + void setBucket(Transform bucket) { + Assertions.assertNull(this.bucket, "Should only one distribution"); + this.bucket = bucket; + } + + void addPartition(Transform partition) { + if (partition instanceof IdentityTransform) { + partitionColumnNames.add(((IdentityTransform) partition).reference().fieldNames()[0]); + } else { + throw new NotSupportedException("Doesn't support " + partition.name()); + } + this.partitions.add(partition); + } + static SparkTableInfo create(SparkBaseTable baseTable) { SparkTableInfo sparkTableInfo = new SparkTableInfo(); String identifier = baseTable.name(); @@ -62,9 +94,34 @@ static SparkTableInfo create(SparkBaseTable baseTable) { .collect(Collectors.toList()); sparkTableInfo.comment = baseTable.properties().remove(ConnectorConstants.COMMENT); sparkTableInfo.tableProperties = baseTable.properties(); + Arrays.stream(baseTable.partitioning()) + .forEach( + transform -> { + if (transform instanceof BucketTransform + || transform instanceof SortedBucketTransform) { + sparkTableInfo.setBucket(transform); + } else if (transform instanceof IdentityTransform) { + sparkTableInfo.addPartition(transform); + } else { + throw new NotSupportedException( + "Doesn't support Spark transform: " + transform.name()); + } + }); return sparkTableInfo; } + public List getUnPartitionedColumns() { + return columns.stream() + .filter(column -> !partitionColumnNames.contains(column.name)) + .collect(Collectors.toList()); + } + + public List getPartitionedColumns() { + return columns.stream() + .filter(column -> partitionColumnNames.contains(column.name)) + .collect(Collectors.toList()); + } + @Data public static class SparkColumnInfo { private String name; diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfoChecker.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfoChecker.java index e95730d1ae3..d346769281c 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfoChecker.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkTableInfoChecker.java @@ -6,8 +6,12 @@ package com.datastrato.gravitino.integration.test.util.spark; import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo.SparkColumnInfo; +import com.datastrato.gravitino.spark.connector.SparkTransformConverter; import java.util.ArrayList; import java.util.List; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.IdentityTransform; +import org.apache.spark.sql.connector.expressions.Transform; import org.junit.jupiter.api.Assertions; /** @@ -27,6 +31,8 @@ public static SparkTableInfoChecker create() { private enum CheckField { NAME, COLUMN, + PARTITION, + BUCKET, COMMENT, } @@ -42,6 +48,34 @@ public SparkTableInfoChecker withColumns(List columns) { return this; } + public SparkTableInfoChecker withIdentifyPartition(List partitionColumns) { + partitionColumns.forEach( + columnName -> { + IdentityTransform identityTransform = + SparkTransformConverter.createSparkIdentityTransform(columnName); + this.expectedTableInfo.addPartition(identityTransform); + }); + this.checkFields.add(CheckField.PARTITION); + return this; + } + + public SparkTableInfoChecker withBucket(int bucketNum, List bucketColumns) { + Transform bucketTransform = Expressions.bucket(bucketNum, bucketColumns.toArray(new String[0])); + this.expectedTableInfo.setBucket(bucketTransform); + this.checkFields.add(CheckField.BUCKET); + return this; + } + + public SparkTableInfoChecker withBucket( + int bucketNum, List bucketColumns, List sortColumns) { + Transform sortBucketTransform = + SparkTransformConverter.createSortBucketTransform( + bucketNum, bucketColumns.toArray(new String[0]), sortColumns.toArray(new String[0])); + this.expectedTableInfo.setBucket(sortBucketTransform); + this.checkFields.add(CheckField.BUCKET); + return this; + } + public SparkTableInfoChecker withComment(String comment) { this.expectedTableInfo.setComment(comment); this.checkFields.add(CheckField.COMMENT); @@ -61,6 +95,13 @@ public void check(SparkTableInfo realTableInfo) { Assertions.assertEquals( expectedTableInfo.getColumns(), realTableInfo.getColumns()); break; + case PARTITION: + Assertions.assertEquals( + expectedTableInfo.getPartitions(), realTableInfo.getPartitions()); + break; + case BUCKET: + Assertions.assertEquals(expectedTableInfo.getBucket(), realTableInfo.getBucket()); + break; case COMMENT: Assertions.assertEquals( expectedTableInfo.getComment(), realTableInfo.getComment()); diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkUtilIT.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkUtilIT.java index 36bdc22ba59..13e3b74a55e 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkUtilIT.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/util/spark/SparkUtilIT.java @@ -21,6 +21,7 @@ import com.datastrato.gravitino.integration.test.util.AbstractIT; import com.datastrato.gravitino.spark.connector.table.SparkBaseTable; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Set; @@ -83,6 +84,26 @@ protected List sql(String query) { return rowsToJava(rows); } + // columns data are joined by ',' + protected List getTableData(String tableName) { + return sql(getSelectAllSql(tableName)).stream() + .map( + line -> + Arrays.stream(line) + .map( + item -> { + if (item instanceof Object[]) { + return Arrays.stream((Object[]) item) + .map(Object::toString) + .collect(Collectors.joining(",")); + } else { + return item.toString(); + } + }) + .collect(Collectors.joining(","))) + .collect(Collectors.toList()); + } + // Create SparkTableInfo from SparkBaseTable retrieved from LogicalPlan. protected SparkTableInfo getTableInfo(String tableName) { Dataset ds = getSparkSession().sql("DESC TABLE EXTENDED " + tableName); @@ -110,6 +131,10 @@ protected boolean tableExists(String tableName) { } } + private static String getSelectAllSql(String tableName) { + return String.format("SELECT * FROM %s", tableName); + } + private List rowsToJava(List rows) { return rows.stream().map(this::toJava).collect(Collectors.toList()); } diff --git a/spark-connector/build.gradle.kts b/spark-connector/build.gradle.kts index 245577f67de..1a03e73f34f 100644 --- a/spark-connector/build.gradle.kts +++ b/spark-connector/build.gradle.kts @@ -16,6 +16,7 @@ val scalaVersion: String = project.properties["scalaVersion"] as? String ?: extr val sparkVersion: String = libs.versions.spark.get() val icebergVersion: String = libs.versions.iceberg.get() val kyuubiVersion: String = libs.versions.kyuubi.get() +val scalaJava8CompatVersion: String = libs.versions.scala.java.compat.get() dependencies { implementation(project(":api")) @@ -27,6 +28,10 @@ dependencies { implementation("org.apache.kyuubi:kyuubi-spark-connector-hive_$scalaVersion:$kyuubiVersion") implementation("org.apache.spark:spark-catalyst_$scalaVersion:$sparkVersion") implementation("org.apache.spark:spark-sql_$scalaVersion:$sparkVersion") + implementation("org.scala-lang.modules:scala-java8-compat_$scalaVersion:$scalaJava8CompatVersion") + + annotationProcessor(libs.lombok) + compileOnly(libs.lombok) testImplementation(libs.junit.jupiter.api) testImplementation(libs.junit.jupiter.params) diff --git a/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/ConnectorConstants.java b/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/ConnectorConstants.java index 40ae3b5c712..3a49a21470f 100644 --- a/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/ConnectorConstants.java +++ b/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/ConnectorConstants.java @@ -5,8 +5,15 @@ package com.datastrato.gravitino.spark.connector; +import com.datastrato.gravitino.rel.expressions.sorts.SortDirection; + public class ConnectorConstants { public static final String COMMENT = "comment"; + public static final SortDirection SPARK_DEFAULT_SORT_DIRECTION = SortDirection.ASCENDING; + public static final String LOCATION = "location"; + + public static final String DOT = "."; + private ConnectorConstants() {} } diff --git a/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/SparkTransformConverter.java b/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/SparkTransformConverter.java new file mode 100644 index 00000000000..9afad670b76 --- /dev/null +++ b/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/SparkTransformConverter.java @@ -0,0 +1,244 @@ +/* + * Copyright 2024 Datastrato Pvt Ltd. + * This software is licensed under the Apache License version 2. + */ + +package com.datastrato.gravitino.spark.connector; + +import com.datastrato.gravitino.rel.expressions.Expression; +import com.datastrato.gravitino.rel.expressions.NamedReference; +import com.datastrato.gravitino.rel.expressions.distributions.Distribution; +import com.datastrato.gravitino.rel.expressions.distributions.Distributions; +import com.datastrato.gravitino.rel.expressions.sorts.SortOrder; +import com.datastrato.gravitino.rel.expressions.sorts.SortOrders; +import com.datastrato.gravitino.rel.expressions.transforms.Transform; +import com.datastrato.gravitino.rel.expressions.transforms.Transforms; +import com.google.common.base.Preconditions; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import javax.ws.rs.NotSupportedException; +import lombok.Getter; +import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.spark.sql.connector.expressions.BucketTransform; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.IdentityTransform; +import org.apache.spark.sql.connector.expressions.LogicalExpressions; +import org.apache.spark.sql.connector.expressions.SortedBucketTransform; +import scala.collection.JavaConverters; + +/** + * SparkTransformConverter translate between Spark transform and Gravitino partition, distribution, + * sort orders. There may be multi partition transforms, but should be only one bucket transform. + * + *

Spark bucket transform is corresponding to Gravitino Hash distribution without sort orders. + * + *

Spark sorted bucket transform is corresponding to Gravitino Hash distribution with sort + * orders. + */ +public class SparkTransformConverter { + + @Getter + public static class DistributionAndSortOrdersInfo { + private Distribution distribution; + private SortOrder[] sortOrders; + + private void setDistribution(Distribution distributionInfo) { + Preconditions.checkState(distribution == null, "Should only set distribution once"); + this.distribution = distributionInfo; + } + + private void setSortOrders(SortOrder[] sortOrdersInfo) { + Preconditions.checkState(sortOrders == null, "Should only set sort orders once"); + this.sortOrders = sortOrdersInfo; + } + } + + public static Transform[] toGravitinoPartitionings( + org.apache.spark.sql.connector.expressions.Transform[] transforms) { + if (ArrayUtils.isEmpty(transforms)) { + return Transforms.EMPTY_TRANSFORM; + } + + return Arrays.stream(transforms) + .filter(transform -> !isBucketTransform(transform)) + .map( + transform -> { + if (transform instanceof IdentityTransform) { + IdentityTransform identityTransform = (IdentityTransform) transform; + return Transforms.identity(identityTransform.reference().fieldNames()); + } else { + throw new NotSupportedException( + "Doesn't support Spark transform: " + transform.name()); + } + }) + .toArray(Transform[]::new); + } + + public static DistributionAndSortOrdersInfo toGravitinoDistributionAndSortOrders( + org.apache.spark.sql.connector.expressions.Transform[] transforms) { + DistributionAndSortOrdersInfo distributionAndSortOrdersInfo = + new DistributionAndSortOrdersInfo(); + if (ArrayUtils.isEmpty(transforms)) { + return distributionAndSortOrdersInfo; + } + + Arrays.stream(transforms) + .filter(transform -> isBucketTransform(transform)) + .forEach( + transform -> { + if (transform instanceof SortedBucketTransform) { + Pair pair = + toGravitinoDistributionAndSortOrders((SortedBucketTransform) transform); + distributionAndSortOrdersInfo.setDistribution(pair.getLeft()); + distributionAndSortOrdersInfo.setSortOrders(pair.getRight()); + } else if (transform instanceof BucketTransform) { + BucketTransform bucketTransform = (BucketTransform) transform; + Distribution distribution = toGravitinoDistribution(bucketTransform); + distributionAndSortOrdersInfo.setDistribution(distribution); + } else { + throw new NotSupportedException( + "Only support BucketTransform and SortedBucketTransform, but get: " + + transform.name()); + } + }); + + return distributionAndSortOrdersInfo; + } + + public static org.apache.spark.sql.connector.expressions.Transform[] toSparkTransform( + com.datastrato.gravitino.rel.expressions.transforms.Transform[] partitions, + Distribution distribution, + SortOrder[] sortOrder) { + List sparkTransforms = new ArrayList<>(); + if (ArrayUtils.isNotEmpty(partitions)) { + Arrays.stream(partitions) + .forEach( + transform -> { + if (transform instanceof Transforms.IdentityTransform) { + Transforms.IdentityTransform identityTransform = + (Transforms.IdentityTransform) transform; + sparkTransforms.add( + createSparkIdentityTransform( + String.join(ConnectorConstants.DOT, identityTransform.fieldName()))); + } else { + throw new UnsupportedOperationException( + "Doesn't support Gravitino partition: " + + transform.name() + + ", className: " + + transform.getClass().getName()); + } + }); + } + + org.apache.spark.sql.connector.expressions.Transform bucketTransform = + toSparkBucketTransform(distribution, sortOrder); + if (bucketTransform != null) { + sparkTransforms.add(bucketTransform); + } + + return sparkTransforms.toArray(new org.apache.spark.sql.connector.expressions.Transform[0]); + } + + private static Distribution toGravitinoDistribution(BucketTransform bucketTransform) { + int bucketNum = (Integer) bucketTransform.numBuckets().value(); + Expression[] expressions = + JavaConverters.seqAsJavaList(bucketTransform.columns()).stream() + .map(sparkReference -> NamedReference.field(sparkReference.fieldNames())) + .toArray(Expression[]::new); + return Distributions.hash(bucketNum, expressions); + } + + // Spark datasourceV2 doesn't support specify sort order direction, use ASCENDING as default. + private static Pair toGravitinoDistributionAndSortOrders( + SortedBucketTransform sortedBucketTransform) { + int bucketNum = (Integer) sortedBucketTransform.numBuckets().value(); + Expression[] bucketColumns = + toGravitinoNamedReference(JavaConverters.seqAsJavaList(sortedBucketTransform.columns())); + + Expression[] sortColumns = + toGravitinoNamedReference( + JavaConverters.seqAsJavaList(sortedBucketTransform.sortedColumns())); + SortOrder[] sortOrders = + Arrays.stream(sortColumns) + .map( + sortColumn -> + SortOrders.of(sortColumn, ConnectorConstants.SPARK_DEFAULT_SORT_DIRECTION)) + .toArray(SortOrder[]::new); + + return Pair.of(Distributions.hash(bucketNum, bucketColumns), sortOrders); + } + + private static org.apache.spark.sql.connector.expressions.Transform toSparkBucketTransform( + Distribution distribution, SortOrder[] sortOrders) { + if (distribution == null) { + return null; + } + + switch (distribution.strategy()) { + case NONE: + return null; + case HASH: + int bucketNum = distribution.number(); + String[] bucketFields = + Arrays.stream(distribution.expressions()) + .map( + expression -> + getFieldNameFromGravitinoNamedReference((NamedReference) expression)) + .toArray(String[]::new); + if (sortOrders == null || sortOrders.length == 0) { + return Expressions.bucket(bucketNum, bucketFields); + } else { + String[] sortOrderFields = + Arrays.stream(sortOrders) + .map( + sortOrder -> + getFieldNameFromGravitinoNamedReference( + (NamedReference) sortOrder.expression())) + .toArray(String[]::new); + return createSortBucketTransform(bucketNum, bucketFields, sortOrderFields); + } + // Spark doesn't support EVEN or RANGE distribution + default: + throw new NotSupportedException( + "Doesn't support distribution strategy: " + distribution.strategy()); + } + } + + private static Expression[] toGravitinoNamedReference( + List sparkNamedReferences) { + return sparkNamedReferences.stream() + .map(sparkReference -> NamedReference.field(sparkReference.fieldNames())) + .toArray(Expression[]::new); + } + + public static org.apache.spark.sql.connector.expressions.Transform createSortBucketTransform( + int bucketNum, String[] bucketFields, String[] sortFields) { + return LogicalExpressions.bucket( + bucketNum, createSparkNamedReference(bucketFields), createSparkNamedReference(sortFields)); + } + + // columnName could be "a" or "a.b" for nested column + public static IdentityTransform createSparkIdentityTransform(String columnName) { + return IdentityTransform.apply(Expressions.column(columnName)); + } + + private static org.apache.spark.sql.connector.expressions.NamedReference[] + createSparkNamedReference(String[] fields) { + return Arrays.stream(fields) + .map(Expressions::column) + .toArray(org.apache.spark.sql.connector.expressions.NamedReference[]::new); + } + + // Gravitino use ["a","b"] for nested fields while Spark use "a.b"; + private static String getFieldNameFromGravitinoNamedReference( + NamedReference gravitinoNamedReference) { + return String.join(ConnectorConstants.DOT, gravitinoNamedReference.fieldName()); + } + + private static boolean isBucketTransform( + org.apache.spark.sql.connector.expressions.Transform transform) { + return transform instanceof BucketTransform || transform instanceof SortedBucketTransform; + } +} diff --git a/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/catalog/GravitinoCatalog.java b/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/catalog/GravitinoCatalog.java index 0449a6e8c82..2e08a2a8332 100644 --- a/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/catalog/GravitinoCatalog.java +++ b/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/catalog/GravitinoCatalog.java @@ -17,6 +17,8 @@ import com.datastrato.gravitino.spark.connector.GravitinoCatalogAdaptor; import com.datastrato.gravitino.spark.connector.GravitinoCatalogAdaptorFactory; import com.datastrato.gravitino.spark.connector.PropertiesConverter; +import com.datastrato.gravitino.spark.connector.SparkTransformConverter; +import com.datastrato.gravitino.spark.connector.SparkTransformConverter.DistributionAndSortOrdersInfo; import com.datastrato.gravitino.spark.connector.SparkTypeConverter; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; @@ -113,7 +115,7 @@ public Identifier[] listTables(String[] namespace) throws NoSuchNamespaceExcepti @Override public Table createTable( - Identifier ident, Column[] columns, Transform[] partitions, Map properties) + Identifier ident, Column[] columns, Transform[] transforms, Map properties) throws TableAlreadyExistsException, NoSuchNamespaceException { NameIdentifier gravitinoIdentifier = NameIdentifier.of(metalakeName, catalogName, getDatabase(ident), ident.name()); @@ -127,11 +129,23 @@ public Table createTable( // Spark store comment in properties, we should retrieve it and pass to Gravitino explicitly. String comment = gravitinoProperties.remove(ConnectorConstants.COMMENT); + DistributionAndSortOrdersInfo distributionAndSortOrdersInfo = + SparkTransformConverter.toGravitinoDistributionAndSortOrders(transforms); + com.datastrato.gravitino.rel.expressions.transforms.Transform[] partitionings = + SparkTransformConverter.toGravitinoPartitionings(transforms); + try { com.datastrato.gravitino.rel.Table table = gravitinoCatalogClient .asTableCatalog() - .createTable(gravitinoIdentifier, gravitinoColumns, comment, gravitinoProperties); + .createTable( + gravitinoIdentifier, + gravitinoColumns, + comment, + gravitinoProperties, + partitionings, + distributionAndSortOrdersInfo.getDistribution(), + distributionAndSortOrdersInfo.getSortOrders()); return gravitinoAdaptor.createSparkTable(ident, table, sparkCatalog, propertiesConverter); } catch (NoSuchSchemaException e) { throw new NoSuchNamespaceException(ident.namespace()); diff --git a/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/table/SparkBaseTable.java b/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/table/SparkBaseTable.java index b6ae81e4d41..0d057656e86 100644 --- a/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/table/SparkBaseTable.java +++ b/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/table/SparkBaseTable.java @@ -5,8 +5,11 @@ package com.datastrato.gravitino.spark.connector.table; +import com.datastrato.gravitino.rel.expressions.distributions.Distribution; +import com.datastrato.gravitino.rel.expressions.sorts.SortOrder; import com.datastrato.gravitino.spark.connector.ConnectorConstants; import com.datastrato.gravitino.spark.connector.PropertiesConverter; +import com.datastrato.gravitino.spark.connector.SparkTransformConverter; import com.datastrato.gravitino.spark.connector.SparkTypeConverter; import java.util.Arrays; import java.util.HashMap; @@ -22,6 +25,7 @@ import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableCapability; import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.connector.write.LogicalWriteInfo; import org.apache.spark.sql.connector.write.WriteBuilder; @@ -117,6 +121,15 @@ public WriteBuilder newWriteBuilder(LogicalWriteInfo info) { return ((SupportsWrite) getSparkTable()).newWriteBuilder(info); } + @Override + public Transform[] partitioning() { + com.datastrato.gravitino.rel.expressions.transforms.Transform[] partitions = + gravitinoTable.partitioning(); + Distribution distribution = gravitinoTable.distribution(); + SortOrder[] sortOrders = gravitinoTable.sortOrder(); + return SparkTransformConverter.toSparkTransform(partitions, distribution, sortOrders); + } + protected Table getSparkTable() { if (lazySparkTable == null) { try { diff --git a/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/TestSparkTransformConverter.java b/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/TestSparkTransformConverter.java new file mode 100644 index 00000000000..ea00eeb5b58 --- /dev/null +++ b/spark-connector/src/test/java/com/datastrato/gravitino/spark/connector/TestSparkTransformConverter.java @@ -0,0 +1,206 @@ +/* + * Copyright 2024 Datastrato Pvt Ltd. + * This software is licensed under the Apache License version 2. + */ + +package com.datastrato.gravitino.spark.connector; + +import com.datastrato.gravitino.rel.expressions.NamedReference; +import com.datastrato.gravitino.rel.expressions.distributions.Distribution; +import com.datastrato.gravitino.rel.expressions.distributions.Distributions; +import com.datastrato.gravitino.rel.expressions.sorts.SortDirection; +import com.datastrato.gravitino.rel.expressions.sorts.SortOrder; +import com.datastrato.gravitino.rel.expressions.sorts.SortOrders; +import com.datastrato.gravitino.rel.expressions.transforms.Transform; +import com.datastrato.gravitino.rel.expressions.transforms.Transforms; +import com.datastrato.gravitino.spark.connector.SparkTransformConverter.DistributionAndSortOrdersInfo; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import javax.ws.rs.NotSupportedException; +import org.apache.spark.sql.connector.expressions.BucketTransform; +import org.apache.spark.sql.connector.expressions.Expressions; +import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.LogicalExpressions; +import org.apache.spark.sql.connector.expressions.SortedBucketTransform; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.TestInstance.Lifecycle; +import scala.collection.JavaConverters; + +@TestInstance(Lifecycle.PER_CLASS) +public class TestSparkTransformConverter { + private Map + sparkToGravitinoPartitionTransformMaps = new HashMap<>(); + + @BeforeAll + void init() { + initSparkToGravitinoTransformMap(); + } + + @Test + void testPartition() { + sparkToGravitinoPartitionTransformMaps.forEach( + (sparkTransform, gravitinoTransform) -> { + Transform[] gravitinoPartitionings = + SparkTransformConverter.toGravitinoPartitionings( + new org.apache.spark.sql.connector.expressions.Transform[] {sparkTransform}); + Assertions.assertTrue( + gravitinoPartitionings != null && gravitinoPartitionings.length == 1); + Assertions.assertEquals(gravitinoTransform, gravitinoPartitionings[0]); + }); + + sparkToGravitinoPartitionTransformMaps.forEach( + (sparkTransform, gravitinoTransform) -> { + org.apache.spark.sql.connector.expressions.Transform[] sparkTransforms = + SparkTransformConverter.toSparkTransform( + new Transform[] {gravitinoTransform}, null, null); + Assertions.assertTrue(sparkTransforms.length == 1); + Assertions.assertEquals(sparkTransform, sparkTransforms[0]); + }); + } + + @Test + void testGravitinoToSparkDistributionWithoutSortOrder() { + int bucketNum = 16; + String[][] columnNames = createGravitinoFieldReferenceNames("a", "b.c"); + Distribution gravitinoDistribution = createHashDistribution(bucketNum, columnNames); + + org.apache.spark.sql.connector.expressions.Transform[] sparkTransforms = + SparkTransformConverter.toSparkTransform(null, gravitinoDistribution, null); + Assertions.assertTrue(sparkTransforms != null && sparkTransforms.length == 1); + Assertions.assertTrue(sparkTransforms[0] instanceof BucketTransform); + BucketTransform bucket = (BucketTransform) sparkTransforms[0]; + Assertions.assertEquals(bucketNum, (Integer) bucket.numBuckets().value()); + String[][] columns = + JavaConverters.seqAsJavaList(bucket.columns()).stream() + .map(namedReference -> namedReference.fieldNames()) + .toArray(String[][]::new); + Assertions.assertArrayEquals(columnNames, columns); + + // none and null distribution + sparkTransforms = SparkTransformConverter.toSparkTransform(null, null, null); + Assertions.assertEquals(0, sparkTransforms.length); + sparkTransforms = SparkTransformConverter.toSparkTransform(null, Distributions.NONE, null); + Assertions.assertEquals(0, sparkTransforms.length); + + // range and even distribution + Assertions.assertThrowsExactly( + NotSupportedException.class, + () -> SparkTransformConverter.toSparkTransform(null, Distributions.RANGE, null)); + Distribution evenDistribution = Distributions.even(bucketNum, NamedReference.field("")); + Assertions.assertThrowsExactly( + NotSupportedException.class, + () -> SparkTransformConverter.toSparkTransform(null, evenDistribution, null)); + } + + @Test + void testSparkToGravitinoDistributionWithoutSortOrder() { + int bucketNum = 16; + String[] sparkFieldReferences = new String[] {"a", "b.c"}; + + org.apache.spark.sql.connector.expressions.Transform sparkBucket = + Expressions.bucket(bucketNum, sparkFieldReferences); + DistributionAndSortOrdersInfo distributionAndSortOrdersInfo = + SparkTransformConverter.toGravitinoDistributionAndSortOrders( + new org.apache.spark.sql.connector.expressions.Transform[] {sparkBucket}); + + Assertions.assertNull(distributionAndSortOrdersInfo.getSortOrders()); + + Distribution distribution = distributionAndSortOrdersInfo.getDistribution(); + String[][] gravitinoFieldReferences = createGravitinoFieldReferenceNames(sparkFieldReferences); + Assertions.assertEquals( + createHashDistribution(bucketNum, gravitinoFieldReferences), distribution); + } + + @Test + void testSparkToGravitinoDistributionWithSortOrder() { + int bucketNum = 16; + String[][] bucketColumnNames = createGravitinoFieldReferenceNames("a", "b.c"); + String[][] sortColumnNames = createGravitinoFieldReferenceNames("f", "m.n"); + SortedBucketTransform sortedBucketTransform = + LogicalExpressions.bucket( + bucketNum, + createSparkFieldReference(bucketColumnNames), + createSparkFieldReference(sortColumnNames)); + + DistributionAndSortOrdersInfo distributionAndSortOrders = + SparkTransformConverter.toGravitinoDistributionAndSortOrders( + new org.apache.spark.sql.connector.expressions.Transform[] {sortedBucketTransform}); + Assertions.assertEquals( + createHashDistribution(bucketNum, bucketColumnNames), + distributionAndSortOrders.getDistribution()); + + SortOrder[] sortOrders = + createSortOrders(sortColumnNames, ConnectorConstants.SPARK_DEFAULT_SORT_DIRECTION); + Assertions.assertArrayEquals(sortOrders, distributionAndSortOrders.getSortOrders()); + } + + @Test + void testGravitinoToSparkDistributionWithSortOrder() { + int bucketNum = 16; + String[][] bucketColumnNames = createGravitinoFieldReferenceNames("a", "b.c"); + String[][] sortColumnNames = createGravitinoFieldReferenceNames("f", "m.n"); + Distribution distribution = createHashDistribution(bucketNum, bucketColumnNames); + SortOrder[] sortOrders = + createSortOrders(sortColumnNames, ConnectorConstants.SPARK_DEFAULT_SORT_DIRECTION); + + org.apache.spark.sql.connector.expressions.Transform[] transforms = + SparkTransformConverter.toSparkTransform(null, distribution, sortOrders); + Assertions.assertTrue(transforms.length == 1); + Assertions.assertTrue(transforms[0] instanceof SortedBucketTransform); + + SortedBucketTransform sortedBucketTransform = (SortedBucketTransform) transforms[0]; + Assertions.assertEquals(bucketNum, (Integer) sortedBucketTransform.numBuckets().value()); + String[][] sparkSortColumns = + JavaConverters.seqAsJavaList(sortedBucketTransform.sortedColumns()).stream() + .map(sparkNamedReference -> sparkNamedReference.fieldNames()) + .toArray(String[][]::new); + + String[][] sparkBucketColumns = + JavaConverters.seqAsJavaList(sortedBucketTransform.columns()).stream() + .map(sparkNamedReference -> sparkNamedReference.fieldNames()) + .toArray(String[][]::new); + + Assertions.assertArrayEquals(bucketColumnNames, sparkBucketColumns); + Assertions.assertArrayEquals(sortColumnNames, sparkSortColumns); + } + + private org.apache.spark.sql.connector.expressions.NamedReference[] createSparkFieldReference( + String[][] fields) { + return Arrays.stream(fields) + .map(field -> FieldReference.apply(String.join(ConnectorConstants.DOT, field))) + .toArray(org.apache.spark.sql.connector.expressions.NamedReference[]::new); + } + + // split column name for Gravitino + private String[][] createGravitinoFieldReferenceNames(String... columnNames) { + return Arrays.stream(columnNames) + .map(columnName -> columnName.split("\\.")) + .toArray(String[][]::new); + } + + private SortOrder[] createSortOrders(String[][] columnNames, SortDirection direction) { + return Arrays.stream(columnNames) + .map(columnName -> SortOrders.of(NamedReference.field(columnName), direction)) + .toArray(SortOrder[]::new); + } + + private Distribution createHashDistribution(int bucketNum, String[][] columnNames) { + NamedReference[] namedReferences = + Arrays.stream(columnNames) + .map(columnName -> NamedReference.field(columnName)) + .toArray(NamedReference[]::new); + return Distributions.hash(bucketNum, namedReferences); + } + + private void initSparkToGravitinoTransformMap() { + sparkToGravitinoPartitionTransformMaps.put( + SparkTransformConverter.createSparkIdentityTransform("a"), Transforms.identity("a")); + sparkToGravitinoPartitionTransformMaps.put( + SparkTransformConverter.createSparkIdentityTransform("a.b"), + Transforms.identity(new String[] {"a", "b"})); + } +}