@@ -24,7 +24,7 @@ import org.apache.spark.ml.{Estimator, Transformer}
24
24
import org .apache .spark .ml .linalg .{Vector , Vectors }
25
25
import org .apache .spark .ml .param .shared .{HasInputCol , HasInputCols , HasOutputCol , HasOutputCols }
26
26
import org .apache .spark .ml .util .MyParams
27
- import org .apache .spark .sql .{ Dataset , SparkSession }
27
+ import org .apache .spark .sql .Dataset
28
28
29
29
class ParamsSuite extends SparkFunSuite {
30
30
@@ -441,24 +441,20 @@ object ParamsSuite extends SparkFunSuite {
441
441
* `HasInputCol` and both `HasOutputCols` and `HasOutputCol`.
442
442
*
443
443
* @param paramsClass The Class to be checked
444
- * @param spark A `SparkSession` instance to use
444
+ * @param dataset A `Dataset` to use in the tests
445
445
*/
446
- def checkMultiColumnParams (paramsClass : Class [_ <: Params ], spark : SparkSession ): Unit = {
447
- import spark .implicits ._
448
- // create fake input Dataset
449
- val feature1 = Array (- 1.0 , 0.0 , 1.0 )
450
- val feature2 = Array (1.0 , 0.0 , - 1.0 )
451
- val df = feature1.zip(feature2).toSeq.toDF(" feature1" , " feature2" )
446
+ def testMultiColumnParams (paramsClass : Class [_ <: Params ], dataset : Dataset [_]): Unit = {
447
+ val cols = dataset.columns
452
448
453
449
if (paramsClass.isAssignableFrom(classOf [HasInputCols ])
454
450
&& paramsClass.isAssignableFrom(classOf [HasInputCol ])) {
455
451
val model = paramsClass.newInstance()
456
- model.set(model.asInstanceOf [HasInputCols ].inputCols, Array ( " feature1 " , " feature2 " ) )
457
- model.set(model.asInstanceOf [HasInputCol ].inputCol, " features1 " )
452
+ model.set(model.asInstanceOf [HasInputCols ].inputCols, cols )
453
+ model.set(model.asInstanceOf [HasInputCol ].inputCol, cols( 0 ) )
458
454
val e = intercept[IllegalArgumentException ] {
459
455
model match {
460
- case t : Transformer => t.transform(df )
461
- case e : Estimator [_] => e.fit(df )
456
+ case t : Transformer => t.transform(dataset )
457
+ case e : Estimator [_] => e.fit(dataset )
462
458
}
463
459
}
464
460
assert(e.getMessage.contains(" cannot be both set" ))
@@ -467,12 +463,12 @@ object ParamsSuite extends SparkFunSuite {
467
463
if (paramsClass.isAssignableFrom(classOf [HasOutputCols ])
468
464
&& paramsClass.isAssignableFrom(classOf [HasOutputCol ])) {
469
465
val model = paramsClass.newInstance()
470
- model.set(model.asInstanceOf [HasOutputCols ].outputCols, Array ( " result1 " , " result2 " ) )
471
- model.set(model.asInstanceOf [HasOutputCol ].outputCol, " result1 " )
466
+ model.set(model.asInstanceOf [HasOutputCols ].outputCols, cols )
467
+ model.set(model.asInstanceOf [HasOutputCol ].outputCol, cols( 0 ) )
472
468
val e = intercept[IllegalArgumentException ] {
473
469
model match {
474
- case t : Transformer => t.transform(df )
475
- case e : Estimator [_] => e.fit(df )
470
+ case t : Transformer => t.transform(dataset )
471
+ case e : Estimator [_] => e.fit(dataset )
476
472
}
477
473
}
478
474
assert(e.getMessage.contains(" cannot be both set" ))
0 commit comments