Skip to content

Commit d0b8d06

Browse files
committed
address review comments
1 parent 9872bfd commit d0b8d06

File tree

4 files changed

+16
-19
lines changed

4 files changed

+16
-19
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
193193

194194
@Since("1.4.0")
195195
override def transformSchema(schema: StructType): StructType = {
196-
ParamValidators.assertColOrCols(this)
196+
ParamValidators.checkMultiColumnParams(this)
197197
if (isSet(inputCol) && isSet(splitsArray)) {
198198
ParamValidators.raiseIncompatibleParamsException("inputCol", "splitsArray")
199199
}

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ object ParamValidators {
256256
* this is not true, an `IllegalArgumentException` is raised.
257257
* @param model
258258
*/
259-
private[spark] def assertColOrCols(model: Params): Unit = {
259+
private[spark] def checkMultiColumnParams(model: Params): Unit = {
260260
model match {
261261
case m: HasInputCols with HasInputCol if m.isSet(m.inputCols) && m.isSet(m.inputCol) =>
262262
raiseIncompatibleParamsException("inputCols", "inputCol")

mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,8 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
402402
}
403403

404404
test("assert exception is thrown is both multi-column and single-column params are set") {
405-
ParamsSuite.checkMultiColumnParams(classOf[Bucketizer], spark)
405+
val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2")
406+
ParamsSuite.testMultiColumnParams(classOf[Bucketizer], df)
406407
}
407408
}
408409

mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.ml.{Estimator, Transformer}
2424
import org.apache.spark.ml.linalg.{Vector, Vectors}
2525
import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
2626
import org.apache.spark.ml.util.MyParams
27-
import org.apache.spark.sql.{Dataset, SparkSession}
27+
import org.apache.spark.sql.Dataset
2828

2929
class ParamsSuite extends SparkFunSuite {
3030

@@ -441,24 +441,20 @@ object ParamsSuite extends SparkFunSuite {
441441
* `HasInputCol` and both `HasOutputCols` and `HasOutputCol`.
442442
*
443443
* @param paramsClass The Class to be checked
444-
* @param spark A `SparkSession` instance to use
444+
* @param dataset A `Dataset` to use in the tests
445445
*/
446-
def checkMultiColumnParams(paramsClass: Class[_ <: Params], spark: SparkSession): Unit = {
447-
import spark.implicits._
448-
// create fake input Dataset
449-
val feature1 = Array(-1.0, 0.0, 1.0)
450-
val feature2 = Array(1.0, 0.0, -1.0)
451-
val df = feature1.zip(feature2).toSeq.toDF("feature1", "feature2")
446+
def testMultiColumnParams(paramsClass: Class[_ <: Params], dataset: Dataset[_]): Unit = {
447+
val cols = dataset.columns
452448

453449
if (paramsClass.isAssignableFrom(classOf[HasInputCols])
454450
&& paramsClass.isAssignableFrom(classOf[HasInputCol])) {
455451
val model = paramsClass.newInstance()
456-
model.set(model.asInstanceOf[HasInputCols].inputCols, Array("feature1", "feature2"))
457-
model.set(model.asInstanceOf[HasInputCol].inputCol, "features1")
452+
model.set(model.asInstanceOf[HasInputCols].inputCols, cols)
453+
model.set(model.asInstanceOf[HasInputCol].inputCol, cols(0))
458454
val e = intercept[IllegalArgumentException] {
459455
model match {
460-
case t: Transformer => t.transform(df)
461-
case e: Estimator[_] => e.fit(df)
456+
case t: Transformer => t.transform(dataset)
457+
case e: Estimator[_] => e.fit(dataset)
462458
}
463459
}
464460
assert(e.getMessage.contains("cannot be both set"))
@@ -467,12 +463,12 @@ object ParamsSuite extends SparkFunSuite {
467463
if (paramsClass.isAssignableFrom(classOf[HasOutputCols])
468464
&& paramsClass.isAssignableFrom(classOf[HasOutputCol])) {
469465
val model = paramsClass.newInstance()
470-
model.set(model.asInstanceOf[HasOutputCols].outputCols, Array("result1", "result2"))
471-
model.set(model.asInstanceOf[HasOutputCol].outputCol, "result1")
466+
model.set(model.asInstanceOf[HasOutputCols].outputCols, cols)
467+
model.set(model.asInstanceOf[HasOutputCol].outputCol, cols(0))
472468
val e = intercept[IllegalArgumentException] {
473469
model match {
474-
case t: Transformer => t.transform(df)
475-
case e: Estimator[_] => e.fit(df)
470+
case t: Transformer => t.transform(dataset)
471+
case e: Estimator[_] => e.fit(dataset)
476472
}
477473
}
478474
assert(e.getMessage.contains("cannot be both set"))

0 commit comments

Comments
 (0)