Skip to content

Commit e0d8077

Browse files
authored
[comet-parquet-exec] Simplify schema logic for CometNativeScan (#1142)
* Serialize original data schema and required schema, generate projection vector on the Java side. * Sending over more schema info like column names and nullability. * Using the new stuff in the proto. About to take the old out. * Remove old logic. * remove errant print. * Serialize original data schema and required schema, generate projection vector on the Java side. * Sending over more schema info like column names and nullability. * Using the new stuff in the proto. About to take the old out. * Remove old logic. * remove errant print. * Remove commented print. format. * Remove commented print. format. * Fix projection_vector to include partition_schema cols correctly. * Rename variable.
1 parent e3672f7 commit e0d8077

File tree

4 files changed

+82
-70
lines changed

4 files changed

+82
-70
lines changed

native/core/src/execution/datafusion/planner.rs

Lines changed: 40 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ use datafusion_physical_expr::LexOrdering;
121121
use itertools::Itertools;
122122
use jni::objects::GlobalRef;
123123
use num::{BigInt, ToPrimitive};
124-
use parquet::schema::parser::parse_message_type;
125124
use std::cmp::max;
126125
use std::{collections::HashMap, sync::Arc};
127126
use url::Url;
@@ -950,50 +949,28 @@ impl PhysicalPlanner {
950949
))
951950
}
952951
OpStruct::NativeScan(scan) => {
953-
let data_schema = parse_message_type(&scan.data_schema).unwrap();
954-
let required_schema = parse_message_type(&scan.required_schema).unwrap();
955-
956-
let data_schema_descriptor =
957-
parquet::schema::types::SchemaDescriptor::new(Arc::new(data_schema));
958-
let data_schema_arrow = Arc::new(
959-
parquet::arrow::schema::parquet_to_arrow_schema(&data_schema_descriptor, None)
960-
.unwrap(),
961-
);
962-
963-
let required_schema_descriptor =
964-
parquet::schema::types::SchemaDescriptor::new(Arc::new(required_schema));
965-
let required_schema_arrow = Arc::new(
966-
parquet::arrow::schema::parquet_to_arrow_schema(
967-
&required_schema_descriptor,
968-
None,
969-
)
970-
.unwrap(),
971-
);
972-
973-
let partition_schema_arrow = scan
974-
.partition_schema
952+
let data_schema = convert_spark_types_to_arrow_schema(scan.data_schema.as_slice());
953+
let required_schema: SchemaRef =
954+
convert_spark_types_to_arrow_schema(scan.required_schema.as_slice());
955+
let partition_schema: SchemaRef =
956+
convert_spark_types_to_arrow_schema(scan.partition_schema.as_slice());
957+
let projection_vector: Vec<usize> = scan
958+
.projection_vector
975959
.iter()
976-
.map(to_arrow_datatype)
977-
.collect_vec();
978-
let partition_fields: Vec<_> = partition_schema_arrow
979-
.iter()
980-
.enumerate()
981-
.map(|(idx, data_type)| {
982-
Field::new(format!("part_{}", idx), data_type.clone(), true)
983-
})
960+
.map(|offset| *offset as usize)
984961
.collect();
985962

986963
// Convert the Spark expressions to Physical expressions
987964
let data_filters: Result<Vec<Arc<dyn PhysicalExpr>>, ExecutionError> = scan
988965
.data_filters
989966
.iter()
990-
.map(|expr| self.create_expr(expr, Arc::clone(&required_schema_arrow)))
967+
.map(|expr| self.create_expr(expr, Arc::clone(&required_schema)))
991968
.collect();
992969

993970
// Create a conjunctive form of the vector because ParquetExecBuilder takes
994971
// a single expression
995972
let data_filters = data_filters?;
996-
let test_data_filters = data_filters.clone().into_iter().reduce(|left, right| {
973+
let cnf_data_filters = data_filters.clone().into_iter().reduce(|left, right| {
997974
Arc::new(BinaryExpr::new(
998975
left,
999976
datafusion::logical_expr::Operator::And,
@@ -1064,29 +1041,21 @@ impl PhysicalPlanner {
10641041
assert_eq!(file_groups.len(), partition_count);
10651042

10661043
let object_store_url = ObjectStoreUrl::local_filesystem();
1044+
let partition_fields: Vec<Field> = partition_schema
1045+
.fields()
1046+
.iter()
1047+
.map(|field| {
1048+
Field::new(field.name(), field.data_type().clone(), field.is_nullable())
1049+
})
1050+
.collect_vec();
10671051
let mut file_scan_config =
1068-
FileScanConfig::new(object_store_url, Arc::clone(&data_schema_arrow))
1052+
FileScanConfig::new(object_store_url, Arc::clone(&data_schema))
10691053
.with_file_groups(file_groups)
10701054
.with_table_partition_cols(partition_fields);
10711055

1072-
// Check for projection, if so generate the vector and add to FileScanConfig.
1073-
let mut projection_vector: Vec<usize> =
1074-
Vec::with_capacity(required_schema_arrow.fields.len());
1075-
// TODO: could be faster with a hashmap rather than iterating over data_schema_arrow with index_of.
1076-
required_schema_arrow.fields.iter().for_each(|field| {
1077-
projection_vector.push(data_schema_arrow.index_of(field.name()).unwrap());
1078-
});
1079-
1080-
partition_schema_arrow
1081-
.iter()
1082-
.enumerate()
1083-
.for_each(|(idx, _)| {
1084-
projection_vector.push(idx + data_schema_arrow.fields.len());
1085-
});
1086-
10871056
assert_eq!(
10881057
projection_vector.len(),
1089-
required_schema_arrow.fields.len() + partition_schema_arrow.len()
1058+
required_schema.fields.len() + partition_schema.fields.len()
10901059
);
10911060
file_scan_config = file_scan_config.with_projection(Some(projection_vector));
10921061

@@ -1095,13 +1064,11 @@ impl PhysicalPlanner {
10951064
table_parquet_options.global.pushdown_filters = true;
10961065
table_parquet_options.global.reorder_filters = true;
10971066

1098-
let mut builder = ParquetExecBuilder::new(file_scan_config)
1099-
.with_table_parquet_options(table_parquet_options)
1100-
.with_schema_adapter_factory(
1101-
Arc::new(CometSchemaAdapterFactory::default()),
1102-
);
1067+
let mut builder = ParquetExecBuilder::new(file_scan_config)
1068+
.with_table_parquet_options(table_parquet_options)
1069+
.with_schema_adapter_factory(Arc::new(CometSchemaAdapterFactory::default()));
11031070

1104-
if let Some(filter) = test_data_filters {
1071+
if let Some(filter) = cnf_data_filters {
11051072
builder = builder.with_predicate(filter);
11061073
}
11071074

@@ -2309,6 +2276,23 @@ fn from_protobuf_eval_mode(value: i32) -> Result<EvalMode, prost::DecodeError> {
23092276
}
23102277
}
23112278

2279+
fn convert_spark_types_to_arrow_schema(
2280+
spark_types: &[spark_operator::SparkStructField],
2281+
) -> SchemaRef {
2282+
let arrow_fields = spark_types
2283+
.iter()
2284+
.map(|spark_type| {
2285+
Field::new(
2286+
String::clone(&spark_type.name),
2287+
to_arrow_datatype(spark_type.data_type.as_ref().unwrap()),
2288+
spark_type.nullable,
2289+
)
2290+
})
2291+
.collect_vec();
2292+
let arrow_schema: SchemaRef = Arc::new(Schema::new(arrow_fields));
2293+
arrow_schema
2294+
}
2295+
23122296
#[cfg(test)]
23132297
mod tests {
23142298
use std::{sync::Arc, task::Poll};

native/core/src/execution/datafusion/schema_adapter.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,8 @@ impl SchemaMapper for SchemaMapping {
259259
EvalMode::Legacy,
260260
"UTC",
261261
false,
262-
)?.into_array(batch_col.len())
262+
)?
263+
.into_array(batch_col.len())
263264
// and if that works, return the field and column.
264265
.map(|new_col| (new_col, table_field.clone()))
265266
})

native/proto/src/proto/operator.proto

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ message SparkFilePartition {
6161
repeated SparkPartitionedFile partitioned_file = 1;
6262
}
6363

64+
message SparkStructField {
65+
string name = 1;
66+
spark.spark_expression.DataType data_type = 2;
67+
bool nullable = 3;
68+
}
69+
6470
message Scan {
6571
repeated spark.spark_expression.DataType fields = 1;
6672
// The source of the scan (e.g. file scan, broadcast exchange, shuffle, etc). This
@@ -75,11 +81,12 @@ message NativeScan {
7581
// is purely for informational purposes when viewing native query plans in
7682
// debug mode.
7783
string source = 2;
78-
string required_schema = 3;
79-
string data_schema = 4;
80-
repeated spark.spark_expression.DataType partition_schema = 5;
84+
repeated SparkStructField required_schema = 3;
85+
repeated SparkStructField data_schema = 4;
86+
repeated SparkStructField partition_schema = 5;
8187
repeated spark.spark_expression.Expr data_filters = 6;
8288
repeated SparkFilePartition file_partitions = 7;
89+
repeated int64 projection_vector = 8;
8390
}
8491

8592
message Projection {

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ import org.apache.spark.sql.execution._
3636
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
3737
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
3838
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD}
39-
import org.apache.spark.sql.execution.datasources.parquet.SparkToParquetSchemaConverter
4039
import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDD, DataSourceRDDPartition}
4140
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
4241
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
@@ -2520,18 +2519,28 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
25202519
case _ =>
25212520
}
25222521

2523-
val requiredSchemaParquet =
2524-
new SparkToParquetSchemaConverter(conf).convert(scan.requiredSchema)
2525-
val dataSchemaParquet =
2526-
new SparkToParquetSchemaConverter(conf).convert(scan.relation.dataSchema)
2527-
val partitionSchema = scan.relation.partitionSchema.fields.flatMap { field =>
2528-
serializeDataType(field.dataType)
2529-
}
2522+
val partitionSchema = schema2Proto(scan.relation.partitionSchema.fields)
2523+
val requiredSchema = schema2Proto(scan.requiredSchema.fields)
2524+
val dataSchema = schema2Proto(scan.relation.dataSchema.fields)
2525+
2526+
val data_schema_idxs = scan.requiredSchema.fields.map(field => {
2527+
scan.relation.dataSchema.fieldIndex(field.name)
2528+
})
2529+
val partition_schema_idxs = Array
2530+
.range(
2531+
scan.relation.dataSchema.fields.length,
2532+
scan.relation.dataSchema.length + scan.relation.partitionSchema.fields.length)
2533+
2534+
val projection_vector = (data_schema_idxs ++ partition_schema_idxs).map(idx =>
2535+
idx.toLong.asInstanceOf[java.lang.Long])
2536+
2537+
nativeScanBuilder.addAllProjectionVector(projection_vector.toIterable.asJava)
2538+
25302539
// In `CometScanRule`, we ensure partitionSchema is supported.
25312540
assert(partitionSchema.length == scan.relation.partitionSchema.fields.length)
25322541

2533-
nativeScanBuilder.setRequiredSchema(requiredSchemaParquet.toString)
2534-
nativeScanBuilder.setDataSchema(dataSchemaParquet.toString)
2542+
nativeScanBuilder.addAllDataSchema(dataSchema.toIterable.asJava)
2543+
nativeScanBuilder.addAllRequiredSchema(requiredSchema.toIterable.asJava)
25352544
nativeScanBuilder.addAllPartitionSchema(partitionSchema.toIterable.asJava)
25362545

25372546
Some(result.setNativeScan(nativeScanBuilder).build())
@@ -3198,6 +3207,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
31983207
true
31993208
}
32003209

3210+
private def schema2Proto(
3211+
fields: Array[StructField]): Array[OperatorOuterClass.SparkStructField] = {
3212+
val fieldBuilder = OperatorOuterClass.SparkStructField.newBuilder()
3213+
fields.map(field => {
3214+
fieldBuilder.setName(field.name)
3215+
fieldBuilder.setDataType(serializeDataType(field.dataType).get)
3216+
fieldBuilder.setNullable(field.nullable)
3217+
fieldBuilder.build()
3218+
})
3219+
}
3220+
32013221
private def partition2Proto(
32023222
partition: FilePartition,
32033223
nativeScanBuilder: OperatorOuterClass.NativeScan.Builder,

0 commit comments

Comments
 (0)