|
17 | 17 |
|
18 | 18 | package org.apache.spark.ml.param
|
19 | 19 |
|
| 20 | +import java.io.{ByteArrayOutputStream, NotSerializableException, ObjectOutputStream} |
| 21 | + |
20 | 22 | import org.apache.spark.SparkFunSuite
|
| 23 | +import org.apache.spark.ml.util.MyParams |
21 | 24 | import org.apache.spark.mllib.linalg.{Vector, Vectors}
|
22 | 25 |
|
23 | 26 | class ParamsSuite extends SparkFunSuite {
|
@@ -349,6 +352,31 @@ class ParamsSuite extends SparkFunSuite {
|
349 | 352 | val t3 = t.copy(ParamMap(t.maxIter -> 20))
|
350 | 353 | assert(t3.isSet(t3.maxIter))
|
351 | 354 | }
|
| 355 | + |
| 356 | + test("Filtering ParamMap") { |
| 357 | + val params1 = new MyParams("my_params1") |
| 358 | + val params2 = new MyParams("my_params2") |
| 359 | + val paramMap = ParamMap( |
| 360 | + params1.intParam -> 1, |
| 361 | + params2.intParam -> 1, |
| 362 | + params1.doubleParam -> 0.2, |
| 363 | + params2.doubleParam -> 0.2) |
| 364 | + val filteredParamMap = paramMap.filter(params1) |
| 365 | + |
| 366 | + assert(filteredParamMap.size === 2) |
| 367 | + filteredParamMap.toSeq.foreach { |
| 368 | + case ParamPair(p, _) => |
| 369 | + assert(p.parent === params1.uid) |
| 370 | + } |
| 371 | + |
| 372 | + // At the previous implementation of ParamMap#filter, |
| 373 | + // mutable.Map#filterKeys was used internally but |
| 374 | + // the return type of the method is not serializable (see SI-6654). |
| 375 | + // Now mutable.Map#filter is used instead of filterKeys and the return type is serializable. |
| 376 | + // So let's ensure serializability. |
| 377 | + val objOut = new ObjectOutputStream(new ByteArrayOutputStream()) |
| 378 | + objOut.writeObject(filteredParamMap) |
| 379 | + } |
352 | 380 | }
|
353 | 381 |
|
354 | 382 | object ParamsSuite extends SparkFunSuite {
|
|
0 commit comments