Skip to content

Commit c3ad26e

Browse files
authored
fix: Support partition values in feature branch comet-parquet-exec (#1106)
* init * more * more * fix clippy * Use Spark and Arrow types for partition schema
1 parent 1cca8d6 commit c3ad26e

File tree

3 files changed

+98
-14
lines changed

3 files changed

+98
-14
lines changed

native/core/src/execution/datafusion/planner.rs

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,19 @@ impl PhysicalPlanner {
969969
.unwrap(),
970970
);
971971

972+
let partition_schema_arrow = scan
973+
.partition_schema
974+
.iter()
975+
.map(to_arrow_datatype)
976+
.collect_vec();
977+
let partition_fields: Vec<_> = partition_schema_arrow
978+
.iter()
979+
.enumerate()
980+
.map(|(idx, data_type)| {
981+
Field::new(format!("part_{}", idx), data_type.clone(), true)
982+
})
983+
.collect();
984+
972985
// Convert the Spark expressions to Physical expressions
973986
let data_filters: Result<Vec<Arc<dyn PhysicalExpr>>, ExecutionError> = scan
974987
.data_filters
@@ -997,30 +1010,63 @@ impl PhysicalPlanner {
9971010
// Generate file groups
9981011
let mut file_groups: Vec<Vec<PartitionedFile>> =
9991012
Vec::with_capacity(partition_count);
1000-
scan.file_partitions.iter().for_each(|partition| {
1013+
scan.file_partitions.iter().try_for_each(|partition| {
10011014
let mut files = Vec::with_capacity(partition.partitioned_file.len());
1002-
partition.partitioned_file.iter().for_each(|file| {
1015+
partition.partitioned_file.iter().try_for_each(|file| {
10031016
assert!(file.start + file.length <= file.file_size);
1004-
files.push(PartitionedFile::new_with_range(
1017+
1018+
let mut partitioned_file = PartitionedFile::new_with_range(
10051019
Url::parse(file.file_path.as_ref())
10061020
.unwrap()
10071021
.path()
10081022
.to_string(),
10091023
file.file_size as u64,
10101024
file.start,
10111025
file.start + file.length,
1012-
));
1013-
});
1026+
);
1027+
1028+
// Process partition values
1029+
// Create an empty input schema for partition values because they are all literals.
1030+
let empty_schema = Arc::new(Schema::empty());
1031+
let partition_values: Result<Vec<_>, _> = file
1032+
.partition_values
1033+
.iter()
1034+
.map(|partition_value| {
1035+
let literal = self.create_expr(
1036+
partition_value,
1037+
Arc::<Schema>::clone(&empty_schema),
1038+
)?;
1039+
literal
1040+
.as_any()
1041+
.downcast_ref::<DataFusionLiteral>()
1042+
.ok_or_else(|| {
1043+
ExecutionError::GeneralError(
1044+
"Expected literal of partition value".to_string(),
1045+
)
1046+
})
1047+
.map(|literal| literal.value().clone())
1048+
})
1049+
.collect();
1050+
let partition_values = partition_values?;
1051+
1052+
partitioned_file.partition_values = partition_values;
1053+
1054+
files.push(partitioned_file);
1055+
Ok::<(), ExecutionError>(())
1056+
})?;
1057+
10141058
file_groups.push(files);
1015-
});
1059+
Ok::<(), ExecutionError>(())
1060+
})?;
10161061

10171062
// TODO: I think we can remove partition_count in the future, but leave for testing.
10181063
assert_eq!(file_groups.len(), partition_count);
10191064

10201065
let object_store_url = ObjectStoreUrl::local_filesystem();
10211066
let mut file_scan_config =
10221067
FileScanConfig::new(object_store_url, Arc::clone(&data_schema_arrow))
1023-
.with_file_groups(file_groups);
1068+
.with_file_groups(file_groups)
1069+
.with_table_partition_cols(partition_fields);
10241070

10251071
// Check for projection, if so generate the vector and add to FileScanConfig.
10261072
let mut projection_vector: Vec<usize> =
@@ -1030,7 +1076,17 @@ impl PhysicalPlanner {
10301076
projection_vector.push(data_schema_arrow.index_of(field.name()).unwrap());
10311077
});
10321078

1033-
assert_eq!(projection_vector.len(), required_schema_arrow.fields.len());
1079+
partition_schema_arrow
1080+
.iter()
1081+
.enumerate()
1082+
.for_each(|(idx, _)| {
1083+
projection_vector.push(idx + data_schema_arrow.fields.len());
1084+
});
1085+
1086+
assert_eq!(
1087+
projection_vector.len(),
1088+
required_schema_arrow.fields.len() + partition_schema_arrow.len()
1089+
);
10341090
file_scan_config = file_scan_config.with_projection(Some(projection_vector));
10351091

10361092
let mut table_parquet_options = TableParquetOptions::new();

native/proto/src/proto/operator.proto

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ message SparkPartitionedFile {
5252
int64 start = 2;
5353
int64 length = 3;
5454
int64 file_size = 4;
55+
repeated spark.spark_expression.Expr partition_values = 5;
5556
}
5657

5758
// This name and the one above are not great, but they correspond to the (unfortunate) Spark names.
@@ -76,8 +77,9 @@ message NativeScan {
7677
string source = 2;
7778
string required_schema = 3;
7879
string data_schema = 4;
79-
repeated spark.spark_expression.Expr data_filters = 5;
80-
repeated SparkFilePartition file_partitions = 6;
80+
repeated spark.spark_expression.DataType partition_schema = 5;
81+
repeated spark.spark_expression.Expr data_filters = 6;
82+
repeated SparkFilePartition file_partitions = 7;
8183
}
8284

8385
message Projection {

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, Normalize
2929
import org.apache.spark.sql.catalyst.plans._
3030
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
3131
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
32-
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometNativeScanExec, CometScanExec, CometSinkPlaceHolder, CometSparkToColumnarExec, DecimalPrecision}
32+
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometScanExec, CometSinkPlaceHolder, CometSparkToColumnarExec, DecimalPrecision}
3333
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
3434
import org.apache.spark.sql.execution
3535
import org.apache.spark.sql.execution._
@@ -2507,12 +2507,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
25072507
partitions.foreach(p => {
25082508
val inputPartitions = p.asInstanceOf[DataSourceRDDPartition].inputPartitions
25092509
inputPartitions.foreach(partition => {
2510-
partition2Proto(partition.asInstanceOf[FilePartition], nativeScanBuilder)
2510+
partition2Proto(
2511+
partition.asInstanceOf[FilePartition],
2512+
nativeScanBuilder,
2513+
scan.relation.partitionSchema)
25112514
})
25122515
})
25132516
case rdd: FileScanRDD =>
25142517
rdd.filePartitions.foreach(partition => {
2515-
partition2Proto(partition, nativeScanBuilder)
2518+
partition2Proto(partition, nativeScanBuilder, scan.relation.partitionSchema)
25162519
})
25172520
case _ =>
25182521
}
@@ -2521,9 +2524,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
25212524
new SparkToParquetSchemaConverter(conf).convert(scan.requiredSchema)
25222525
val dataSchemaParquet =
25232526
new SparkToParquetSchemaConverter(conf).convert(scan.relation.dataSchema)
2527+
val partitionSchema = scan.relation.partitionSchema.fields.flatMap { field =>
2528+
serializeDataType(field.dataType)
2529+
}
2530+
// In `CometScanRule`, we ensure partitionSchema is supported.
2531+
assert(partitionSchema.length == scan.relation.partitionSchema.fields.length)
25242532

25252533
nativeScanBuilder.setRequiredSchema(requiredSchemaParquet.toString)
25262534
nativeScanBuilder.setDataSchema(dataSchemaParquet.toString)
2535+
nativeScanBuilder.addAllPartitionSchema(partitionSchema.toIterable.asJava)
25272536

25282537
Some(result.setNativeScan(nativeScanBuilder).build())
25292538

@@ -3191,10 +3200,27 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
31913200

31923201
private def partition2Proto(
31933202
partition: FilePartition,
3194-
nativeScanBuilder: OperatorOuterClass.NativeScan.Builder): Unit = {
3203+
nativeScanBuilder: OperatorOuterClass.NativeScan.Builder,
3204+
partitionSchema: StructType): Unit = {
31953205
val partitionBuilder = OperatorOuterClass.SparkFilePartition.newBuilder()
31963206
partition.files.foreach(file => {
3207+
// Process the partition values
3208+
val partitionValues = file.partitionValues
3209+
assert(partitionValues.numFields == partitionSchema.length)
3210+
val partitionVals =
3211+
partitionValues.toSeq(partitionSchema).zipWithIndex.map { case (value, i) =>
3212+
val attr = partitionSchema(i)
3213+
val valueProto = exprToProto(Literal(value, attr.dataType), Seq.empty)
3214+
// In `CometScanRule`, we have already checked that all partition values are
3215+
// supported. So, we can safely use `get` here.
3216+
assert(
3217+
valueProto.isDefined,
3218+
s"Unsupported partition value: $value, type: ${attr.dataType}")
3219+
valueProto.get
3220+
}
3221+
31973222
val fileBuilder = OperatorOuterClass.SparkPartitionedFile.newBuilder()
3223+
partitionVals.foreach(fileBuilder.addPartitionValues)
31983224
fileBuilder
31993225
.setFilePath(file.pathUri.toString)
32003226
.setStart(file.start)

0 commit comments

Comments
 (0)