Skip to content

Commit c81072d

Browse files
committed
addressed comments
1 parent 99c2ebf commit c81072d

File tree

3 files changed

+8
-12
lines changed

3 files changed

+8
-12
lines changed

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -220,19 +220,15 @@ class BooleanParam(parent: Params, name: String, doc: String) // No need for isV
220220
}
221221

222222
/** Specialized version of [[Param[Array[T]]]] for Java. */
223-
class ArrayParam[T : ClassTag](
224-
parent: Params,
225-
name: String,
226-
doc: String,
227-
isValid: Array[T] => Boolean)
228-
extends Param[Array[T]](parent, name, doc, isValid) {
223+
class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean)
224+
extends Param[Array[String]](parent, name, doc, isValid) {
229225

230226
def this(parent: Params, name: String, doc: String) =
231227
this(parent, name, doc, ParamValidators.alwaysTrue)
232228

233-
override def w(value: Array[T]): ParamPair[Array[T]] = super.w(value)
229+
override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value)
234230

235-
private[param] def wCast(value: Seq[T]): ParamPair[Array[T]] = w(value.toArray)
231+
private[param] def wCast(value: Seq[String]): ParamPair[Array[String]] = w(value.toArray)
236232
}
237233

238234
/**
@@ -328,8 +324,8 @@ trait Params extends Identifiable with Serializable {
328324
*/
329325
protected final def set[T](param: Param[T], value: T): this.type = {
330326
shouldOwn(param)
331-
if (param.isInstanceOf[ArrayParam[_]] && value.isInstanceOf[Seq[_]]) {
332-
paramMap.put(param.asInstanceOf[ArrayParam[Any]].wCast(value.asInstanceOf[Seq[Any]]))
327+
if (param.isInstanceOf[StringArrayParam] && value.isInstanceOf[Seq[_]]) {
328+
paramMap.put(param.asInstanceOf[StringArrayParam].wCast(value.asInstanceOf[Seq[String]]))
333329
} else {
334330
paramMap.put(param.w(value))
335331
}

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ private[shared] object SharedParamsCodeGen {
8383
case _ if c == classOf[Float] => "FloatParam"
8484
case _ if c == classOf[Double] => "DoubleParam"
8585
case _ if c == classOf[Boolean] => "BooleanParam"
86-
case _ if c.isArray => s"ArrayParam[${getTypeString(c.getComponentType)}]"
86+
case _ if c.isArray && c.getComponentType == classOf[String] => s"StringArrayParam"
8787
case _ => s"Param[${getTypeString(c)}]"
8888
}
8989
}

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ private[ml] trait HasInputCols extends Params {
178178
* Param for input column names.
179179
* @group param
180180
*/
181-
final val inputCols: ArrayParam[String] = new ArrayParam[String](this, "inputCols", "input column names")
181+
final val inputCols: StringArrayParam = new StringArrayParam(this, "inputCols", "input column names")
182182

183183
/** @group getParam */
184184
final def getInputCols: Array[String] = $(inputCols)

0 commit comments

Comments
 (0)