Skip to content

Commit 2686a4b

Browse files
feat: [comet-parquet-exec] Use Datafusion based record batch reader for use in iceberg reads (#1174)
* wip. Use DF's ParquetExec for Iceberg API * wip - await?? * wip * wip - * fix shading issue * fix shading issue * fixes * refactor to remove arrow based reader * rename config * Fix config defaults --------- Co-authored-by: Andy Grove <agrove@apache.org>
1 parent 8563edf commit 2686a4b

File tree

15 files changed

+396
-160
lines changed

15 files changed

+396
-160
lines changed

common/src/main/java/org/apache/comet/parquet/Native.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,10 @@ public static native void setPageV2(
246246
* @param filePath
247247
* @param start
248248
* @param length
249-
* @param required_columns array of names of fields to read
250249
* @return a handle to the record batch reader, used in subsequent calls.
251250
*/
252251
public static native long initRecordBatchReader(
253-
String filePath, long start, long length, Object[] required_columns);
254-
255-
public static native int numRowGroups(long handle);
256-
257-
public static native long numTotalRows(long handle);
252+
String filePath, long fileSize, long start, long length, byte[] requiredSchema);
258253

259254
// arrow native version of read batch
260255
/**

common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919

2020
package org.apache.comet.parquet;
2121

22+
import java.io.ByteArrayOutputStream;
2223
import java.io.Closeable;
2324
import java.io.IOException;
2425
import java.lang.reflect.InvocationTargetException;
2526
import java.lang.reflect.Method;
2627
import java.net.URISyntaxException;
28+
import java.nio.channels.Channels;
2729
import java.util.*;
2830

2931
import scala.Option;
@@ -36,6 +38,9 @@
3638
import org.apache.arrow.c.CometSchemaImporter;
3739
import org.apache.arrow.memory.BufferAllocator;
3840
import org.apache.arrow.memory.RootAllocator;
41+
import org.apache.arrow.vector.ipc.WriteChannel;
42+
import org.apache.arrow.vector.ipc.message.MessageSerializer;
43+
import org.apache.arrow.vector.types.pojo.Schema;
3944
import org.apache.hadoop.conf.Configuration;
4045
import org.apache.hadoop.fs.Path;
4146
import org.apache.hadoop.mapreduce.InputSplit;
@@ -52,6 +57,7 @@
5257
import org.apache.spark.TaskContext$;
5358
import org.apache.spark.executor.TaskMetrics;
5459
import org.apache.spark.sql.catalyst.InternalRow;
60+
import org.apache.spark.sql.comet.CometArrowUtils;
5561
import org.apache.spark.sql.comet.parquet.CometParquetReadSupport;
5662
import org.apache.spark.sql.execution.datasources.PartitionedFile;
5763
import org.apache.spark.sql.execution.datasources.parquet.ParquetToSparkSchemaConverter;
@@ -99,7 +105,6 @@ public class NativeBatchReader extends RecordReader<Void, ColumnarBatch> impleme
99105
private PartitionedFile file;
100106
private final Map<String, SQLMetric> metrics;
101107

102-
private long rowsRead;
103108
private StructType sparkSchema;
104109
private MessageType requestedSchema;
105110
private CometVector[] vectors;
@@ -111,9 +116,6 @@ public class NativeBatchReader extends RecordReader<Void, ColumnarBatch> impleme
111116
private boolean isInitialized;
112117
private ParquetMetadata footer;
113118

114-
/** The total number of rows across all row groups of the input split. */
115-
private long totalRowCount;
116-
117119
/**
118120
* Whether the native scan should always return decimal represented by 128 bits, regardless of its
119121
* precision. Normally, this should be true if native execution is enabled, since Arrow compute
@@ -224,6 +226,7 @@ public void init() throws URISyntaxException, IOException {
224226
long start = file.start();
225227
long length = file.length();
226228
String filePath = file.filePath().toString();
229+
long fileSize = file.fileSize();
227230

228231
requestedSchema = footer.getFileMetaData().getSchema();
229232
MessageType fileSchema = requestedSchema;
@@ -254,6 +257,13 @@ public void init() throws URISyntaxException, IOException {
254257
}
255258
} ////// End get requested schema
256259

260+
String timeZoneId = conf.get("spark.sql.session.timeZone");
261+
Schema arrowSchema = CometArrowUtils.toArrowSchema(sparkSchema, timeZoneId);
262+
ByteArrayOutputStream out = new ByteArrayOutputStream();
263+
WriteChannel writeChannel = new WriteChannel(Channels.newChannel(out));
264+
MessageSerializer.serialize(writeChannel, arrowSchema);
265+
byte[] serializedRequestedArrowSchema = out.toByteArray();
266+
257267
//// Create Column readers
258268
List<ColumnDescriptor> columns = requestedSchema.getColumns();
259269
int numColumns = columns.size();
@@ -334,13 +344,9 @@ public void init() throws URISyntaxException, IOException {
334344
}
335345
}
336346

337-
// TODO: (ARROW NATIVE) Use a ProjectionMask here ?
338-
ArrayList<String> requiredColumns = new ArrayList<>();
339-
for (Type col : requestedSchema.asGroupType().getFields()) {
340-
requiredColumns.add(col.getName());
341-
}
342-
this.handle = Native.initRecordBatchReader(filePath, start, length, requiredColumns.toArray());
343-
totalRowCount = Native.numRowGroups(handle);
347+
this.handle =
348+
Native.initRecordBatchReader(
349+
filePath, fileSize, start, length, serializedRequestedArrowSchema);
344350
isInitialized = true;
345351
}
346352

@@ -375,7 +381,7 @@ public ColumnarBatch getCurrentValue() {
375381

376382
@Override
377383
public float getProgress() {
378-
return (float) rowsRead / totalRowCount;
384+
return 0;
379385
}
380386

381387
/**
@@ -395,7 +401,7 @@ public ColumnarBatch currentBatch() {
395401
public boolean nextBatch() throws IOException {
396402
Preconditions.checkState(isInitialized, "init() should be called first!");
397403

398-
if (rowsRead >= totalRowCount) return false;
404+
// if (rowsRead >= totalRowCount) return false;
399405
int batchSize;
400406

401407
try {
@@ -432,7 +438,6 @@ public boolean nextBatch() throws IOException {
432438
}
433439

434440
currentBatch.setNumRows(batchSize);
435-
rowsRead += batchSize;
436441
return true;
437442
}
438443

@@ -457,6 +462,9 @@ private int loadNextBatch() throws Throwable {
457462
long startNs = System.nanoTime();
458463

459464
int batchSize = Native.readNextRecordBatch(this.handle);
465+
if (batchSize == 0) {
466+
return batchSize;
467+
}
460468
if (importer != null) importer.close();
461469
importer = new CometSchemaImporter(ALLOCATOR);
462470

common/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ object CometConf extends ShimCometConf {
7575
"that to enable native vectorized execution, both this config and " +
7676
"'spark.comet.exec.enabled' need to be enabled.")
7777
.booleanConf
78-
.createWithDefault(true)
78+
.createWithDefault(false)
7979

8080
val COMET_FULL_NATIVE_SCAN_ENABLED: ConfigEntry[Boolean] = conf(
8181
"spark.comet.native.scan.enabled")
@@ -85,15 +85,15 @@ object CometConf extends ShimCometConf {
8585
"read supported data sources (currently only Parquet is supported natively)." +
8686
" By default, this config is true.")
8787
.booleanConf
88-
.createWithDefault(false)
88+
.createWithDefault(true)
8989

90-
val COMET_NATIVE_ARROW_SCAN_ENABLED: ConfigEntry[Boolean] = conf(
90+
val COMET_NATIVE_RECORDBATCH_READER_ENABLED: ConfigEntry[Boolean] = conf(
9191
"spark.comet.native.arrow.scan.enabled")
9292
.internal()
9393
.doc(
94-
"Whether to enable the fully native arrow based scan. When this is turned on, Spark will " +
95-
"use Comet to read Parquet files natively via the Arrow based Parquet reader." +
96-
" By default, this config is false.")
94+
"Whether to enable the fully native datafusion based column reader. When this is turned on," +
95+
" Spark will use Comet to read Parquet files natively via the Datafusion based Parquet" +
96+
" reader. By default, this config is false.")
9797
.booleanConf
9898
.createWithDefault(false)
9999

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.spark.sql.comet
21+
22+
import scala.collection.JavaConverters._
23+
24+
import org.apache.arrow.memory.RootAllocator
25+
import org.apache.arrow.vector.complex.MapVector
26+
import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit}
27+
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
28+
import org.apache.spark.sql.internal.SQLConf
29+
import org.apache.spark.sql.types._
30+
31+
object CometArrowUtils {
32+
33+
val rootAllocator = new RootAllocator(Long.MaxValue)
34+
35+
// todo: support more types.
36+
37+
/** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */
38+
def toArrowType(dt: DataType, timeZoneId: String): ArrowType = dt match {
39+
case BooleanType => ArrowType.Bool.INSTANCE
40+
case ByteType => new ArrowType.Int(8, true)
41+
case ShortType => new ArrowType.Int(8 * 2, true)
42+
case IntegerType => new ArrowType.Int(8 * 4, true)
43+
case LongType => new ArrowType.Int(8 * 8, true)
44+
case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
45+
case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
46+
case StringType => ArrowType.Utf8.INSTANCE
47+
case BinaryType => ArrowType.Binary.INSTANCE
48+
case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale)
49+
case DateType => new ArrowType.Date(DateUnit.DAY)
50+
case TimestampType if timeZoneId == null =>
51+
throw new IllegalStateException("Missing timezoneId where it is mandatory.")
52+
case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId)
53+
case TimestampNTZType =>
54+
new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)
55+
case NullType => ArrowType.Null.INSTANCE
56+
case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH)
57+
case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND)
58+
case _ =>
59+
throw new IllegalArgumentException()
60+
}
61+
62+
def fromArrowType(dt: ArrowType): DataType = dt match {
63+
case ArrowType.Bool.INSTANCE => BooleanType
64+
case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType
65+
case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType
66+
case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => IntegerType
67+
case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => LongType
68+
case float: ArrowType.FloatingPoint
69+
if float.getPrecision() == FloatingPointPrecision.SINGLE =>
70+
FloatType
71+
case float: ArrowType.FloatingPoint
72+
if float.getPrecision() == FloatingPointPrecision.DOUBLE =>
73+
DoubleType
74+
case ArrowType.Utf8.INSTANCE => StringType
75+
case ArrowType.Binary.INSTANCE => BinaryType
76+
case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale)
77+
case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
78+
case ts: ArrowType.Timestamp
79+
if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null =>
80+
TimestampNTZType
81+
case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType
82+
case ArrowType.Null.INSTANCE => NullType
83+
case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH =>
84+
YearMonthIntervalType()
85+
case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND => DayTimeIntervalType()
86+
case _ => throw new IllegalArgumentException()
87+
// throw QueryExecutionErrors.unsupportedArrowTypeError(dt)
88+
}
89+
90+
/** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */
91+
def toArrowField(name: String, dt: DataType, nullable: Boolean, timeZoneId: String): Field = {
92+
dt match {
93+
case ArrayType(elementType, containsNull) =>
94+
val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null)
95+
new Field(
96+
name,
97+
fieldType,
98+
Seq(toArrowField("element", elementType, containsNull, timeZoneId)).asJava)
99+
case StructType(fields) =>
100+
val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null)
101+
new Field(
102+
name,
103+
fieldType,
104+
fields
105+
.map { field =>
106+
toArrowField(field.name, field.dataType, field.nullable, timeZoneId)
107+
}
108+
.toSeq
109+
.asJava)
110+
case MapType(keyType, valueType, valueContainsNull) =>
111+
val mapType = new FieldType(nullable, new ArrowType.Map(false), null)
112+
// Note: Map Type struct can not be null, Struct Type key field can not be null
113+
new Field(
114+
name,
115+
mapType,
116+
Seq(
117+
toArrowField(
118+
MapVector.DATA_VECTOR_NAME,
119+
new StructType()
120+
.add(MapVector.KEY_NAME, keyType, nullable = false)
121+
.add(MapVector.VALUE_NAME, valueType, nullable = valueContainsNull),
122+
nullable = false,
123+
timeZoneId)).asJava)
124+
case udt: UserDefinedType[_] => toArrowField(name, udt.sqlType, nullable, timeZoneId)
125+
case dataType =>
126+
val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId), null)
127+
new Field(name, fieldType, Seq.empty[Field].asJava)
128+
}
129+
}
130+
131+
def fromArrowField(field: Field): DataType = {
132+
field.getType match {
133+
case _: ArrowType.Map =>
134+
val elementField = field.getChildren.get(0)
135+
val keyType = fromArrowField(elementField.getChildren.get(0))
136+
val valueType = fromArrowField(elementField.getChildren.get(1))
137+
MapType(keyType, valueType, elementField.getChildren.get(1).isNullable)
138+
case ArrowType.List.INSTANCE =>
139+
val elementField = field.getChildren().get(0)
140+
val elementType = fromArrowField(elementField)
141+
ArrayType(elementType, containsNull = elementField.isNullable)
142+
case ArrowType.Struct.INSTANCE =>
143+
val fields = field.getChildren().asScala.map { child =>
144+
val dt = fromArrowField(child)
145+
StructField(child.getName, dt, child.isNullable)
146+
}
147+
StructType(fields.toArray)
148+
case arrowType => fromArrowType(arrowType)
149+
}
150+
}
151+
152+
/**
153+
* Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType
154+
*/
155+
def toArrowSchema(schema: StructType, timeZoneId: String): Schema = {
156+
new Schema(schema.map { field =>
157+
toArrowField(field.name, field.dataType, field.nullable, timeZoneId)
158+
}.asJava)
159+
}
160+
161+
def fromArrowSchema(schema: Schema): StructType = {
162+
StructType(schema.getFields.asScala.map { field =>
163+
val dt = fromArrowField(field)
164+
StructField(field.getName, dt, field.isNullable)
165+
}.toArray)
166+
}
167+
168+
/** Return Map with conf settings to be used in ArrowPythonRunner */
169+
def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = {
170+
val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone)
171+
val pandasColsByName = Seq(
172+
SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key ->
173+
conf.pandasGroupedMapAssignColumnsByName.toString)
174+
val arrowSafeTypeCheck = Seq(
175+
SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key ->
176+
conf.arrowSafeTypeConversion.toString)
177+
Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*)
178+
}
179+
180+
}

native/Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

native/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ arrow = { version = "53.2.0", features = ["prettyprint", "ffi", "chrono-tz"] }
3737
arrow-array = { version = "53.2.0" }
3838
arrow-buffer = { version = "53.2.0" }
3939
arrow-data = { version = "53.2.0" }
40+
arrow-ipc = { version = "53.2.0" }
4041
arrow-schema = { version = "53.2.0" }
42+
flatbuffers = { version = "24.3.25" }
4143
parquet = { version = "53.2.0", default-features = false, features = ["experimental"] }
4244
datafusion-common = { version = "43.0.0" }
4345
datafusion = { version = "43.0.0", default-features = false, features = ["unicode_expressions", "crypto_expressions", "parquet"] }

native/core/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ arrow-array = { workspace = true }
4040
arrow-buffer = { workspace = true }
4141
arrow-data = { workspace = true }
4242
arrow-schema = { workspace = true }
43+
arrow-ipc = { workspace = true }
44+
flatbuffers = { workspace = true }
4345
parquet = { workspace = true, default-features = false, features = ["experimental"] }
4446
half = { version = "2.4.1", default-features = false }
4547
futures = "0.3.28"

native/core/src/execution/datafusion/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@
2020
pub mod expressions;
2121
mod operators;
2222
pub mod planner;
23-
mod schema_adapter;
23+
pub(crate) mod schema_adapter;
2424
pub mod shuffle_writer;
2525
mod util;

0 commit comments

Comments
 (0)