Skip to content

strengthened requirements about exclusive Params for single and multicolumn support #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}

/**
* `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0,
* `Bucketizer` maps a column of continuous features to a column of feature buckets.
*
* Since 2.3.0,
* `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that
* when both the `inputCol` and `inputCols` parameters are set, an Exception will be thrown. The
* `splits` parameter is only used for single column usage, and `splitsArray` is for multiple
Expand Down Expand Up @@ -184,11 +186,16 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
ParamValidators.checkExclusiveParams(this, "inputCol", "inputCols")
ParamValidators.checkExclusiveParams(this, "outputCol", "outputCols")
ParamValidators.checkExclusiveParams(this, "splits", "splitsArray")
ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol, splits),
Seq(outputCols, splitsArray))

if (isSet(inputCols)) {
require(getInputCols.length == getOutputCols.length &&
getInputCols.length == getSplitsArray.length, s"Bucketizer $this has mismatched Params " +
s"for multi-column transform. Params (inputCols, outputCols, splitsArray) should have " +
s"equal lengths, but they have different lengths: " +
s"(${getInputCols.length}, ${getOutputCols.length}, ${getSplitsArray.length}).")

var transformedSchema = schema
$(inputCols).zip($(outputCols)).zipWithIndex.foreach { case ((inputCol, outputCol), idx) =>
SchemaUtils.checkNumericType(transformedSchema, inputCol)
Expand Down
75 changes: 60 additions & 15 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import scala.collection.mutable

import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.slf4j.LoggerFactory

import org.apache.spark.SparkException
import org.apache.spark.annotation.{DeveloperApi, Since}
Expand Down Expand Up @@ -167,8 +166,6 @@ private[ml] object Param {
@DeveloperApi
object ParamValidators {

private val LOGGER = LoggerFactory.getLogger(ParamValidators.getClass)

/** (private[param]) Default validation always return true */
private[param] def alwaysTrue[T]: T => Boolean = (_: T) => true

Expand Down Expand Up @@ -254,21 +251,69 @@ object ParamValidators {
}

/**
* Checks that only one of the params passed as arguments is set. If this is not true, an
* `IllegalArgumentException` is raised.
* Utility for Param validity checks for Transformers which have both single- and multi-column
* support. This utility assumes that `inputCol` indicates single-column usage and
* that `inputCols` indicates multi-column usage.
*
* This checks to ensure that exactly one set of Params has been set, and it
* raises an `IllegalArgumentException` if not.
*
* @param singleColumnParams Params which should be set (or have defaults) if `inputCol` has been
* set. This does not need to include `inputCol`.
* @param multiColumnParams Params which should be set (or have defaults) if `inputCols` has been
* set. This does not need to include `inputCols`.
*/
def checkExclusiveParams(model: Params, params: String*): Unit = {
val (existingParams, nonExistingParams) = params.partition(model.hasParam)
if (nonExistingParams.nonEmpty) {
val pronoun = if (nonExistingParams.size == 1) "It" else "They"
LOGGER.warn(s"Ignored ${nonExistingParams.mkString("`", "`, `", "`")} while checking " +
s"exclusive params. $pronoun don't exist for the specified model the model.")
def checkSingleVsMultiColumnParams(
model: Params,
singleColumnParams: Seq[Param[_]],
multiColumnParams: Seq[Param[_]]): Unit = {
val name = s"${model.getClass.getSimpleName} $model"

def checkExclusiveParams(
isSingleCol: Boolean,
requiredParams: Seq[Param[_]],
excludedParams: Seq[Param[_]]): Unit = {
val badParamsMsgBuilder = new mutable.StringBuilder()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This builder lets us include all incorrectly set Params in the error message, rather than just one.


val mustUnsetParams = excludedParams.filter(p => model.isSet(p))
.map(_.name).mkString(", ")
if (mustUnsetParams.nonEmpty)
badParamsMsgBuilder ++=
s"The following Params are not applicable and should not be set: $mustUnsetParams."

val mustSetParams = requiredParams.filter(p => !model.isDefined(p))
.map(_.name).mkString(", ")
if (mustSetParams.nonEmpty)
badParamsMsgBuilder ++=
s"The following Params must be defined but are not set: $mustSetParams."

val badParamsMsg = badParamsMsgBuilder.toString()

if (badParamsMsg.nonEmpty) {
val errPrefix = if (isSingleCol) {
s"$name has the inputCol Param set for single-column transform."
} else {
s"$name has the inputCols Param set for multi-column transform."
}
throw new IllegalArgumentException(s"$errPrefix $badParamsMsg")
}
}

if (existingParams.count(paramName => model.isSet(model.getParam(paramName))) > 1) {
val paramString = existingParams.mkString("`", "`, `", "`")
throw new IllegalArgumentException(s"$paramString are exclusive, " +
"but more than one among them are set.")
val inputCol = model.getParam("inputCol")
val inputCols = model.getParam("inputCols")

if (model.isSet(inputCol)) {
require(!model.isSet(inputCols), s"$name requires " +
s"exactly one of inputCol, inputCols Params to be set, but both are set.")

checkExclusiveParams(isSingleCol = true, requiredParams = singleColumnParams,
excludedParams = multiColumnParams)
} else if (model.isSet(inputCols)) {
checkExclusiveParams(isSingleCol = false, requiredParams = multiColumnParams,
excludedParams = singleColumnParams)
} else {
throw new IllegalArgumentException(s"$name requires " +
s"exactly one of inputCol, inputCols Params to be set, but neither is set.")
}
}
}
Expand Down
27 changes: 13 additions & 14 deletions mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -435,24 +435,23 @@ object ParamsSuite extends SparkFunSuite {
}

/**
* Checks that the class throws an exception in case multiple exclusive params are set
* Checks that the class throws an exception in case multiple exclusive params are set.
* The params to be checked are passed as arguments with their value.
* The checks are performed only if all the passed params are defined for the given model.
*/
def testExclusiveParams(model: Params, dataset: Dataset[_],
def testExclusiveParams(
model: Params,
dataset: Dataset[_],
paramsAndValues: (String, Any)*): Unit = {
val params = paramsAndValues.map(_._1)
if (params.forall(model.hasParam)) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should check this. This method will be called from tests, and tests should be written carefully enough to avoid mistakes like this.

paramsAndValues.foreach { case (paramName, paramValue) =>
model.set(model.getParam(paramName), paramValue)
}
val e = intercept[IllegalArgumentException] {
model match {
case t: Transformer => t.transform(dataset)
case e: Estimator[_] => e.fit(dataset)
}
val m = model.copy(ParamMap.empty)
paramsAndValues.foreach { case (paramName, paramValue) =>
m.set(m.getParam(paramName), paramValue)
}
val e = intercept[IllegalArgumentException] {
m match {
case t: Transformer => t.transform(dataset)
case e: Estimator[_] => e.fit(dataset)
}
assert(e.getMessage.contains("are exclusive, but more than one"))
}
assert(e.getMessage.contains("are exclusive, but more than one"))
}
}