Skip to content

Commit

Permalink
[SPARK-25122][SQL] Deduplication of supports equals code
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

The method ```*supportEquals``` determining whether elements of a data type could be used as items in a hash set or as keys in a hash map is duplicated across multiple collection and higher-order functions.

This PR suggests to deduplicate the method.

## How was this patch tested?

Run tests in:
- DataFrameFunctionsSuite
- CollectionExpressionsSuite
- HigherOrderExpressionsSuite

Closes apache#22110 from mn-mikke/SPARK-25122.

Authored-by: Marek Novotny <mn.mikke@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
mn-mikke authored and cloud-fan committed Aug 17, 2018
1 parent f161409 commit 8af61fb
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1505,13 +1505,7 @@ case class ArraysOverlap(left: Expression, right: Expression)
@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(elementType)

@transient private lazy val elementTypeSupportEquals = elementType match {
case BinaryType => false
case _: AtomicType => true
case _ => false
}

@transient private lazy val doEvaluation = if (elementTypeSupportEquals) {
@transient private lazy val doEvaluation = if (TypeUtils.typeWithProperEquals(elementType)) {
fastEval _
} else {
bruteForceEval _
Expand Down Expand Up @@ -1593,7 +1587,7 @@ case class ArraysOverlap(left: Expression, right: Expression)
nullSafeCodeGen(ctx, ev, (a1, a2) => {
val smaller = ctx.freshName("smallerArray")
val bigger = ctx.freshName("biggerArray")
val comparisonCode = if (elementTypeSupportEquals) {
val comparisonCode = if (TypeUtils.typeWithProperEquals(elementType)) {
fastCodegen(ctx, ev, smaller, bigger)
} else {
bruteForceCodegen(ctx, ev, smaller, bigger)
Expand Down Expand Up @@ -3404,12 +3398,6 @@ case class ArrayDistinct(child: Expression)
}
}

@transient private lazy val elementTypeSupportEquals = elementType match {
case BinaryType => false
case _: AtomicType => true
case _ => false
}

@transient protected lazy val canUseSpecializedHashSet = elementType match {
case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true
case _ => false
Expand All @@ -3434,9 +3422,13 @@ case class ArrayDistinct(child: Expression)

override def nullSafeEval(array: Any): Any = {
val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
if (elementTypeSupportEquals) {
new GenericArrayData(data.distinct.asInstanceOf[Array[Any]])
} else {
doEvaluation(data)
}

@transient private lazy val doEvaluation = if (TypeUtils.typeWithProperEquals(elementType)) {
(data: Array[AnyRef]) => new GenericArrayData(data.distinct.asInstanceOf[Array[Any]])
} else {
(data: Array[AnyRef]) => {
var foundNullElement = false
var pos = 0
for (i <- 0 until data.length) {
Expand Down Expand Up @@ -3576,12 +3568,6 @@ abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast {
@transient protected lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(elementType)

@transient protected lazy val elementTypeSupportEquals = elementType match {
case BinaryType => false
case _: AtomicType => true
case _ => false
}

@transient protected lazy val canUseSpecializedHashSet = elementType match {
case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true
case _ => false
Expand Down Expand Up @@ -3679,7 +3665,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
with ComplexTypeMergingExpression {

@transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = {
if (elementTypeSupportEquals) {
if (TypeUtils.typeWithProperEquals(elementType)) {
(array1, array2) =>
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
val hs = new OpenHashSet[Any]
Expand Down Expand Up @@ -3896,7 +3882,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL
}

@transient lazy val evalIntersect: (ArrayData, ArrayData) => ArrayData = {
if (elementTypeSupportEquals) {
if (TypeUtils.typeWithProperEquals(elementType)) {
(array1, array2) =>
if (array1.numElements() != 0 && array2.numElements() != 0) {
val hs = new OpenHashSet[Any]
Expand Down Expand Up @@ -4136,7 +4122,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike
}

@transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = {
if (elementTypeSupportEquals) {
if (TypeUtils.typeWithProperEquals(elementType)) {
(array1, array2) =>
val hs = new OpenHashSet[Any]
var notFoundNullElement = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -683,20 +683,14 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
value2Var: NamedLambdaVariable),
_) = function

private def keyTypeSupportsEquals = keyType match {
case BinaryType => false
case _: AtomicType => true
case _ => false
}

/**
* The function accepts two key arrays and returns a collection of keys with indexes
* to value arrays. Indexes are represented as an array of two items. This is a small
* optimization leveraging mutability of arrays.
*/
@transient private lazy val getKeysWithValueIndexes:
(ArrayData, ArrayData) => mutable.Iterable[(Any, Array[Option[Int]])] = {
if (keyTypeSupportsEquals) {
if (TypeUtils.typeWithProperEquals(keyType)) {
getKeysWithIndexesFast
} else {
getKeysWithIndexesBruteForce
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.RowOrdering
import org.apache.spark.sql.types._

/**
* Helper functions to check for valid data types.
* Functions to help with checking for valid data types and value comparison of various types.
*/
object TypeUtils {
def checkForNumericExpr(dt: DataType, caller: String): TypeCheckResult = {
Expand Down Expand Up @@ -73,4 +73,15 @@ object TypeUtils {
}
x.length - y.length
}

/**
* Returns true if the equals method of the elements of the data type is implemented properly.
* This also means that they can be safely used in collections relying on the equals method,
* as sets or maps.
*/
def typeWithProperEquals(dataType: DataType): Boolean = dataType match {
case BinaryType => false
case _: AtomicType => true
case _ => false
}
}

0 comments on commit 8af61fb

Please sign in to comment.