Skip to content

[SPARK-21745][SQL] Refactor ColumnVector hierarchy to make ColumnVector read-only and to introduce WritableColumnVector. #18958

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -450,14 +450,13 @@ class CodegenContext {
/**
* Returns the specialized code to set a given value in a column vector for a given `DataType`.
*/
def setValue(batch: String, row: String, dataType: DataType, ordinal: Int,
value: String): String = {
def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = {
val jt = javaType(dataType)
dataType match {
case _ if isPrimitiveType(jt) =>
s"$batch.column($ordinal).put${primitiveTypeName(jt)}($row, $value);"
case t: DecimalType => s"$batch.column($ordinal).putDecimal($row, $value, ${t.precision});"
case t: StringType => s"$batch.column($ordinal).putByteArray($row, $value.getBytes());"
s"$vector.put${primitiveTypeName(jt)}($rowId, $value);"
case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});"
case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());"
case _ =>
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
}
Expand All @@ -468,37 +467,36 @@ class CodegenContext {
* that could potentially be nullable.
*/
def updateColumn(
batch: String,
row: String,
vector: String,
rowId: String,
dataType: DataType,
ordinal: Int,
ev: ExprCode,
nullable: Boolean): String = {
if (nullable) {
s"""
if (!${ev.isNull}) {
${setValue(batch, row, dataType, ordinal, ev.value)}
${setValue(vector, rowId, dataType, ev.value)}
} else {
$batch.column($ordinal).putNull($row);
$vector.putNull($rowId);
}
"""
} else {
s"""${setValue(batch, row, dataType, ordinal, ev.value)};"""
s"""${setValue(vector, rowId, dataType, ev.value)};"""
}
}

/**
* Returns the specialized code to access a value from a column vector for a given `DataType`.
*/
def getValue(batch: String, row: String, dataType: DataType, ordinal: Int): String = {
def getValue(vector: String, rowId: String, dataType: DataType): String = {
val jt = javaType(dataType)
dataType match {
case _ if isPrimitiveType(jt) =>
s"$batch.column($ordinal).get${primitiveTypeName(jt)}($row)"
s"$vector.get${primitiveTypeName(jt)}($rowId)"
case t: DecimalType =>
s"$batch.column($ordinal).getDecimal($row, ${t.precision}, ${t.scale})"
s"$vector.getDecimal($rowId, ${t.precision}, ${t.scale})"
case StringType =>
s"$batch.column($ordinal).getUTF8String($row)"
s"$vector.getUTF8String($rowId)"
case _ =>
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import org.apache.spark.sql.catalyst.util.DateTimeUtils;
import org.apache.spark.sql.execution.vectorized.ColumnVector;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DecimalType;

Expand Down Expand Up @@ -135,9 +136,9 @@ private boolean next() throws IOException {
/**
* Reads `total` values from this columnReader into column.
*/
void readBatch(int total, ColumnVector column) throws IOException {
void readBatch(int total, WritableColumnVector column) throws IOException {
int rowId = 0;
ColumnVector dictionaryIds = null;
WritableColumnVector dictionaryIds = null;
if (dictionary != null) {
// SPARK-16334: We only maintain a single dictionary per row batch, so that it can be used to
// decode all previous dictionary encoded pages if we ever encounter a non-dictionary encoded
Expand Down Expand Up @@ -219,8 +220,11 @@ void readBatch(int total, ColumnVector column) throws IOException {
/**
* Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`.
*/
private void decodeDictionaryIds(int rowId, int num, ColumnVector column,
ColumnVector dictionaryIds) {
private void decodeDictionaryIds(
int rowId,
int num,
WritableColumnVector column,
ColumnVector dictionaryIds) {
switch (descriptor.getType()) {
case INT32:
if (column.dataType() == DataTypes.IntegerType ||
Expand Down Expand Up @@ -346,13 +350,13 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column,
* is guaranteed that num is smaller than the number of values left in the current page.
*/

private void readBooleanBatch(int rowId, int num, ColumnVector column) throws IOException {
private void readBooleanBatch(int rowId, int num, WritableColumnVector column) throws IOException {
assert(column.dataType() == DataTypes.BooleanType);
defColumn.readBooleans(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
}

private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException {
private void readIntBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType ||
Expand All @@ -370,7 +374,7 @@ private void readIntBatch(int rowId, int num, ColumnVector column) throws IOExce
}
}

private void readLongBatch(int rowId, int num, ColumnVector column) throws IOException {
private void readLongBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
if (column.dataType() == DataTypes.LongType ||
DecimalType.is64BitDecimalType(column.dataType())) {
Expand All @@ -389,7 +393,7 @@ private void readLongBatch(int rowId, int num, ColumnVector column) throws IOExc
}
}

private void readFloatBatch(int rowId, int num, ColumnVector column) throws IOException {
private void readFloatBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: support implicit cast to double?
if (column.dataType() == DataTypes.FloatType) {
Expand All @@ -400,7 +404,7 @@ private void readFloatBatch(int rowId, int num, ColumnVector column) throws IOEx
}
}

private void readDoubleBatch(int rowId, int num, ColumnVector column) throws IOException {
private void readDoubleBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
if (column.dataType() == DataTypes.DoubleType) {
Expand All @@ -411,7 +415,7 @@ private void readDoubleBatch(int rowId, int num, ColumnVector column) throws IOE
}
}

private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOException {
private void readBinaryBatch(int rowId, int num, WritableColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
Expand All @@ -432,8 +436,11 @@ private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOE
}
}

private void readFixedLenByteArrayBatch(int rowId, int num,
ColumnVector column, int arrayLen) throws IOException {
private void readFixedLenByteArrayBatch(
int rowId,
int num,
WritableColumnVector column,
int arrayLen) throws IOException {
VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils;
import org.apache.spark.sql.execution.vectorized.ColumnarBatch;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector;
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

Expand Down Expand Up @@ -90,6 +93,8 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa
*/
private ColumnarBatch columnarBatch;

private WritableColumnVector[] columnVectors;

/**
* If true, this class returns batches instead of rows.
*/
Expand Down Expand Up @@ -172,20 +177,26 @@ public void initBatch(MemoryMode memMode, StructType partitionColumns,
}
}

columnarBatch = ColumnarBatch.allocate(batchSchema, memMode);
int capacity = ColumnarBatch.DEFAULT_BATCH_SIZE;
if (memMode == MemoryMode.OFF_HEAP) {
columnVectors = OffHeapColumnVector.allocateColumns(capacity, batchSchema);
} else {
columnVectors = OnHeapColumnVector.allocateColumns(capacity, batchSchema);
}
columnarBatch = new ColumnarBatch(batchSchema, columnVectors, capacity);
if (partitionColumns != null) {
int partitionIdx = sparkSchema.fields().length;
for (int i = 0; i < partitionColumns.fields().length; i++) {
ColumnVectorUtils.populate(columnarBatch.column(i + partitionIdx), partitionValues, i);
columnarBatch.column(i + partitionIdx).setIsConstant();
ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i);
columnVectors[i + partitionIdx].setIsConstant();
}
}

// Initialize missing columns with nulls.
for (int i = 0; i < missingColumns.length; i++) {
if (missingColumns[i]) {
columnarBatch.column(i).putNulls(0, columnarBatch.capacity());
columnarBatch.column(i).setIsConstant();
columnVectors[i].putNulls(0, columnarBatch.capacity());
columnVectors[i].setIsConstant();
}
}
}
Expand Down Expand Up @@ -226,7 +237,7 @@ public boolean nextBatch() throws IOException {
int num = (int) Math.min((long) columnarBatch.capacity(), totalCountLoadedSoFar - rowsReturned);
for (int i = 0; i < columnReaders.length; ++i) {
if (columnReaders[i] == null) continue;
columnReaders[i].readBatch(num, columnarBatch.column(i));
columnReaders[i].readBatch(num, columnVectors[i]);
}
rowsReturned += num;
columnarBatch.setNumRows(num);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import java.nio.ByteBuffer;
import java.nio.ByteOrder;

import org.apache.spark.sql.execution.vectorized.ColumnVector;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.unsafe.Platform;

import org.apache.parquet.column.values.ValuesReader;
Expand Down Expand Up @@ -56,39 +56,39 @@ public void skip() {
}

@Override
public final void readBooleans(int total, ColumnVector c, int rowId) {
public final void readBooleans(int total, WritableColumnVector c, int rowId) {
// TODO: properly vectorize this
for (int i = 0; i < total; i++) {
c.putBoolean(rowId + i, readBoolean());
}
}

@Override
public final void readIntegers(int total, ColumnVector c, int rowId) {
public final void readIntegers(int total, WritableColumnVector c, int rowId) {
c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
offset += 4 * total;
}

@Override
public final void readLongs(int total, ColumnVector c, int rowId) {
public final void readLongs(int total, WritableColumnVector c, int rowId) {
c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
offset += 8 * total;
}

@Override
public final void readFloats(int total, ColumnVector c, int rowId) {
public final void readFloats(int total, WritableColumnVector c, int rowId) {
c.putFloats(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
offset += 4 * total;
}

@Override
public final void readDoubles(int total, ColumnVector c, int rowId) {
public final void readDoubles(int total, WritableColumnVector c, int rowId) {
c.putDoubles(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
offset += 8 * total;
}

@Override
public final void readBytes(int total, ColumnVector c, int rowId) {
public final void readBytes(int total, WritableColumnVector c, int rowId) {
for (int i = 0; i < total; i++) {
// Bytes are stored as a 4-byte little endian int. Just read the first byte.
// TODO: consider pushing this in ColumnVector by adding a readBytes with a stride.
Expand Down Expand Up @@ -159,7 +159,7 @@ public final double readDouble() {
}

@Override
public final void readBinary(int total, ColumnVector v, int rowId) {
public final void readBinary(int total, WritableColumnVector v, int rowId) {
for (int i = 0; i < total; i++) {
int len = readInteger();
int start = offset;
Expand Down
Loading