Skip to content

Commit a6b4f05

Browse files
Cleaning up ArrayConverter, moving classTag to NativeType, adding NativeRow
1 parent 431f00f commit a6b4f05

File tree

3 files changed

+101
-43
lines changed

3 files changed

+101
-43
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,67 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow {
206206
override def copy() = new GenericRow(values.clone())
207207
}
208208

209+
// TODO: this is an awful lot of code duplication. If values would be covariant we could reuse
210+
// much of GenericRow
211+
class NativeRow[T](protected[catalyst] val values: Array[T]) extends Row {
212+
213+
/** No-arg constructor for serialization. */
214+
def this() = this(null)
215+
216+
def this(elementType: NativeType, size: Int) =
217+
this(elementType.classTag.newArray(size).asInstanceOf[Array[T]])
218+
219+
def iterator = values.iterator
220+
221+
def length = values.length
222+
223+
def apply(i: Int) = values(i)
224+
225+
def isNullAt(i: Int) = values(i) == null
226+
227+
def getInt(i: Int): Int = {
228+
if (values(i) == null) sys.error("Failed to check null bit for primitive int value.")
229+
values(i).asInstanceOf[Int]
230+
}
231+
232+
def getLong(i: Int): Long = {
233+
if (values(i) == null) sys.error("Failed to check null bit for primitive long value.")
234+
values(i).asInstanceOf[Long]
235+
}
236+
237+
def getDouble(i: Int): Double = {
238+
if (values(i) == null) sys.error("Failed to check null bit for primitive double value.")
239+
values(i).asInstanceOf[Double]
240+
}
241+
242+
def getFloat(i: Int): Float = {
243+
if (values(i) == null) sys.error("Failed to check null bit for primitive float value.")
244+
values(i).asInstanceOf[Float]
245+
}
246+
247+
def getBoolean(i: Int): Boolean = {
248+
if (values(i) == null) sys.error("Failed to check null bit for primitive boolean value.")
249+
values(i).asInstanceOf[Boolean]
250+
}
251+
252+
def getShort(i: Int): Short = {
253+
if (values(i) == null) sys.error("Failed to check null bit for primitive short value.")
254+
values(i).asInstanceOf[Short]
255+
}
256+
257+
def getByte(i: Int): Byte = {
258+
if (values(i) == null) sys.error("Failed to check null bit for primitive byte value.")
259+
values(i).asInstanceOf[Byte]
260+
}
261+
262+
def getString(i: Int): String = {
263+
if (values(i) == null) sys.error("Failed to check null bit for primitive String value.")
264+
values(i).asInstanceOf[String]
265+
}
266+
267+
def copy() = this
268+
}
269+
209270

210271
class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] {
211272
def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.types
1919

2020
import java.sql.Timestamp
2121

22-
import scala.reflect.runtime.universe.{typeTag, TypeTag}
22+
import scala.reflect.ClassTag
23+
import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror}
2324

2425
import org.apache.spark.sql.catalyst.expressions.Expression
26+
import org.apache.spark.util.Utils
2527

2628
abstract class DataType {
2729
/** Matches any expression that evaluates to this DataType */
@@ -43,6 +45,11 @@ abstract class NativeType extends DataType {
4345
type JvmType
4446
@transient val tag: TypeTag[JvmType]
4547
val ordering: Ordering[JvmType]
48+
49+
@transient val classTag = {
50+
val mirror = runtimeMirror(Utils.getSparkClassLoader)
51+
ClassTag[JvmType](mirror.runtimeClass(tag.tpe))
52+
}
4653
}
4754

4855
case object StringType extends NativeType with PrimitiveType {

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala

Lines changed: 32 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,13 @@
1818
package org.apache.spark.sql.parquet
1919

2020
import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap}
21-
import scala.reflect.ClassTag
22-
import scala.reflect.runtime.universe.runtimeMirror
2321

2422
import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter}
2523
import parquet.schema.MessageType
2624

2725
import org.apache.spark.sql.catalyst.types._
28-
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Row, Attribute}
26+
import org.apache.spark.sql.catalyst.expressions.{NativeRow, GenericRow, Row, Attribute}
2927
import org.apache.spark.sql.parquet.CatalystConverter.FieldType
30-
import org.apache.spark.util.Utils
3128

3229
private[parquet] object CatalystConverter {
3330
// The type internally used for fields
@@ -83,7 +80,7 @@ private[parquet] object CatalystConverter {
8380
val attributes = ParquetTypesConverter.convertToAttributes(parquetSchema)
8481
// For non-nested types we use the optimized Row converter
8582
if (attributes.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType))) {
86-
new MutableRowGroupConverter(attributes)
83+
new PrimitiveRowGroupConverter(attributes)
8784
} else {
8885
new CatalystGroupConverter(attributes)
8986
}
@@ -170,6 +167,9 @@ private[parquet] class CatalystGroupConverter(
170167
def getCurrentRecord: Row = {
171168
assert(isRootConverter, "getCurrentRecord should only be called in root group converter!")
172169
// TODO: use iterators if possible
170+
// Note: this will ever only be called in the root converter when the record has been
171+
// fully processed. Therefore it will be difficult to use mutable rows instead, since
172+
// any non-root converter never would be sure when it would be safe to re-use the buffer.
173173
new GenericRow(current.toArray)
174174
}
175175

@@ -180,14 +180,9 @@ private[parquet] class CatalystGroupConverter(
180180
current.update(fieldIndex, value)
181181
}
182182

183-
override protected[parquet] def clearBuffer(): Unit = {
184-
// TODO: reuse buffer?
185-
buffer = new ArrayBuffer[Row](CatalystArrayConverter.INITIAL_ARRAY_SIZE)
186-
}
183+
override protected[parquet] def clearBuffer(): Unit = buffer.clear()
187184

188185
override def start(): Unit = {
189-
// TODO: reuse buffer?
190-
// Allocate new array in the root converter (others will be called clearBuffer() on)
191186
current = ArrayBuffer.fill(schema.length)(null)
192187
converters.foreach {
193188
converter => if (!converter.isPrimitive) {
@@ -196,12 +191,10 @@ private[parquet] class CatalystGroupConverter(
196191
}
197192
}
198193

199-
// TODO: think about reusing the buffer
200194
override def end(): Unit = {
201195
if (!isRootConverter) {
202196
assert(current!=null) // there should be no empty groups
203197
buffer.append(new GenericRow(current.toArray))
204-
// TODO: use iterators if possible, avoid Row wrapping
205198
parent.updateField(index, new GenericRow(buffer.toArray.asInstanceOf[Array[Any]]))
206199
}
207200
}
@@ -212,7 +205,7 @@ private[parquet] class CatalystGroupConverter(
212205
* to a [[org.apache.spark.sql.catalyst.expressions.Row]] object. Note that his
213206
* converter is optimized for rows of primitive types (non-nested records).
214207
*/
215-
private[parquet] class MutableRowGroupConverter(
208+
private[parquet] class PrimitiveRowGroupConverter(
216209
protected[parquet] val schema: Seq[FieldType],
217210
protected[parquet] var current: ParquetRelation.RowType)
218211
extends GroupConverter with CatalystConverter {
@@ -334,7 +327,7 @@ object CatalystArrayConverter {
334327
* [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an
335328
* [[org.apache.spark.sql.catalyst.types.ArrayType]].
336329
*
337-
* @param elementType The type of the array elements
330+
* @param elementType The type of the array elements (complex or primitive)
338331
* @param index The position of this (array) field inside its parent converter
339332
* @param parent The parent converter
340333
* @param buffer A data buffer
@@ -345,8 +338,6 @@ private[parquet] class CatalystArrayConverter(
345338
protected[parquet] val parent: CatalystConverter,
346339
protected[parquet] var buffer: Buffer[Any])
347340
extends GroupConverter with CatalystConverter {
348-
// TODO: In the future consider using native arrays instead of buffer for
349-
// primitive types for performance reasons
350341

351342
def this(elementType: DataType, index: Int, parent: CatalystConverter) =
352343
this(
@@ -374,8 +365,7 @@ private[parquet] class CatalystArrayConverter(
374365
}
375366

376367
override protected[parquet] def clearBuffer(): Unit = {
377-
// TODO: reuse buffer?
378-
buffer = new ArrayBuffer[Any](CatalystArrayConverter.INITIAL_ARRAY_SIZE)
368+
buffer.clear()
379369
}
380370

381371
override def start(): Unit = {
@@ -384,10 +374,8 @@ private[parquet] class CatalystArrayConverter(
384374
}
385375
}
386376

387-
// TODO: think about reusing the buffer
388377
override def end(): Unit = {
389378
assert(parent != null)
390-
// TODO: use iterators if possible, avoid Row wrapping
391379
parent.updateField(index, new GenericRow(buffer.toArray))
392380
clearBuffer()
393381
}
@@ -396,20 +384,27 @@ private[parquet] class CatalystArrayConverter(
396384
override def getCurrentRecord: Row = throw new UnsupportedOperationException
397385
}
398386

399-
private[parquet] class CatalystNativeArrayConverter[T <: NativeType](
387+
/**
388+
* A `parquet.io.api.GroupConverter` that converts a single-element groups that
389+
* match the characteristics of an array (see
390+
* [[org.apache.spark.sql.parquet.ParquetTypesConverter]]) into an
391+
* [[org.apache.spark.sql.catalyst.types.ArrayType]].
392+
*
393+
* @param elementType The type of the array elements (native)
394+
* @param index The position of this (array) field inside its parent converter
395+
* @param parent The parent converter
396+
* @param capacity The (initial) capacity of the buffer
397+
*/
398+
private[parquet] class CatalystNativeArrayConverter(
400399
val elementType: NativeType,
401400
val index: Int,
402401
protected[parquet] val parent: CatalystConverter,
403402
protected[parquet] var capacity: Int = CatalystArrayConverter.INITIAL_ARRAY_SIZE)
404403
extends GroupConverter with CatalystConverter {
405404

406-
// similar comment as in [[Decoder]]: this should probably be in NativeType
407-
private val classTag = {
408-
val mirror = runtimeMirror(Utils.getSparkClassLoader)
409-
ClassTag[T#JvmType](mirror.runtimeClass(elementType.tag.tpe))
410-
}
405+
type nativeType = elementType.JvmType
411406

412-
private var buffer: Array[T#JvmType] = classTag.newArray(capacity)
407+
private var buffer: Array[nativeType] = elementType.classTag.newArray(capacity)
413408

414409
private var elements: Int = 0
415410

@@ -432,43 +427,43 @@ private[parquet] class CatalystNativeArrayConverter[T <: NativeType](
432427
// Overriden here to avoid auto-boxing for primitive types
433428
override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = {
434429
checkGrowBuffer()
435-
buffer(elements) = value.asInstanceOf[T#JvmType]
430+
buffer(elements) = value.asInstanceOf[nativeType]
436431
elements += 1
437432
}
438433

439434
override protected[parquet] def updateInt(fieldIndex: Int, value: Int): Unit = {
440435
checkGrowBuffer()
441-
buffer(elements) = value.asInstanceOf[T#JvmType]
436+
buffer(elements) = value.asInstanceOf[nativeType]
442437
elements += 1
443438
}
444439

445440
override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = {
446441
checkGrowBuffer()
447-
buffer(elements) = value.asInstanceOf[T#JvmType]
442+
buffer(elements) = value.asInstanceOf[nativeType]
448443
elements += 1
449444
}
450445

451446
override protected[parquet] def updateDouble(fieldIndex: Int, value: Double): Unit = {
452447
checkGrowBuffer()
453-
buffer(elements) = value.asInstanceOf[T#JvmType]
448+
buffer(elements) = value.asInstanceOf[nativeType]
454449
elements += 1
455450
}
456451

457452
override protected[parquet] def updateFloat(fieldIndex: Int, value: Float): Unit = {
458453
checkGrowBuffer()
459-
buffer(elements) = value.asInstanceOf[T#JvmType]
454+
buffer(elements) = value.asInstanceOf[nativeType]
460455
elements += 1
461456
}
462457

463458
override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = {
464459
checkGrowBuffer()
465-
buffer(elements) = value.getBytes.asInstanceOf[T#JvmType]
460+
buffer(elements) = value.getBytes.asInstanceOf[nativeType]
466461
elements += 1
467462
}
468463

469464
override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = {
470465
checkGrowBuffer()
471-
buffer(elements) = value.toStringUsingUTF8.asInstanceOf[T#JvmType]
466+
buffer(elements) = value.toStringUsingUTF8.asInstanceOf[nativeType]
472467
elements += 1
473468
}
474469

@@ -482,12 +477,7 @@ private[parquet] class CatalystNativeArrayConverter[T <: NativeType](
482477
assert(parent != null)
483478
parent.updateField(
484479
index,
485-
new GenericRow {
486-
// TODO: it would be much nicer to use a view here but GenericRow requires an Array
487-
// TODO: we should avoid using GenericRow as a wrapper but [[GetField]] current
488-
// requires that
489-
override val values = buffer.slice(0, elements).map(_.asInstanceOf[Any])
490-
})
480+
new NativeRow[nativeType](buffer.slice(0, elements)))
491481
clearBuffer()
492482
}
493483

@@ -497,7 +487,7 @@ private[parquet] class CatalystNativeArrayConverter[T <: NativeType](
497487
private def checkGrowBuffer(): Unit = {
498488
if (elements >= capacity) {
499489
val newCapacity = 2 * capacity
500-
val tmp: Array[T#JvmType] = classTag.newArray(newCapacity)
490+
val tmp: Array[nativeType] = elementType.classTag.newArray(newCapacity)
501491
Array.copy(buffer, 0, tmp, 0, capacity)
502492
buffer = tmp
503493
capacity = newCapacity

0 commit comments

Comments
 (0)