Skip to content

[SPARK-27650][SQL] separate the row iterator functionality from ColumnarBatch #24546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,11 @@ object MimaExcludes {

// [SPARK-26616][MLlib] Expose document frequency in IDFModel
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.IDFModel.this"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.feature.IDF#DocumentFrequencyAggregator.idf")
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.feature.IDF#DocumentFrequencyAggregator.idf"),

// [SPARK-27650][SQL] separate the row iterator functionality from ColumnarBatch
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.vectorized.ColumnarBatch.getRow"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.vectorized.ColumnarBatch.rowIterator")
)

// Exclude rules for 2.4.x
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@

import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector;
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
import org.apache.spark.sql.execution.vectorized.*;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
Expand Down Expand Up @@ -104,6 +101,8 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa
*/
private ColumnarBatch columnarBatch;

private ColumnarBatchRowView rowView;

private WritableColumnVector[] columnVectors;

/**
Expand Down Expand Up @@ -168,7 +167,7 @@ public boolean nextKeyValue() throws IOException {
@Override
public Object getCurrentValue() {
if (returnColumnarBatch) return columnarBatch;
return columnarBatch.getRow(batchIdx - 1);
return rowView.getRow(batchIdx - 1);
}

@Override
Expand Down Expand Up @@ -202,6 +201,7 @@ private void initBatch(
columnVectors = OnHeapColumnVector.allocateColumns(capacity, batchSchema);
}
columnarBatch = new ColumnarBatch(columnVectors);
rowView = new ColumnarBatchRowView(columnarBatch);
if (partitionColumns != null) {
int partitionIdx = sparkSchema.fields().length;
for (int i = 0; i < partitionColumns.fields().length; i++) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.vectorized;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.vectorized.ColumnarBatch;

import java.util.Iterator;
import java.util.NoSuchElementException;

/**
* This class provides a row view of a {@link ColumnarBatch}, so that Spark can access the data
* row by row
*/
public final class ColumnarBatchRowView {

private final ColumnarBatch batch;

// Staging row returned from `getRow`.
private final MutableColumnarRow row;

public ColumnarBatchRowView(ColumnarBatch batch) {
Copy link
Member

@kiszk kiszk May 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to create another constructor not to allocate MutableColumnarRow as an optimization?
This is because most of the use cases are to immediately call rowIterator() that obviously never calls getRow().

this.batch = batch;
this.row = new MutableColumnarRow(batch.columns());
}

/**
* Returns an iterator over the rows in this batch.
*/
public Iterator<InternalRow> rowIterator() {
final int maxRows = batch.numRows();
final MutableColumnarRow row = new MutableColumnarRow(batch.columns());
return new Iterator<InternalRow>() {
int rowId = 0;

@Override
public boolean hasNext() {
return rowId < maxRows;
}

@Override
public InternalRow next() {
if (rowId >= maxRows) {
throw new NoSuchElementException();
}
row.rowId = rowId++;
return row;
}

@Override
public void remove() {
throw new UnsupportedOperationException();
}
};
}

/**
* Returns the row in this batch at `rowId`. Returned row is reused across calls.
*/
public InternalRow getRow(int rowId) {
assert(rowId >= 0 && rowId < batch.numRows());
row.rowId = rowId;
return row;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,17 @@
*/
package org.apache.spark.sql.vectorized;

import java.util.*;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.execution.vectorized.MutableColumnarRow;

/**
* This class wraps multiple ColumnVectors as a row-wise table. It provides a row view of this
* batch so that Spark can access the data row by row. Instance of it is meant to be reused during
* the entire data loading process.
* This class wraps multiple {@link ColumnVector}s as a table-like data batch. Instance of it is
* meant to be reused during the entire data loading process.
*/
@Evolving
public final class ColumnarBatch {
private int numRows;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it still proper to carry this info here now? Row-wise access isn't at ColumnarBatch anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spark needs to know the row count to read the columnar data.

private final ColumnVector[] columns;

// Staging row returned from `getRow`.
private final MutableColumnarRow row;

/**
* Called to close all the columns in this batch. It is not valid to access the data after
* calling this. This must be called at the end to clean up memory allocations.
Expand All @@ -45,36 +37,6 @@ public void close() {
}
}

/**
* Returns an iterator over the rows in this batch.
*/
public Iterator<InternalRow> rowIterator() {
final int maxRows = numRows;
final MutableColumnarRow row = new MutableColumnarRow(columns);
return new Iterator<InternalRow>() {
int rowId = 0;

@Override
public boolean hasNext() {
return rowId < maxRows;
}

@Override
public InternalRow next() {
if (rowId >= maxRows) {
throw new NoSuchElementException();
}
row.rowId = rowId++;
return row;
}

@Override
public void remove() {
throw new UnsupportedOperationException();
}
};
}

/**
* Sets the number of rows in this batch.
*/
Expand All @@ -98,16 +60,13 @@ public void setNumRows(int numRows) {
public ColumnVector column(int ordinal) { return columns[ordinal]; }

/**
* Returns the row in this batch at `rowId`. Returned row is reused across calls.
* Returns all the columns of this batch.
*/
public InternalRow getRow(int rowId) {
assert(rowId >= 0 && rowId < numRows);
row.rowId = rowId;
return row;
public ColumnVector[] columns() {
return columns;
}

public ColumnarBatch(ColumnVector[] columns) {
this.columns = columns;
this.row = new MutableColumnarRow(columns);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator}
import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, OnHeapColumnVector}
import org.apache.spark.sql.execution.vectorized.{ColumnarBatchRowView, MutableColumnarRow, OnHeapColumnVector}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch

Expand Down Expand Up @@ -212,7 +212,7 @@ class VectorizedHashMapGenerator(
s"""
|public java.util.Iterator<${classOf[InternalRow].getName}> rowIterator() {
| batch.setNumRows(numRows);
| return batch.rowIterator();
| return new ${classOf[ColumnarBatchRowView].getName}(batch).rowIterator();
|}
""".stripMargin
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.vectorized.ColumnarBatchRowView
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
import org.apache.spark.util.{ByteBufferOutputStream, Utils}
Expand Down Expand Up @@ -172,7 +173,7 @@ private[sql] object ArrowConverters {

val batch = new ColumnarBatch(columns)
batch.setNumRows(root.getRowCount)
batch.rowIterator().asScala
new ColumnarBatchRowView(batch).rowIterator().asScala
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.python.BatchIterator
import org.apache.spark.sql.execution.r.ArrowRRunner
import org.apache.spark.sql.execution.streaming.GroupStateImpl
import org.apache.spark.sql.execution.vectorized.ColumnarBatchRowView
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.GroupStateTimeout
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -247,17 +248,17 @@ case class MapPartitionsInRWithArrowExec(

private var currentIter = if (columnarBatchIter.hasNext) {
val batch = columnarBatchIter.next()
val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType())
val actualDataTypes = batch.columns().map(_.dataType()).toSeq
assert(outputTypes == actualDataTypes, "Invalid schema from dapply(): " +
s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}")
batch.rowIterator.asScala
new ColumnarBatchRowView(batch).rowIterator.asScala
} else {
Iterator.empty
}

override def hasNext: Boolean = currentIter.hasNext || {
if (columnarBatchIter.hasNext) {
currentIter = columnarBatchIter.next().rowIterator.asScala
currentIter = new ColumnarBatchRowView(columnarBatchIter.next()).rowIterator.asScala
hasNext
} else {
false
Expand Down Expand Up @@ -587,7 +588,9 @@ case class FlatMapGroupsInRWithArrowExec(
// binary in a batch due to the limitation of R API. See also ARROW-4512.
val columnarBatchIter = runner.compute(groupedByRKey, -1)
val outputProject = UnsafeProjection.create(output, output)
columnarBatchIter.flatMap(_.rowIterator().asScala).map(outputProject)
columnarBatchIter
.flatMap(new ColumnarBatchRowView(_).rowIterator().asScala)
.map(outputProject)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.execution.vectorized.ColumnarBatchRowView
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -146,7 +147,7 @@ case class AggregateInPandasExec(
val joined = new JoinedRow
val resultProj = UnsafeProjection.create(resultExpressions, joinedAttributes)

columnarBatchIter.map(_.rowIterator.next()).map { aggOutputRow =>
columnarBatchIter.map(new ColumnarBatchRowView(_).getRow(0)).map { aggOutputRow =>
val leftRow = queue.remove()
val joinedRow = joined(leftRow, aggOutputRow)
resultProj(joinedRow)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.execution.vectorized.ColumnarBatchRowView
import org.apache.spark.sql.types.StructType

/**
Expand Down Expand Up @@ -101,17 +102,17 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi

private var currentIter = if (columnarBatchIter.hasNext) {
val batch = columnarBatchIter.next()
val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType())
val actualDataTypes = batch.columns().map(_.dataType()).toSeq
assert(outputTypes == actualDataTypes, "Invalid schema from pandas_udf: " +
s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}")
batch.rowIterator.asScala
new ColumnarBatchRowView(batch).rowIterator().asScala
} else {
Iterator.empty
}

override def hasNext: Boolean = currentIter.hasNext || {
if (columnarBatchIter.hasNext) {
currentIter = columnarBatchIter.next().rowIterator.asScala
currentIter = new ColumnarBatchRowView(columnarBatchIter.next()).rowIterator.asScala
hasNext
} else {
false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.execution.vectorized.ColumnarBatchRowView
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}

/**
* Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]]
Expand Down Expand Up @@ -154,7 +155,7 @@ case class FlatMapGroupsInPandasExec(
val outputVectors = output.indices.map(structVector.getChild)
val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
flattenedBatch.setNumRows(batch.numRows())
flattenedBatch.rowIterator.asScala
new ColumnarBatchRowView(flattenedBatch).rowIterator.asScala
}.map(unsafeProj)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan}
import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.execution.vectorized.ColumnarBatchRowView
import org.apache.spark.sql.execution.window._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -394,11 +395,13 @@ case class WindowInPandasExec(

val joined = new JoinedRow

windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput =>
val leftRow = queue.remove()
val joinedRow = joined(leftRow, windowOutput)
resultProj(joinedRow)
}
windowFunctionResult
.flatMap(new ColumnarBatchRowView(_).rowIterator.asScala)
.map { windowOutput =>
val leftRow = queue.remove()
val joinedRow = joined(leftRow, windowOutput)
resultProj(joinedRow)
}
}
}
}
Loading