Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 64 additions & 8 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,19 @@ impl PhysicalPlanner {
.unwrap(),
);

let partition_schema_arrow = scan
.partition_schema
.iter()
.map(to_arrow_datatype)
.collect_vec();
let partition_fields: Vec<_> = partition_schema_arrow
.iter()
.enumerate()
.map(|(idx, data_type)| {
Field::new(format!("part_{}", idx), data_type.clone(), true)
})
.collect();

// Convert the Spark expressions to Physical expressions
let data_filters: Result<Vec<Arc<dyn PhysicalExpr>>, ExecutionError> = scan
.data_filters
Expand Down Expand Up @@ -997,30 +1010,63 @@ impl PhysicalPlanner {
// Generate file groups
let mut file_groups: Vec<Vec<PartitionedFile>> =
Vec::with_capacity(partition_count);
scan.file_partitions.iter().for_each(|partition| {
scan.file_partitions.iter().try_for_each(|partition| {
let mut files = Vec::with_capacity(partition.partitioned_file.len());
partition.partitioned_file.iter().for_each(|file| {
partition.partitioned_file.iter().try_for_each(|file| {
assert!(file.start + file.length <= file.file_size);
files.push(PartitionedFile::new_with_range(

let mut partitioned_file = PartitionedFile::new_with_range(
Url::parse(file.file_path.as_ref())
.unwrap()
.path()
.to_string(),
file.file_size as u64,
file.start,
file.start + file.length,
));
});
);

// Process partition values
// Create an empty input schema for partition values because they are all literals.
let empty_schema = Arc::new(Schema::empty());
let partition_values: Result<Vec<_>, _> = file
.partition_values
.iter()
.map(|partition_value| {
let literal = self.create_expr(
partition_value,
Arc::<Schema>::clone(&empty_schema),
)?;
literal
.as_any()
.downcast_ref::<DataFusionLiteral>()
.ok_or_else(|| {
ExecutionError::GeneralError(
"Expected literal of partition value".to_string(),
)
})
.map(|literal| literal.value().clone())
})
.collect();
let partition_values = partition_values?;

partitioned_file.partition_values = partition_values;

files.push(partitioned_file);
Ok::<(), ExecutionError>(())
})?;

file_groups.push(files);
});
Ok::<(), ExecutionError>(())
})?;

// TODO: I think we can remove partition_count in the future, but leave for testing.
assert_eq!(file_groups.len(), partition_count);

let object_store_url = ObjectStoreUrl::local_filesystem();
let mut file_scan_config =
FileScanConfig::new(object_store_url, Arc::clone(&data_schema_arrow))
.with_file_groups(file_groups);
.with_file_groups(file_groups)
.with_table_partition_cols(partition_fields);

// Check for projection, if so generate the vector and add to FileScanConfig.
let mut projection_vector: Vec<usize> =
Expand All @@ -1030,7 +1076,17 @@ impl PhysicalPlanner {
projection_vector.push(data_schema_arrow.index_of(field.name()).unwrap());
});

assert_eq!(projection_vector.len(), required_schema_arrow.fields.len());
partition_schema_arrow
.iter()
.enumerate()
.for_each(|(idx, _)| {
projection_vector.push(idx + data_schema_arrow.fields.len());
});

assert_eq!(
projection_vector.len(),
required_schema_arrow.fields.len() + partition_schema_arrow.len()
);
file_scan_config = file_scan_config.with_projection(Some(projection_vector));

let mut table_parquet_options = TableParquetOptions::new();
Expand Down
6 changes: 4 additions & 2 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ message SparkPartitionedFile {
int64 start = 2;
int64 length = 3;
int64 file_size = 4;
repeated spark.spark_expression.Expr partition_values = 5;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aren't the partition values just strings?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. Although for Hive partitioned table, partition values are dictionary names which are strings, but once Spark reads these strings back, they are casted to corresponding data types of partition columns.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah makes sense.

}

// This name and the one above are not great, but they correspond to the (unfortunate) Spark names.
Expand All @@ -76,8 +77,9 @@ message NativeScan {
string source = 2;
string required_schema = 3;
string data_schema = 4;
repeated spark.spark_expression.Expr data_filters = 5;
repeated SparkFilePartition file_partitions = 6;
repeated spark.spark_expression.DataType partition_schema = 5;
repeated spark.spark_expression.Expr data_filters = 6;
repeated SparkFilePartition file_partitions = 7;
}

message Projection {
Expand Down
34 changes: 30 additions & 4 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, Normalize
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometNativeScanExec, CometScanExec, CometSinkPlaceHolder, CometSparkToColumnarExec, DecimalPrecision}
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometScanExec, CometSinkPlaceHolder, CometSparkToColumnarExec, DecimalPrecision}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
Expand Down Expand Up @@ -2507,12 +2507,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
partitions.foreach(p => {
val inputPartitions = p.asInstanceOf[DataSourceRDDPartition].inputPartitions
inputPartitions.foreach(partition => {
partition2Proto(partition.asInstanceOf[FilePartition], nativeScanBuilder)
partition2Proto(
partition.asInstanceOf[FilePartition],
nativeScanBuilder,
scan.relation.partitionSchema)
})
})
case rdd: FileScanRDD =>
rdd.filePartitions.foreach(partition => {
partition2Proto(partition, nativeScanBuilder)
partition2Proto(partition, nativeScanBuilder, scan.relation.partitionSchema)
})
case _ =>
}
Expand All @@ -2521,9 +2524,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
new SparkToParquetSchemaConverter(conf).convert(scan.requiredSchema)
val dataSchemaParquet =
new SparkToParquetSchemaConverter(conf).convert(scan.relation.dataSchema)
val partitionSchema = scan.relation.partitionSchema.fields.flatMap { field =>
serializeDataType(field.dataType)
}
// In `CometScanRule`, we ensure partitionSchema is supported.
assert(partitionSchema.length == scan.relation.partitionSchema.fields.length)

nativeScanBuilder.setRequiredSchema(requiredSchemaParquet.toString)
nativeScanBuilder.setDataSchema(dataSchemaParquet.toString)
nativeScanBuilder.addAllPartitionSchema(partitionSchema.toIterable.asJava)

Some(result.setNativeScan(nativeScanBuilder).build())

Expand Down Expand Up @@ -3191,10 +3200,27 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim

private def partition2Proto(
partition: FilePartition,
nativeScanBuilder: OperatorOuterClass.NativeScan.Builder): Unit = {
nativeScanBuilder: OperatorOuterClass.NativeScan.Builder,
partitionSchema: StructType): Unit = {
val partitionBuilder = OperatorOuterClass.SparkFilePartition.newBuilder()
partition.files.foreach(file => {
// Process the partition values
val partitionValues = file.partitionValues
assert(partitionValues.numFields == partitionSchema.length)
val partitionVals =
partitionValues.toSeq(partitionSchema).zipWithIndex.map { case (value, i) =>
val attr = partitionSchema(i)
val valueProto = exprToProto(Literal(value, attr.dataType), Seq.empty)
// In `CometScanRule`, we have already checked that all partition values are
// supported. So, we can safely use `get` here.
assert(
valueProto.isDefined,
s"Unsupported partition value: $value, type: ${attr.dataType}")
valueProto.get
}

val fileBuilder = OperatorOuterClass.SparkPartitionedFile.newBuilder()
partitionVals.foreach(fileBuilder.addPartitionValues)
fileBuilder
.setFilePath(file.pathUri.toString)
.setStart(file.start)
Expand Down
Loading