Skip to content

Commit

Permalink
[SPARK-2523] [SQL] Hadoop table scan bug fixing
Browse files Browse the repository at this point in the history
In HiveTableScan.scala, ObjectInspector was created for all of the partition based records, which probably causes ClassCastException if the object inspector is not identical among table & partitions.

This is the follow up with:
apache#1408
apache#1390

I've run a micro benchmark in my local with 15000000 records totally, and got the result as below:

With This Patch  |  Partition-Based Table  |  Non-Partition-Based Table
------------ | ------------- | -------------
No  |  1927 ms  |  1885 ms
Yes  | 1541 ms  |  1524 ms

It showed this patch will also improve the performance.

PS:  the benchmark code is also attached. (thanks liancheng )
```
package org.apache.spark.sql.hive

import org.apache.spark.SparkContext
import org.apache.spark.SparkConf
import org.apache.spark.sql._

object HiveTableScanPrepare extends App {
  case class Record(key: String, value: String)

  val sparkContext = new SparkContext(
    new SparkConf()
      .setMaster("local")
      .setAppName(getClass.getSimpleName.stripSuffix("$")))

  val hiveContext = new LocalHiveContext(sparkContext)

  val rdd = sparkContext.parallelize((1 to 3000000).map(i => Record(s"$i", s"val_$i")))

  import hiveContext._

  hql("SHOW TABLES")
  hql("DROP TABLE if exists part_scan_test")
  hql("DROP TABLE if exists scan_test")
  hql("DROP TABLE if exists records")
  rdd.registerAsTable("records")

  hql("""CREATE TABLE part_scan_test (key STRING, value STRING) PARTITIONED BY (part1 string, part2 STRING)
                 | ROW FORMAT SERDE
                 | 'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe'
                 | STORED AS RCFILE
               """.stripMargin)
  hql("""CREATE TABLE scan_test (key STRING, value STRING)
                 | ROW FORMAT SERDE
                 | 'org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe'
                 | STORED AS RCFILE
               """.stripMargin)

  for (part1 <- 2000 until 2001) {
    for (part2 <- 1 to 5) {
      hql(s"""from records
                 | insert into table part_scan_test PARTITION (part1='$part1', part2='2010-01-$part2')
                 | select key, value
               """.stripMargin)
      hql(s"""from records
                 | insert into table scan_test select key, value
               """.stripMargin)
    }
  }
}

object HiveTableScanTest extends App {
  val sparkContext = new SparkContext(
    new SparkConf()
      .setMaster("local")
      .setAppName(getClass.getSimpleName.stripSuffix("$")))

  val hiveContext = new LocalHiveContext(sparkContext)

  import hiveContext._

  hql("SHOW TABLES")
  val part_scan_test = hql("select key, value from part_scan_test")
  val scan_test = hql("select key, value from scan_test")

  val r_part_scan_test = (0 to 5).map(i => benchmark(part_scan_test))
  val r_scan_test = (0 to 5).map(i => benchmark(scan_test))
  println("Scanning Partition-Based Table")
  r_part_scan_test.foreach(printResult)
  println("Scanning Non-Partition-Based Table")
  r_scan_test.foreach(printResult)

  def printResult(result: (Long, Long)) {
    println(s"Duration: ${result._1} ms Result: ${result._2}")
  }

  def benchmark(srdd: SchemaRDD) = {
    val begin = System.currentTimeMillis()
    val result = srdd.count()
    val end = System.currentTimeMillis()
    ((end - begin), result)
  }
}
```

Author: Cheng Hao <hao.cheng@intel.com>

Closes apache#1439 from chenghao-intel/hadoop_table_scan and squashes the following commits:

888968f [Cheng Hao] Fix issues in code style
27540ba [Cheng Hao] Fix the TableScan Bug while partition serde differs
40a24a7 [Cheng Hao] Add Unit Test
  • Loading branch information
chenghao-intel authored and marmbrus committed Jul 28, 2014
1 parent a7d145e commit 2b8d89e
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 115 deletions.
113 changes: 81 additions & 32 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,25 @@ import org.apache.hadoop.hive.ql.exec.Utilities
import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable}
import org.apache.hadoop.hive.ql.plan.TableDesc
import org.apache.hadoop.hive.serde2.Deserializer
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector

import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf}

import org.apache.spark.SerializableWritable
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD}

import org.apache.spark.sql.catalyst.expressions.{Attribute, Row, GenericMutableRow, Literal, Cast}
import org.apache.spark.sql.catalyst.types.DataType

/**
* A trait for subclasses that handle table scans.
*/
private[hive] sealed trait TableReader {
def makeRDDForTable(hiveTable: HiveTable): RDD[_]
def makeRDDForTable(hiveTable: HiveTable): RDD[Row]

def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_]
def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row]
}


Expand All @@ -46,7 +51,10 @@ private[hive] sealed trait TableReader {
* data warehouse directory.
*/
private[hive]
class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveContext)
class HadoopTableReader(
@transient attributes: Seq[Attribute],
@transient relation: MetastoreRelation,
@transient sc: HiveContext)
extends TableReader {

// Choose the minimum number of splits. If mapred.map.tasks is set, then use that unless
Expand All @@ -63,10 +71,10 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon

def hiveConf = _broadcastedHiveConf.value.value

override def makeRDDForTable(hiveTable: HiveTable): RDD[_] =
override def makeRDDForTable(hiveTable: HiveTable): RDD[Row] =
makeRDDForTable(
hiveTable,
_tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]],
relation.tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]],
filterOpt = None)

/**
Expand All @@ -81,14 +89,14 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon
def makeRDDForTable(
hiveTable: HiveTable,
deserializerClass: Class[_ <: Deserializer],
filterOpt: Option[PathFilter]): RDD[_] = {
filterOpt: Option[PathFilter]): RDD[Row] = {

assert(!hiveTable.isPartitioned, """makeRDDForTable() cannot be called on a partitioned table,
since input formats may differ across partitions. Use makeRDDForTablePartitions() instead.""")

// Create local references to member variables, so that the entire `this` object won't be
// serialized in the closure below.
val tableDesc = _tableDesc
val tableDesc = relation.tableDesc
val broadcastedHiveConf = _broadcastedHiveConf

val tablePath = hiveTable.getPath
Expand All @@ -99,23 +107,20 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon
.asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]]
val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc)

val attrsWithIndex = attributes.zipWithIndex
val mutableRow = new GenericMutableRow(attrsWithIndex.length)
val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter =>
val hconf = broadcastedHiveConf.value.value
val deserializer = deserializerClass.newInstance()
deserializer.initialize(hconf, tableDesc.getProperties)

// Deserialize each Writable to get the row value.
iter.map {
case v: Writable => deserializer.deserialize(v)
case value =>
sys.error(s"Unable to deserialize non-Writable: $value of ${value.getClass.getName}")
}
HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow)
}

deserializedHadoopRDD
}

override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_] = {
override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[Row] = {
val partitionToDeserializer = partitions.map(part =>
(part, part.getDeserializer.getClass.asInstanceOf[Class[Deserializer]])).toMap
makeRDDForPartitionedTable(partitionToDeserializer, filterOpt = None)
Expand All @@ -132,9 +137,9 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon
* subdirectory of each partition being read. If None, then all files are accepted.
*/
def makeRDDForPartitionedTable(
partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]],
filterOpt: Option[PathFilter]): RDD[_] = {

partitionToDeserializer: Map[HivePartition,
Class[_ <: Deserializer]],
filterOpt: Option[PathFilter]): RDD[Row] = {
val hivePartitionRDDs = partitionToDeserializer.map { case (partition, partDeserializer) =>
val partDesc = Utilities.getPartitionDesc(partition)
val partPath = partition.getPartitionPath
Expand All @@ -156,33 +161,42 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon
}

// Create local references so that the outer object isn't serialized.
val tableDesc = _tableDesc
val tableDesc = relation.tableDesc
val broadcastedHiveConf = _broadcastedHiveConf
val localDeserializer = partDeserializer
val mutableRow = new GenericMutableRow(attributes.length)

// split the attributes (output schema) into 2 categories:
// (partition keys, ordinal), (normal attributes, ordinal), the ordinal mean the
// index of the attribute in the output Row.
val (partitionKeys, attrs) = attributes.zipWithIndex.partition(attr => {
relation.partitionKeys.indexOf(attr._1) >= 0
})

def fillPartitionKeys(parts: Array[String], row: GenericMutableRow) = {
partitionKeys.foreach { case (attr, ordinal) =>
// get partition key ordinal for a given attribute
val partOridinal = relation.partitionKeys.indexOf(attr)
row(ordinal) = Cast(Literal(parts(partOridinal)), attr.dataType).eval(null)
}
}
// fill the partition key for the given MutableRow Object
fillPartitionKeys(partValues, mutableRow)

val hivePartitionRDD = createHadoopRdd(tableDesc, inputPathStr, ifc)
hivePartitionRDD.mapPartitions { iter =>
val hconf = broadcastedHiveConf.value.value
val rowWithPartArr = new Array[Object](2)

// The update and deserializer initialization are intentionally
// kept out of the below iter.map loop to save performance.
rowWithPartArr.update(1, partValues)
val deserializer = localDeserializer.newInstance()
deserializer.initialize(hconf, partProps)

// Map each tuple to a row object
iter.map { value =>
val deserializedRow = deserializer.deserialize(value)
rowWithPartArr.update(0, deserializedRow)
rowWithPartArr.asInstanceOf[Object]
}
// fill the non partition key attributes
HadoopTableReader.fillObject(iter, deserializer, attrs, mutableRow)
}
}.toSeq

// Even if we don't use any partitions, we still need an empty RDD
if (hivePartitionRDDs.size == 0) {
new EmptyRDD[Object](sc.sparkContext)
new EmptyRDD[Row](sc.sparkContext)
} else {
new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs)
}
Expand Down Expand Up @@ -225,10 +239,9 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon
// Only take the value (skip the key) because Hive works only with values.
rdd.map(_._2)
}

}

private[hive] object HadoopTableReader {
private[hive] object HadoopTableReader extends HiveInspectors {
/**
* Curried. After given an argument for 'path', the resulting JobConf => Unit closure is used to
* instantiate a HadoopRDD.
Expand All @@ -241,4 +254,40 @@ private[hive] object HadoopTableReader {
val bufferSize = System.getProperty("spark.buffer.size", "65536")
jobConf.set("io.file.buffer.size", bufferSize)
}

/**
* Transform the raw data(Writable object) into the Row object for an iterable input
* @param iter Iterable input which represented as Writable object
* @param deserializer Deserializer associated with the input writable object
* @param attrs Represents the row attribute names and its zero-based position in the MutableRow
* @param row reusable MutableRow object
*
* @return Iterable Row object that transformed from the given iterable input.
*/
def fillObject(
iter: Iterator[Writable],
deserializer: Deserializer,
attrs: Seq[(Attribute, Int)],
row: GenericMutableRow): Iterator[Row] = {
val soi = deserializer.getObjectInspector().asInstanceOf[StructObjectInspector]
// get the field references according to the attributes(output of the reader) required
val fieldRefs = attrs.map { case (attr, idx) => (soi.getStructFieldRef(attr.name), idx) }

// Map each tuple to a row object
iter.map { value =>
val raw = deserializer.deserialize(value)
var idx = 0;
while (idx < fieldRefs.length) {
val fieldRef = fieldRefs(idx)._1
val fieldIdx = fieldRefs(idx)._2
val fieldValue = soi.getStructFieldData(raw, fieldRef)

row(fieldIdx) = unwrapData(fieldValue, fieldRef.getFieldObjectInspector())

idx += 1
}

row: Row
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types.{BooleanType, DataType}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.hive._
import org.apache.spark.util.MutablePair

/**
* :: DeveloperApi ::
Expand All @@ -50,8 +49,7 @@ case class HiveTableScan(
relation: MetastoreRelation,
partitionPruningPred: Option[Expression])(
@transient val context: HiveContext)
extends LeafNode
with HiveInspectors {
extends LeafNode {

require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned,
"Partition pruning predicates only supported for partitioned tables.")
Expand All @@ -67,42 +65,7 @@ case class HiveTableScan(
}

@transient
private[this] val hadoopReader = new HadoopTableReader(relation.tableDesc, context)

/**
* The hive object inspector for this table, which can be used to extract values from the
* serialized row representation.
*/
@transient
private[this] lazy val objectInspector =
relation.tableDesc.getDeserializer.getObjectInspector.asInstanceOf[StructObjectInspector]

/**
* Functions that extract the requested attributes from the hive output. Partitioned values are
* casted from string to its declared data type.
*/
@transient
protected lazy val attributeFunctions: Seq[(Any, Array[String]) => Any] = {
attributes.map { a =>
val ordinal = relation.partitionKeys.indexOf(a)
if (ordinal >= 0) {
val dataType = relation.partitionKeys(ordinal).dataType
(_: Any, partitionKeys: Array[String]) => {
castFromString(partitionKeys(ordinal), dataType)
}
} else {
val ref = objectInspector.getAllStructFieldRefs
.find(_.getFieldName == a.name)
.getOrElse(sys.error(s"Can't find attribute $a"))
val fieldObjectInspector = ref.getFieldObjectInspector

(row: Any, _: Array[String]) => {
val data = objectInspector.getStructFieldData(row, ref)
unwrapData(data, fieldObjectInspector)
}
}
}
}
private[this] val hadoopReader = new HadoopTableReader(attributes, relation, context)

private[this] def castFromString(value: String, dataType: DataType) = {
Cast(Literal(value), dataType).eval(null)
Expand All @@ -114,6 +77,7 @@ case class HiveTableScan(
val columnInternalNames = neededColumnIDs.map(HiveConf.getColumnInternalName(_)).mkString(",")

if (attributes.size == relation.output.size) {
// SQLContext#pruneFilterProject guarantees no duplicated value in `attributes`
ColumnProjectionUtils.setFullyReadColumns(hiveConf)
} else {
ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs)
Expand All @@ -140,12 +104,6 @@ case class HiveTableScan(

addColumnMetadataToConf(context.hiveconf)

private def inputRdd = if (!relation.hiveQlTable.isPartitioned) {
hadoopReader.makeRDDForTable(relation.hiveQlTable)
} else {
hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions))
}

/**
* Prunes partitions not involve the query plan.
*
Expand All @@ -169,44 +127,10 @@ case class HiveTableScan(
}
}

override def execute() = {
inputRdd.mapPartitions { iterator =>
if (iterator.isEmpty) {
Iterator.empty
} else {
val mutableRow = new GenericMutableRow(attributes.length)
val mutablePair = new MutablePair[Any, Array[String]]()
val buffered = iterator.buffered

// NOTE (lian): Critical path of Hive table scan, unnecessary FP style code and pattern
// matching are avoided intentionally.
val rowsAndPartitionKeys = buffered.head match {
// With partition keys
case _: Array[Any] =>
buffered.map { case array: Array[Any] =>
val deserializedRow = array(0)
val partitionKeys = array(1).asInstanceOf[Array[String]]
mutablePair.update(deserializedRow, partitionKeys)
}

// Without partition keys
case _ =>
val emptyPartitionKeys = Array.empty[String]
buffered.map { deserializedRow =>
mutablePair.update(deserializedRow, emptyPartitionKeys)
}
}

rowsAndPartitionKeys.map { pair =>
var i = 0
while (i < attributes.length) {
mutableRow(i) = attributeFunctions(i)(pair._1, pair._2)
i += 1
}
mutableRow: Row
}
}
}
override def execute() = if (!relation.hiveQlTable.isPartitioned) {
hadoopReader.makeRDDForTable(relation.hiveQlTable)
} else {
hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions))
}

override def output = attributes
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
100 100 2010-01-01
200 200 2010-01-02
Loading

0 comments on commit 2b8d89e

Please sign in to comment.