Skip to content

Commit 07165ca

Browse files
committed
[SPARK-12424][ML] The implementation of ParamMap#filter is wrong.
ParamMap#filter uses `mutable.Map#filterKeys`. The return type of `filterKey` is collection.Map, not mutable.Map but the result is casted to mutable.Map using `asInstanceOf` so we get `ClassCastException`. Also, the return type of Map#filterKeys is not Serializable. It's the issue of Scala (https://issues.scala-lang.org/browse/SI-6654). Author: Kousuke Saruta <sarutak@oss.nttdata.co.jp> Closes apache#10381 from sarutak/SPARK-12424.
1 parent e01c6c8 commit 07165ca

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -859,8 +859,12 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
859859
* Filters this param map for the given parent.
860860
*/
861861
def filter(parent: Params): ParamMap = {
862-
val filtered = map.filterKeys(_.parent == parent)
863-
new ParamMap(filtered.asInstanceOf[mutable.Map[Param[Any], Any]])
862+
// Don't use filterKeys because mutable.Map#filterKeys
863+
// returns the instance of collections.Map, not mutable.Map.
864+
// Otherwise, we get ClassCastException.
865+
// Not using filterKeys also avoid SI-6654
866+
val filtered = map.filter { case (k, _) => k.parent == parent.uid }
867+
new ParamMap(filtered)
864868
}
865869

866870
/**

mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717

1818
package org.apache.spark.ml.param
1919

20+
import java.io.{ByteArrayOutputStream, NotSerializableException, ObjectOutputStream}
21+
2022
import org.apache.spark.SparkFunSuite
23+
import org.apache.spark.ml.util.MyParams
2124
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2225

2326
class ParamsSuite extends SparkFunSuite {
@@ -349,6 +352,31 @@ class ParamsSuite extends SparkFunSuite {
349352
val t3 = t.copy(ParamMap(t.maxIter -> 20))
350353
assert(t3.isSet(t3.maxIter))
351354
}
355+
356+
test("Filtering ParamMap") {
357+
val params1 = new MyParams("my_params1")
358+
val params2 = new MyParams("my_params2")
359+
val paramMap = ParamMap(
360+
params1.intParam -> 1,
361+
params2.intParam -> 1,
362+
params1.doubleParam -> 0.2,
363+
params2.doubleParam -> 0.2)
364+
val filteredParamMap = paramMap.filter(params1)
365+
366+
assert(filteredParamMap.size === 2)
367+
filteredParamMap.toSeq.foreach {
368+
case ParamPair(p, _) =>
369+
assert(p.parent === params1.uid)
370+
}
371+
372+
// At the previous implementation of ParamMap#filter,
373+
// mutable.Map#filterKeys was used internally but
374+
// the return type of the method is not serializable (see SI-6654).
375+
// Now mutable.Map#filter is used instead of filterKeys and the return type is serializable.
376+
// So let's ensure serializability.
377+
val objOut = new ObjectOutputStream(new ByteArrayOutputStream())
378+
objOut.writeObject(filteredParamMap)
379+
}
352380
}
353381

354382
object ParamsSuite extends SparkFunSuite {

0 commit comments

Comments
 (0)