Skip to content

Commit 01f1f04

Browse files
committed
Address all comments and add unit test
1 parent 635f6fb commit 01f1f04

File tree

4 files changed

+44
-28
lines changed

4 files changed

+44
-28
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -135,18 +135,15 @@ private[execution] object HashedRelation {
135135
private[joins] class UnsafeHashedRelation(
136136
private var numKeys: Int,
137137
private var numFields: Int,
138-
private var binaryMap: BytesToBytesMap,
139-
private val isLookupAware: Boolean = false)
138+
private var binaryMap: BytesToBytesMap)
140139
extends HashedRelation with Externalizable with KryoSerializable {
141140

142-
private[joins] def this() = this(0, 0, null, false) // Needed for serialization
141+
private[joins] def this() = this(0, 0, null) // Needed for serialization
143142

144-
override def keyIsUnique: Boolean = {
145-
binaryMap.numKeys() == binaryMap.numValues()
146-
}
143+
override def keyIsUnique: Boolean = binaryMap.numKeys() == binaryMap.numValues()
147144

148145
override def asReadOnlyCopy(): UnsafeHashedRelation = {
149-
new UnsafeHashedRelation(numKeys, numFields, binaryMap, isLookupAware)
146+
new UnsafeHashedRelation(numKeys, numFields, binaryMap)
150147
}
151148

152149
override def estimatedSize: Long = binaryMap.getTotalMemoryConsumption
@@ -317,23 +314,19 @@ private[joins] class UnsafeHashedRelation(
317314
}
318315

319316
override def values(): Iterator[InternalRow] = {
320-
if (isLookupAware) {
321-
val iter = binaryMap.iterator()
317+
val iter = binaryMap.iterator()
322318

323-
new Iterator[InternalRow] {
324-
override def hasNext: Boolean = iter.hasNext
319+
new Iterator[InternalRow] {
320+
override def hasNext: Boolean = iter.hasNext
325321

326-
override def next(): InternalRow = {
327-
if (!hasNext) {
328-
throw new NoSuchElementException("End of the iterator")
329-
}
330-
val loc = iter.next()
331-
resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
332-
resultRow
322+
override def next(): InternalRow = {
323+
if (!hasNext) {
324+
throw new NoSuchElementException("End of the iterator")
333325
}
326+
val loc = iter.next()
327+
resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
328+
resultRow
334329
}
335-
} else {
336-
throw new UnsupportedOperationException
337330
}
338331
}
339332
}
@@ -405,7 +398,7 @@ private[joins] object UnsafeHashedRelation {
405398
}
406399
}
407400

408-
new UnsafeHashedRelation(key.size, numFields, binaryMap, isLookupAware)
401+
new UnsafeHashedRelation(key.size, numFields, binaryMap)
409402
}
410403
}
411404

@@ -945,7 +938,9 @@ class LongHashedRelation(
945938
override def keys(): Iterator[InternalRow] = map.keys()
946939

947940
override def values(): Iterator[InternalRow] = {
948-
throw new UnsupportedOperationException
941+
keys().flatMap { key =>
942+
get(key.getLong(0))
943+
}
949944
}
950945
}
951946

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,8 @@ case class ShuffledHashJoinExec(
134134
val buildNullRow = new GenericInternalRow(buildOutput.length)
135135
val streamNullRow = new GenericInternalRow(streamedOutput.length)
136136

137-
def markRowLookedUp(row: UnsafeRow): Unit = {
138-
if (!row.getBoolean(row.numFields() - 1)) {
139-
row.setBoolean(row.numFields() - 1, true)
140-
}
141-
}
137+
def markRowLookedUp(row: UnsafeRow): Unit =
138+
row.setBoolean(row.numFields() - 1, true)
142139

143140
// Process stream side with looking up hash relation
144141
val streamResultIter =

sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1189,7 +1189,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
11891189
}
11901190
}
11911191

1192-
test("Full outer shuffled hash join") {
1192+
test("SPARK-32399: Full outer shuffled hash join") {
11931193
val inputDFs = Seq(
11941194
// Test unique join key
11951195
(spark.range(10).selectExpr("id as k1"),

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,4 +580,28 @@ class HashedRelationSuite extends SharedSparkSession {
580580
assert(proj(packedKeys).get(0, dt) == -i - 1)
581581
}
582582
}
583+
584+
test("SPARK-32399: test values() method for HashedRelation") {
585+
val key = Seq(BoundReference(0, LongType, false))
586+
val value = Seq(BoundReference(0, IntegerType, true))
587+
val unsafeProj = UnsafeProjection.create(value)
588+
val rows = (0 until 100).map(i => unsafeProj(InternalRow(i + 1)).copy())
589+
590+
// test LongHashedRelation
591+
val longRelation = LongHashedRelation(rows.iterator, key, 10, mm)
592+
var values = longRelation.values()
593+
assert(values.map(_.getInt(0)).toArray.sortWith(_ < _) === (0 until 100).map(i => i + 1))
594+
595+
// test UnsafeHashedRelation
596+
val unsafeRelation = UnsafeHashedRelation(rows.iterator, key, 10, mm)
597+
values = unsafeRelation.values()
598+
assert(values.map(_.getInt(0)).toArray.sortWith(_ < _) === (0 until 100).map(i => i + 1))
599+
600+
// test lookup-aware UnsafeHashedRelation
601+
val lookupAwareUnsafeRelation = UnsafeHashedRelation(
602+
rows.iterator, key, 10, mm, isLookupAware = true, value = Some(value))
603+
values = lookupAwareUnsafeRelation.values()
604+
assert(values.map(v => (v.getInt(0), v.getBoolean(1))).toArray.sortWith(_._1 < _._1)
605+
=== (0 until 100).map(i => (i + 1, false)))
606+
}
583607
}

0 commit comments

Comments
 (0)