Skip to content

Commit 3f78f60

Browse files
committed
[SPARK-23697][CORE] LegacyAccumulatorWrapper should define isZero correctly
## What changes were proposed in this pull request? It's possible that Accumulators of Spark 1.x may no longer work with Spark 2.x. This is because `LegacyAccumulatorWrapper.isZero` may return wrong answer if `AccumulableParam` doesn't define equals/hashCode. This PR fixes this by using reference equality check in `LegacyAccumulatorWrapper.isZero`. ## How was this patch tested? a new test Author: Wenchen Fan <wenchen@databricks.com> Closes #21229 from cloud-fan/accumulator. (cherry picked from commit 4d5de4d) Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent d35eb2f commit 3f78f60

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,9 @@ class LegacyAccumulatorWrapper[R, T](
486486
param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] {
487487
private[spark] var _value = initialValue // Current value on driver
488488

489-
override def isZero: Boolean = _value == param.zero(initialValue)
489+
@transient private lazy val _zero = param.zero(initialValue)
490+
491+
override def isZero: Boolean = _value.asInstanceOf[AnyRef].eq(_zero.asInstanceOf[AnyRef])
490492

491493
override def copy(): LegacyAccumulatorWrapper[R, T] = {
492494
val acc = new LegacyAccumulatorWrapper(initialValue, param)
@@ -495,7 +497,7 @@ class LegacyAccumulatorWrapper[R, T](
495497
}
496498

497499
override def reset(): Unit = {
498-
_value = param.zero(initialValue)
500+
_value = _zero
499501
}
500502

501503
override def add(v: T): Unit = _value = param.addAccumulator(_value, v)

core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.util
1919

2020
import org.apache.spark._
21+
import org.apache.spark.serializer.JavaSerializer
2122

2223
class AccumulatorV2Suite extends SparkFunSuite {
2324

@@ -162,4 +163,22 @@ class AccumulatorV2Suite extends SparkFunSuite {
162163
assert(acc3.isZero)
163164
assert(acc3.value === "")
164165
}
166+
167+
test("LegacyAccumulatorWrapper with AccumulatorParam that has no equals/hashCode") {
168+
class MyData(val i: Int) extends Serializable
169+
val param = new AccumulatorParam[MyData] {
170+
override def zero(initialValue: MyData): MyData = new MyData(0)
171+
override def addInPlace(r1: MyData, r2: MyData): MyData = new MyData(r1.i + r2.i)
172+
}
173+
174+
val acc = new LegacyAccumulatorWrapper(new MyData(0), param)
175+
acc.metadata = AccumulatorMetadata(
176+
AccumulatorContext.newId(),
177+
Some("test"),
178+
countFailedValues = false)
179+
AccumulatorContext.register(acc)
180+
181+
val ser = new JavaSerializer(new SparkConf).newInstance()
182+
ser.serialize(acc)
183+
}
165184
}

0 commit comments

Comments
 (0)