diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index bd2e04d0c..af0c65f00 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -337,6 +337,25 @@ object CometConf { "enabled when reading from Iceberg tables.") .booleanConf .createWithDefault(false) + + val COMET_ROW_TO_COLUMNAR_ENABLED: ConfigEntry[Boolean] = conf( + "spark.comet.rowToColumnar.enabled") + .internal() + .doc("Whether to enable row to columnar conversion in Comet. When this is turned on, " + + "Comet will convert row-based operators in spark.comet.rowToColumnar.sourceNodeList into " + + "columnar based before processing.") + .booleanConf + .createWithDefault(false) + + val COMET_ROW_TO_COLUMNAR_SOURCE_NODE_LIST: ConfigEntry[Seq[String]] = + conf("spark.comet.rowToColumnar.sourceNodeList") + .doc( + "A comma-separated list of row-based data sources that will be converted to columnar " + + "format when 'spark.comet.rowToColumnar.enabled' is true") + .stringConf + .toSequence + .createWithDefault(Seq("Range,InMemoryTableScan")) + } object ConfigHelpers { diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 3756da963..763ccff7f 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -23,6 +23,8 @@ import scala.collection.mutable import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictionaryProvider, Data} import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.dictionary.DictionaryProvider import org.apache.spark.SparkException import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.vectorized.ColumnarBatch @@ -132,3 +134,18 @@ class NativeUtil { new ColumnarBatch(arrayVectors.toArray, maxNumRows) } } + +object NativeUtil { + def rootAsBatch(arrowRoot: VectorSchemaRoot): ColumnarBatch = { + rootAsBatch(arrowRoot, null) + } + + def rootAsBatch(arrowRoot: VectorSchemaRoot, provider: DictionaryProvider): ColumnarBatch = { + val vectors = (0 until arrowRoot.getFieldVectors.size()).map { i => + val vector = arrowRoot.getFieldVectors.get(i) + // Native shuffle always uses decimal128. + CometVector.getVector(vector, true, provider) + } + new ColumnarBatch(vectors.toArray, arrowRoot.getRowCount) + } +} diff --git a/common/src/main/scala/org/apache/comet/vector/StreamReader.scala b/common/src/main/scala/org/apache/comet/vector/StreamReader.scala index da72383e8..61d800bfb 100644 --- a/common/src/main/scala/org/apache/comet/vector/StreamReader.scala +++ b/common/src/main/scala/org/apache/comet/vector/StreamReader.scala @@ -21,13 +21,11 @@ package org.apache.comet.vector import java.nio.channels.ReadableByteChannel -import scala.collection.JavaConverters.collectionAsScalaIterableConverter - import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.{ArrowStreamReader, ReadChannel} import org.apache.arrow.vector.ipc.message.MessageChannelReader -import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} +import org.apache.spark.sql.vectorized.ColumnarBatch /** * A reader that consumes Arrow data from an input channel, and produces Comet batches. @@ -47,14 +45,7 @@ case class StreamReader(channel: ReadableByteChannel) extends AutoCloseable { } private def rootAsBatch(root: VectorSchemaRoot): ColumnarBatch = { - val columns = root.getFieldVectors.asScala.map { vector => - // Native shuffle always uses decimal128. - CometVector.getVector(vector, true, arrowReader).asInstanceOf[ColumnVector] - }.toArray - - val batch = new ColumnarBatch(columns) - batch.setNumRows(root.getRowCount) - batch + NativeUtil.rootAsBatch(root, arrowReader) } override def close(): Unit = { diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowWriters.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowWriters.scala new file mode 100644 index 000000000..8d9f373fe --- /dev/null +++ b/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowWriters.scala @@ -0,0 +1,472 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.arrow + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.complex._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.comet.util.Utils +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.types._ + +/** + * This file is mostly copied from Spark SQL's + * org.apache.spark.sql.execution.arrow.ArrowWriter.scala. Comet shadows Arrow classes to avoid + * potential conflicts with Spark's Arrow dependencies, hence we cannot reuse Spark's ArrowWriter + * directly. + */ +private[arrow] object ArrowWriter { + def create(root: VectorSchemaRoot): ArrowWriter = { + val children = root.getFieldVectors().asScala.map { vector => + vector.allocateNew() + createFieldWriter(vector) + } + new ArrowWriter(root, children.toArray) + } + + private[sql] def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { + val field = vector.getField() + (Utils.fromArrowField(field), vector) match { + case (BooleanType, vector: BitVector) => new BooleanWriter(vector) + case (ByteType, vector: TinyIntVector) => new ByteWriter(vector) + case (ShortType, vector: SmallIntVector) => new ShortWriter(vector) + case (IntegerType, vector: IntVector) => new IntegerWriter(vector) + case (LongType, vector: BigIntVector) => new LongWriter(vector) + case (FloatType, vector: Float4Vector) => new FloatWriter(vector) + case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector) + case (DecimalType.Fixed(precision, scale), vector: DecimalVector) => + new DecimalWriter(vector, precision, scale) + case (StringType, vector: VarCharVector) => new StringWriter(vector) + case (StringType, vector: LargeVarCharVector) => new LargeStringWriter(vector) + case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector) + case (BinaryType, vector: LargeVarBinaryVector) => new LargeBinaryWriter(vector) + case (DateType, vector: DateDayVector) => new DateWriter(vector) + case (TimestampType, vector: TimeStampMicroTZVector) => new TimestampWriter(vector) + case (TimestampNTZType, vector: TimeStampMicroVector) => new TimestampNTZWriter(vector) + case (ArrayType(_, _), vector: ListVector) => + val elementVector = createFieldWriter(vector.getDataVector()) + new ArrayWriter(vector, elementVector) + case (MapType(_, _, _), vector: MapVector) => + val structVector = vector.getDataVector.asInstanceOf[StructVector] + val keyWriter = createFieldWriter(structVector.getChild(MapVector.KEY_NAME)) + val valueWriter = createFieldWriter(structVector.getChild(MapVector.VALUE_NAME)) + new MapWriter(vector, structVector, keyWriter, valueWriter) + case (StructType(_), vector: StructVector) => + val children = (0 until vector.size()).map { ordinal => + createFieldWriter(vector.getChildByOrdinal(ordinal)) + } + new StructWriter(vector, children.toArray) + case (NullType, vector: NullVector) => new NullWriter(vector) + case (_: YearMonthIntervalType, vector: IntervalYearVector) => + new IntervalYearWriter(vector) + case (_: DayTimeIntervalType, vector: DurationVector) => new DurationWriter(vector) +// case (CalendarIntervalType, vector: IntervalMonthDayNanoVector) => +// new IntervalMonthDayNanoWriter(vector) + case (dt, _) => + throw QueryExecutionErrors.notSupportTypeError(dt) + } + } +} + +class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { + + def schema: StructType = Utils.fromArrowSchema(root.getSchema()) + + private var count: Int = 0 + + def write(row: InternalRow): Unit = { + var i = 0 + while (i < fields.length) { + fields(i).write(row, i) + i += 1 + } + count += 1 + } + + def finish(): Unit = { + root.setRowCount(count) + fields.foreach(_.finish()) + } + + def reset(): Unit = { + root.setRowCount(0) + count = 0 + fields.foreach(_.reset()) + } +} + +private[arrow] abstract class ArrowFieldWriter { + + def valueVector: ValueVector + + def name: String = valueVector.getField().getName() + def dataType: DataType = Utils.fromArrowField(valueVector.getField()) + def nullable: Boolean = valueVector.getField().isNullable() + + def setNull(): Unit + def setValue(input: SpecializedGetters, ordinal: Int): Unit + + private[arrow] var count: Int = 0 + + def write(input: SpecializedGetters, ordinal: Int): Unit = { + if (input.isNullAt(ordinal)) { + setNull() + } else { + setValue(input, ordinal) + } + count += 1 + } + + def finish(): Unit = { + valueVector.setValueCount(count) + } + + def reset(): Unit = { + valueVector.reset() + count = 0 + } +} + +private[arrow] class BooleanWriter(val valueVector: BitVector) extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueVector.setSafe(count, if (input.getBoolean(ordinal)) 1 else 0) + } +} + +private[arrow] class ByteWriter(val valueVector: TinyIntVector) extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueVector.setSafe(count, input.getByte(ordinal)) + } +} + +private[arrow] class ShortWriter(val valueVector: SmallIntVector) extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueVector.setSafe(count, input.getShort(ordinal)) + } +} + +private[arrow] class IntegerWriter(val valueVector: IntVector) extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueVector.setSafe(count, input.getInt(ordinal)) + } +} + +private[arrow] class LongWriter(val valueVector: BigIntVector) extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueVector.setSafe(count, input.getLong(ordinal)) + } +} + +private[arrow] class FloatWriter(val valueVector: Float4Vector) extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueVector.setSafe(count, input.getFloat(ordinal)) + } +} + +private[arrow] class DoubleWriter(val valueVector: Float8Vector) extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueVector.setSafe(count, input.getDouble(ordinal)) + } +} + +private[arrow] class DecimalWriter(val valueVector: DecimalVector, precision: Int, scale: Int) + extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val decimal = input.getDecimal(ordinal, precision, scale) + if (decimal.changePrecision(precision, scale)) { + valueVector.setSafe(count, decimal.toJavaBigDecimal) + } else { + setNull() + } + } +} + +private[arrow] class StringWriter(val valueVector: VarCharVector) extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val utf8 = input.getUTF8String(ordinal) + val utf8ByteBuffer = utf8.getByteBuffer + // todo: for off-heap UTF8String, how to pass in to arrow without copy? + valueVector.setSafe(count, utf8ByteBuffer, utf8ByteBuffer.position(), utf8.numBytes()) + } +} + +private[arrow] class LargeStringWriter(val valueVector: LargeVarCharVector) + extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val utf8 = input.getUTF8String(ordinal) + val utf8ByteBuffer = utf8.getByteBuffer + // todo: for off-heap UTF8String, how to pass in to arrow without copy? + valueVector.setSafe(count, utf8ByteBuffer, utf8ByteBuffer.position(), utf8.numBytes()) + } +} + +private[arrow] class BinaryWriter(val valueVector: VarBinaryVector) extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val bytes = input.getBinary(ordinal) + valueVector.setSafe(count, bytes, 0, bytes.length) + } +} + +private[arrow] class LargeBinaryWriter(val valueVector: LargeVarBinaryVector) + extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val bytes = input.getBinary(ordinal) + valueVector.setSafe(count, bytes, 0, bytes.length) + } +} + +private[arrow] class DateWriter(val valueVector: DateDayVector) extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueVector.setSafe(count, input.getInt(ordinal)) + } +} + +private[arrow] class TimestampWriter(val valueVector: TimeStampMicroTZVector) + extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueVector.setSafe(count, input.getLong(ordinal)) + } +} + +private[arrow] class TimestampNTZWriter(val valueVector: TimeStampMicroVector) + extends ArrowFieldWriter { + + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueVector.setSafe(count, input.getLong(ordinal)) + } +} + +private[arrow] class ArrayWriter(val valueVector: ListVector, val elementWriter: ArrowFieldWriter) + extends ArrowFieldWriter { + + override def setNull(): Unit = {} + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val array = input.getArray(ordinal) + var i = 0 + valueVector.startNewValue(count) + while (i < array.numElements()) { + elementWriter.write(array, i) + i += 1 + } + valueVector.endValue(count, array.numElements()) + } + + override def finish(): Unit = { + super.finish() + elementWriter.finish() + } + + override def reset(): Unit = { + super.reset() + elementWriter.reset() + } +} + +private[arrow] class StructWriter( + val valueVector: StructVector, + children: Array[ArrowFieldWriter]) + extends ArrowFieldWriter { + + override def setNull(): Unit = { + var i = 0 + while (i < children.length) { + children(i).setNull() + children(i).count += 1 + i += 1 + } + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val struct = input.getStruct(ordinal, children.length) + var i = 0 + valueVector.setIndexDefined(count) + while (i < struct.numFields) { + children(i).write(struct, i) + i += 1 + } + } + + override def finish(): Unit = { + super.finish() + children.foreach(_.finish()) + } + + override def reset(): Unit = { + super.reset() + children.foreach(_.reset()) + } +} + +private[arrow] class MapWriter( + val valueVector: MapVector, + val structVector: StructVector, + val keyWriter: ArrowFieldWriter, + val valueWriter: ArrowFieldWriter) + extends ArrowFieldWriter { + + override def setNull(): Unit = {} + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val map = input.getMap(ordinal) + valueVector.startNewValue(count) + val keys = map.keyArray() + val values = map.valueArray() + var i = 0 + while (i < map.numElements()) { + structVector.setIndexDefined(keyWriter.count) + keyWriter.write(keys, i) + valueWriter.write(values, i) + i += 1 + } + + valueVector.endValue(count, map.numElements()) + } + + override def finish(): Unit = { + super.finish() + keyWriter.finish() + valueWriter.finish() + } + + override def reset(): Unit = { + super.reset() + keyWriter.reset() + valueWriter.reset() + } +} + +private[arrow] class NullWriter(val valueVector: NullVector) extends ArrowFieldWriter { + + override def setNull(): Unit = {} + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {} +} + +private[arrow] class IntervalYearWriter(val valueVector: IntervalYearVector) + extends ArrowFieldWriter { + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueVector.setSafe(count, input.getInt(ordinal)); + } +} + +private[arrow] class DurationWriter(val valueVector: DurationVector) extends ArrowFieldWriter { + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueVector.setSafe(count, input.getLong(ordinal)) + } +} + +private[arrow] class IntervalMonthDayNanoWriter(val valueVector: IntervalMonthDayNanoVector) + extends ArrowFieldWriter { + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val ci = input.getInterval(ordinal) + valueVector.setSafe(count, ci.months, ci.days, Math.multiplyExact(ci.microseconds, 1000L)) + } +} diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala new file mode 100644 index 000000000..9dbd8dcdf --- /dev/null +++ b/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.arrow + +import org.apache.arrow.memory.{BufferAllocator, RootAllocator} +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.comet.util.Utils +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.vector.NativeUtil + +object CometArrowConverters extends Logging { + // TODO: we should reuse the same root allocator in the comet code base? + val rootAllocator: BufferAllocator = new RootAllocator(Long.MaxValue) + + // This is similar how Spark converts internal row to Arrow format except that it is transforming + // the result batch to Comet's ColumnarBatch instead of serialized bytes. + // There's another big difference that Comet may consume the ColumnarBatch by exporting it to + // the native side. Hence, we need to: + // 1. reset the Arrow writer after the ColumnarBatch is consumed + // 2. close the allocator when the task is finished but not when the iterator is all consumed + // The reason for the second point is that when ColumnarBatch is exported to the native side, the + // exported process increases the reference count of the Arrow vectors. The reference count is + // only decreased when the native plan is done with the vectors, which is usually longer than + // all the ColumnarBatches are consumed. + private[sql] class ArrowBatchIterator( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Long, + timeZoneId: String, + context: TaskContext) + extends Iterator[ColumnarBatch] + with AutoCloseable { + + private val arrowSchema = Utils.toArrowSchema(schema, timeZoneId) + // Reuse the same root allocator here. + private val allocator = + rootAllocator.newChildAllocator(s"to${this.getClass.getSimpleName}", 0, Long.MaxValue) + private val root = VectorSchemaRoot.create(arrowSchema, allocator) + private val arrowWriter = ArrowWriter.create(root) + + private var currentBatch: ColumnarBatch = null + private var closed: Boolean = false + + Option(context).foreach { + _.addTaskCompletionListener[Unit] { _ => + close(true) + } + } + + override def hasNext: Boolean = rowIter.hasNext || { + close(false) + false + } + + override def next(): ColumnarBatch = { + currentBatch = nextBatch() + currentBatch + } + + override def close(): Unit = { + close(false) + } + + private def nextBatch(): ColumnarBatch = { + if (rowIter.hasNext) { + // the arrow writer shall be reset before writing the next batch + arrowWriter.reset() + var rowCount = 0L + while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { + val row = rowIter.next() + arrowWriter.write(row) + rowCount += 1 + } + arrowWriter.finish() + NativeUtil.rootAsBatch(root) + } else { + null + } + } + + private def close(closeAllocator: Boolean): Unit = { + try { + if (!closed) { + if (currentBatch != null) { + arrowWriter.reset() + currentBatch.close() + currentBatch = null + } + root.close() + closed = true + } + } finally { + // the allocator shall be closed when the task is finished + if (closeAllocator) { + allocator.close() + } + } + } + } + + def toArrowBatchIterator( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Long, + timeZoneId: String, + context: TaskContext): Iterator[ColumnarBatch] = { + new ArrowBatchIterator(rowIter, schema, maxRecordsPerBatch, timeZoneId, context) + } +} diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 684d7783a..7d920e1be 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -54,6 +54,11 @@ object Utils { str.split(",").map(_.trim()).filter(_.nonEmpty) } + /** bridges the function call to Spark's Util */ + def getSimpleName(cls: Class[_]): String = { + org.apache.spark.util.Utils.getSimpleName(cls) + } + def fromArrowField(field: Field): DataType = { field.getType match { case _: ArrowType.Map => @@ -90,6 +95,9 @@ object Utils { case _: ArrowType.FixedSizeBinary => BinaryType case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType + case ts: ArrowType.Timestamp + if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null => + TimestampNTZType case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType case ArrowType.Null.INSTANCE => NullType case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => @@ -98,6 +106,13 @@ object Utils { case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.toString}") } + def fromArrowSchema(schema: Schema): StructType = { + StructType(schema.getFields.asScala.map { field => + val dt = fromArrowField(field) + StructField(field.getName, dt, field.isNullable) + }.toArray) + } + /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ def toArrowType(dt: DataType, timeZoneId: String): ArrowType = dt match { diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 77951943f..5be6f1cd0 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -43,7 +44,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.comet.CometConf._ -import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported} +import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported, shouldApplyRowToColumnar} import org.apache.comet.parquet.{CometParquetScan, SupportsComet} import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.QueryPlanSerde @@ -68,7 +69,7 @@ class CometSparkSessionExtensions override def preColumnarTransitions: Rule[SparkPlan] = CometExecRule(session) override def postColumnarTransitions: Rule[SparkPlan] = - EliminateRedundantColumnarToRow(session) + EliminateRedundantTransitions(session) } case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] { @@ -238,6 +239,11 @@ class CometSparkSessionExtensions val nativeOp = QueryPlanSerde.operator2Proto(op).get CometScanWrapper(nativeOp, op) + case op if shouldApplyRowToColumnar(conf, op) => + val cometOp = CometRowToColumnarExec(op) + val nativeOp = QueryPlanSerde.operator2Proto(cometOp).get + CometScanWrapper(nativeOp, cometOp) + case op: ProjectExec => val newOp = transform1(op) newOp match { @@ -592,18 +598,26 @@ class CometSparkSessionExtensions } } - // CometExec already wraps a `ColumnarToRowExec` for row-based operators. Therefore, - // `ColumnarToRowExec` is redundant and can be eliminated. + // This rule is responsible for eliminating redundant transitions between row-based and + // columnar-based operators for Comet. Currently, two potential redundant transitions are: + // 1. ColumnarToRowExec at the end of a Spark operator, which is redundant for Comet operators as + // CometExec already wraps a `ColumnarToRowExec` for row-based operators. + // 2. Consecutive operators of CometRowToColumnarExec and ColumnarToRowExec, which might be + // possible for Comet to add a `CometRowToColumnarExec` for row-based operators first, then + // Spark only requests row-based output. // - // It was added during ApplyColumnarRulesAndInsertTransitions' insertTransitions phase when Spark - // requests row-based output such as `collect` call. It's correct to add a redundant - // `ColumnarToRowExec` for `CometExec`. However, for certain operators such as - // `CometCollectLimitExec` which overrides `executeCollect`, the redundant `ColumnarToRowExec` - // makes the override ineffective. The purpose of this rule is to eliminate the redundant - // `ColumnarToRowExec` for such operators. - case class EliminateRedundantColumnarToRow(session: SparkSession) extends Rule[SparkPlan] { + // The `ColumnarToRowExec` was added during ApplyColumnarRulesAndInsertTransitions' + // insertTransitions phase when Spark requests row-based output such as a `collect` call. It's + // correct to add a redundant `ColumnarToRowExec` for `CometExec`. However, for certain operators + // such as `CometCollectLimitExec` which overrides `executeCollect`, the redundant + // `ColumnarToRowExec` makes the override ineffective. + case class EliminateRedundantTransitions(session: SparkSession) extends Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = { - plan match { + val eliminatedPlan = plan transformUp { + case ColumnarToRowExec(rowToColumnar: CometRowToColumnarExec) => rowToColumnar.child + } + + eliminatedPlan match { case ColumnarToRowExec(child: CometCollectLimitExec) => child case other => @@ -716,6 +730,18 @@ object CometSparkSessionExtensions extends Logging { op.isInstanceOf[CometBatchScanExec] || op.isInstanceOf[CometScanExec] } + private def shouldApplyRowToColumnar(conf: SQLConf, op: SparkPlan): Boolean = { + // Only consider converting leaf nodes to columnar currently, so that all the following + // operators can have a chance to be converted to columnar. + // TODO: consider converting other intermediate operators to columnar. + op.isInstanceOf[LeafExecNode] && !op.supportsColumnar && isSchemaSupported(op.schema) && + COMET_ROW_TO_COLUMNAR_ENABLED.get(conf) && { + val simpleClassName = Utils.getSimpleName(op.getClass) + val nodeName = simpleClassName.replaceAll("Exec$", "") + COMET_ROW_TO_COLUMNAR_SOURCE_NODE_LIST.get(conf).contains(nodeName) + } + } + /** Used for operations that weren't available in Spark 3.2 */ def isSpark32: Boolean = { org.apache.spark.SPARK_VERSION.matches("3\\.2\\..*") diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b98c4388e..26fc708ff 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometSinkPlaceHolder, DecimalPrecision} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometRowToColumnarExec, CometSinkPlaceHolder, DecimalPrecision} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ @@ -2064,6 +2064,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { private def isCometSink(op: SparkPlan): Boolean = { op match { case s if isCometScan(s) => true + case _: CometRowToColumnarExec => true case _: CometSinkPlaceHolder => true case _: CoalesceExec => true case _: CollectLimitExec => true diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometRowToColumnarExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometRowToColumnarExec.scala new file mode 100644 index 000000000..5679e865c --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometRowToColumnarExec.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import org.apache.spark.TaskContext +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.comet.execution.arrow.CometArrowConverters +import org.apache.spark.sql.execution.{RowToColumnarTransition, SparkPlan} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.vectorized.ColumnarBatch + +case class CometRowToColumnarExec(child: SparkPlan) + extends RowToColumnarTransition + with CometPlan { + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override protected def doExecute(): RDD[InternalRow] = { + child.execute() + } + + override def doExecuteBroadcast[T](): Broadcast[T] = { + child.executeBroadcast() + } + + override def supportsColumnar: Boolean = true + + override lazy val metrics: Map[String, SQLMetric] = Map( + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "number of output batches")) + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numInputRows = longMetric("numInputRows") + val numOutputBatches = longMetric("numOutputBatches") + val maxRecordsPerBatch = conf.arrowMaxRecordsPerBatch + val timeZoneId = conf.sessionLocalTimeZone + val schema = child.schema + + child + .execute() + .mapPartitionsInternal { iter => + val context = TaskContext.get() + CometArrowConverters.toArrowBatchIterator( + iter, + schema, + maxRecordsPerBatch, + timeZoneId, + context) + } + .map { batch => + numInputRows += batch.numRows() + numOutputBatches += 1 + batch + } + } + + override protected def withNewChildInternal(newChild: SparkPlan): CometRowToColumnarExec = + copy(child = newChild) + +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 520f2395b..8545eee90 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -270,7 +270,7 @@ abstract class CometNativeExec extends CometExec { } if (inputs.isEmpty) { - throw new CometRuntimeException(s"No input for CometNativeExec: $this") + throw new CometRuntimeException(s"No input for CometNativeExec:\n $this") } ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter(_)) @@ -300,7 +300,8 @@ abstract class CometNativeExec extends CometExec { case _: CometScanExec | _: CometBatchScanExec | _: ShuffleQueryStageExec | _: AQEShuffleReadExec | _: CometShuffleExchangeExec | _: CometUnionExec | _: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _: ReusedExchangeExec | - _: CometBroadcastExchangeExec | _: BroadcastQueryStageExec => + _: CometBroadcastExchangeExec | _: BroadcastQueryStageExec | + _: CometRowToColumnarExec => func(plan) case _: CometPlan => // Other Comet operators, continue to traverse the tree. diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index b2c4fd6e4..0bb21aba7 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Hex import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometRowToColumnarExec, CometScanExec, CometTakeOrderedAndProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec @@ -1118,6 +1118,58 @@ class CometExecSuite extends CometTestBase { } }) } + + test("RowToColumnar over RangeExec") { + Seq("true", "false").foreach(aqe => { + Seq(500, 900).foreach { batchSize => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqe, + SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> batchSize.toString) { + val df = spark.range(1000).selectExpr("id", "id % 8 as k").groupBy("k").sum("id") + checkSparkAnswerAndOperator(df) + // empty record batch should also be handled + val df2 = spark.range(0).selectExpr("id", "id % 8 as k").groupBy("k").sum("id") + checkSparkAnswerAndOperator(df2, includeClasses = Seq(classOf[CometRowToColumnarExec])) + } + } + }) + } + + test("RowToColumnar over RangeExec directly is eliminated for row output") { + Seq("true", "false").foreach(aqe => { + Seq(500, 900).foreach { batchSize => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqe, + SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> batchSize.toString) { + val df = spark.range(1000) + val qe = df.queryExecution + qe.executedPlan.collectFirst({ case r: CometRowToColumnarExec => r }) match { + case Some(_) => fail("CometRowToColumnarExec should be eliminated") + case _ => + } + } + } + }) + } + + test("RowToColumnar over InMemoryTableScanExec") { + Seq("true", "false").foreach(aqe => { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqe, + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> "false") { + spark + .range(1000) + .selectExpr("id as key", "id % 8 as value") + .toDF("key", "value") + .selectExpr("key", "value", "key+1") + .createOrReplaceTempView("abc") + spark.catalog.cacheTable("abc") + val df = spark.sql("SELECT * FROM abc").groupBy("key").count() + checkSparkAnswerAndOperator(df, includeClasses = Seq(classOf[CometRowToColumnarExec])) + } + }) + } } case class BucketedTableTestSpec( diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 6fb81bc43..de5866580 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -34,7 +34,7 @@ import org.apache.parquet.hadoop.example.ExampleParquetWriter import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark._ import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE, SHUFFLE_MANAGER} -import org.apache.spark.sql.comet.{CometBatchScanExec, CometBroadcastExchangeExec, CometExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder} +import org.apache.spark.sql.comet.{CometBatchScanExec, CometBroadcastExchangeExec, CometExec, CometRowToColumnarExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{ColumnarToRowExec, InputAdapter, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -75,6 +75,7 @@ abstract class CometTestBase conf.set(CometConf.COMET_EXEC_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key, "true") + conf.set(CometConf.COMET_ROW_TO_COLUMNAR_ENABLED.key, "true") conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "2g") conf } @@ -155,9 +156,11 @@ abstract class CometTestBase } protected def checkCometOperators(plan: SparkPlan, excludedClasses: Class[_]*): Unit = { - plan.foreach { + val wrapped = wrapCometRowToColumnar(plan) + wrapped.foreach { case _: CometScanExec | _: CometBatchScanExec => true case _: CometSinkPlaceHolder | _: CometScanWrapper => false + case _: CometRowToColumnarExec => false case _: CometExec | _: CometShuffleExchangeExec => true case _: CometBroadcastExchangeExec => true case _: WholeStageCodegenExec | _: ColumnarToRowExec | _: InputAdapter => true @@ -184,6 +187,14 @@ abstract class CometTestBase } } + /** Wraps the CometRowToColumn as ScanWrapper, so the child operators will not be checked */ + private def wrapCometRowToColumnar(plan: SparkPlan): SparkPlan = { + plan.transformDown { + // don't care the native operators + case p: CometRowToColumnarExec => CometScanWrapper(null, p) + } + } + /** * Check the answer of a Comet SQL query with Spark result using absolute tolerance. */