Skip to content

Commit b907314

Browse files
committed
recreate pr
1 parent fddb63f commit b907314

File tree

7 files changed

+82
-42
lines changed

7 files changed

+82
-42
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,16 @@ import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer}
2727
import org.apache.spark.ml.feature.LabeledPoint
2828
import org.apache.spark.ml.linalg.{Vector, Vectors}
2929
import org.apache.spark.ml.param._
30-
import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasStepSize, HasTol}
30+
import org.apache.spark.ml.param.shared._
3131
import org.apache.spark.ml.util._
3232
import org.apache.spark.sql.Dataset
3333

3434
/** Params for Multilayer Perceptron. */
3535
private[classification] trait MultilayerPerceptronParams extends PredictorParams
36-
with HasSeed with HasMaxIter with HasTol with HasStepSize {
36+
with HasSeed with HasMaxIter with HasTol with HasStepSize with HasSolver {
37+
38+
import MultilayerPerceptronClassifier._
39+
3740
/**
3841
* Layer sizes including input size and output size.
3942
*
@@ -78,10 +81,10 @@ private[classification] trait MultilayerPerceptronParams extends PredictorParams
7881
* @group expertParam
7982
*/
8083
@Since("2.0.0")
81-
final val solver: Param[String] = new Param[String](this, "solver",
84+
final override val solver: Param[String] = new Param[String](this, "solver",
8285
"The solver algorithm for optimization. Supported options: " +
83-
s"${MultilayerPerceptronClassifier.supportedSolvers.mkString(", ")}. (Default l-bfgs)",
84-
ParamValidators.inArray[String](MultilayerPerceptronClassifier.supportedSolvers))
86+
s"${supportedSolvers.mkString(", ")}. (Default l-bfgs)",
87+
ParamValidators.inArray[String](supportedSolvers))
8588

8689
/** @group expertGetParam */
8790
@Since("2.0.0")
@@ -101,7 +104,7 @@ private[classification] trait MultilayerPerceptronParams extends PredictorParams
101104
final def getInitialWeights: Vector = $(initialWeights)
102105

103106
setDefault(maxIter -> 100, tol -> 1e-6, blockSize -> 128,
104-
solver -> MultilayerPerceptronClassifier.LBFGS, stepSize -> 0.03)
107+
solver -> LBFGS, stepSize -> 0.03)
105108
}
106109

107110
/** Label to vector converter. */

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ private[shared] object SharedParamsCodeGen {
8080
" 0)", isValid = "ParamValidators.gt(0)"),
8181
ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " +
8282
"all instance weights as 1.0"),
83-
ParamDesc[String]("solver", "the solver algorithm for optimization. If this is not set or " +
84-
"empty, default value is 'auto'", Some("\"auto\"")),
83+
ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false),
8584
ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"),
8685
isValid = "ParamValidators.gtEq(2)", isExpertParam = true))
8786

@@ -99,6 +98,7 @@ private[shared] object SharedParamsCodeGen {
9998
defaultValueStr: Option[String] = None,
10099
isValid: String = "",
101100
finalMethods: Boolean = true,
101+
finalFields: Boolean = true,
102102
isExpertParam: Boolean = false) {
103103

104104
require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.")
@@ -167,6 +167,11 @@ private[shared] object SharedParamsCodeGen {
167167
} else {
168168
"def"
169169
}
170+
val fieldStr = if (param.finalFields) {
171+
"final val"
172+
} else {
173+
"val"
174+
}
170175

171176
val htmlCompliantDoc = Utility.escape(doc)
172177

@@ -180,7 +185,7 @@ private[shared] object SharedParamsCodeGen {
180185
| * Param for $htmlCompliantDoc.
181186
| * @group ${groupStr(0)}
182187
| */
183-
| final val $name: $Param = new $Param(this, "$name", "$doc"$isValid)
188+
| $fieldStr $name: $Param = new $Param(this, "$name", "$doc"$isValid)
184189
|$setDefault
185190
| /** @group ${groupStr(1)} */
186191
| $methodStr get$Name: $T = $$($name)

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -374,17 +374,15 @@ private[ml] trait HasWeightCol extends Params {
374374
}
375375

376376
/**
377-
* Trait for shared param solver (default: "auto").
377+
* Trait for shared param solver.
378378
*/
379379
private[ml] trait HasSolver extends Params {
380380

381381
/**
382-
* Param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.
382+
* Param for the solver algorithm for optimization.
383383
* @group param
384384
*/
385-
final val solver: Param[String] = new Param[String](this, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'")
386-
387-
setDefault(solver, "auto")
385+
final val solver: Param[String] = new Param[String](this, "solver", "the solver algorithm for optimization")
388386

389387
/** @group getParam */
390388
final def getSolver: String = $(solver)

mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,18 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
143143
isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty
144144
}
145145

146-
import GeneralizedLinearRegression._
146+
/**
147+
* The solver algorithm for optimization.
148+
* Supported options: "irls" (iteratively reweighted least squares).
149+
* Default: "irls"
150+
*
151+
* @group expertParam
152+
*/
153+
@Since("2.3.0")
154+
final override val solver: Param[String] = new Param[String](this, "solver",
155+
"The solver algorithm for optimization. Supported options: " +
156+
s"${supportedSolvers.mkString(", ")}. (Default irls)",
157+
ParamValidators.inArray[String](supportedSolvers))
147158

148159
@Since("2.0.0")
149160
override def validateAndTransformSchema(
@@ -314,7 +325,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
314325
*/
315326
@Since("2.0.0")
316327
def setSolver(value: String): this.type = set(solver, value)
317-
setDefault(solver -> "irls")
328+
setDefault(solver -> IRLS)
318329

319330
/**
320331
* Sets the link prediction (linear predictor) column name.
@@ -400,6 +411,12 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
400411
Gamma -> Inverse, Gamma -> Identity, Gamma -> Log
401412
)
402413

414+
/** String name for "irls" (iteratively reweighted least squares) solver. */
415+
private[regression] val IRLS = "irls"
416+
417+
/** Set of solvers that GeneralizedLinearRegression supports. */
418+
private[regression] val supportedSolvers = Array(IRLS)
419+
403420
/** Set of family names that GeneralizedLinearRegression supports. */
404421
private[regression] lazy val supportedFamilyNames =
405422
supportedFamilyAndLinkPairs.map(_._1.name).toArray :+ "tweedie"

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.ml.optim.WeightedLeastSquares
3434
import org.apache.spark.ml.PredictorParams
3535
import org.apache.spark.ml.optim.aggregator.LeastSquaresAggregator
3636
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
37-
import org.apache.spark.ml.param.ParamMap
37+
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
3838
import org.apache.spark.ml.param.shared._
3939
import org.apache.spark.ml.util._
4040
import org.apache.spark.mllib.evaluation.RegressionMetrics
@@ -53,7 +53,23 @@ import org.apache.spark.storage.StorageLevel
5353
private[regression] trait LinearRegressionParams extends PredictorParams
5454
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
5555
with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver
56-
with HasAggregationDepth
56+
with HasAggregationDepth {
57+
58+
import LinearRegression._
59+
60+
/**
61+
* The solver algorithm for optimization.
62+
* Supported options: "l-bfgs", "normal" and "auto".
63+
* Default: "auto"
64+
*
65+
* @group expertParam
66+
*/
67+
@Since("2.3.0")
68+
final override val solver: Param[String] = new Param[String](this, "solver",
69+
"The solver algorithm for optimization. Supported options: " +
70+
s"${supportedSolvers.mkString(", ")}. (Default auto)",
71+
ParamValidators.inArray[String](supportedSolvers))
72+
}
5773

5874
/**
5975
* Linear regression.
@@ -78,6 +94,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
7894
extends Regressor[Vector, LinearRegression, LinearRegressionModel]
7995
with LinearRegressionParams with DefaultParamsWritable with Logging {
8096

97+
import LinearRegression._
98+
8199
@Since("1.4.0")
82100
def this() = this(Identifiable.randomUID("linReg"))
83101

@@ -175,12 +193,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
175193
* @group setParam
176194
*/
177195
@Since("1.6.0")
178-
def setSolver(value: String): this.type = {
179-
require(Set("auto", "l-bfgs", "normal").contains(value),
180-
s"Solver $value was not supported. Supported options: auto, l-bfgs, normal")
181-
set(solver, value)
182-
}
183-
setDefault(solver -> "auto")
196+
def setSolver(value: String): this.type = set(solver, value)
197+
setDefault(solver -> AUTO)
184198

185199
/**
186200
* Suggested depth for treeAggregate (greater than or equal to 2).
@@ -210,8 +224,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
210224
elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth)
211225
instr.logNumFeatures(numFeatures)
212226

213-
if (($(solver) == "auto" &&
214-
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") {
227+
if (($(solver) == AUTO &&
228+
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == NORMAL) {
215229
// For low dimensional data, WeightedLeastSquares is more efficient since the
216230
// training algorithm only requires one pass through the data. (SPARK-10668)
217231

@@ -444,6 +458,18 @@ object LinearRegression extends DefaultParamsReadable[LinearRegression] {
444458
*/
445459
@Since("2.1.0")
446460
val MAX_FEATURES_FOR_NORMAL_SOLVER: Int = WeightedLeastSquares.MAX_NUM_FEATURES
461+
462+
/** String name for "auto". */
463+
private[regression] val AUTO = "auto"
464+
465+
/** String name for "normal". */
466+
private[regression] val NORMAL = "normal"
467+
468+
/** String name for "l-bfgs". */
469+
private[regression] val LBFGS = "l-bfgs"
470+
471+
/** Set of solvers that LinearRegression supports. */
472+
private[regression] val supportedSolvers = Array(AUTO, NORMAL, LBFGS)
447473
}
448474

449475
/**

python/pyspark/ml/classification.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,8 +1265,8 @@ def theta(self):
12651265

12661266
@inherit_doc
12671267
class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
1268-
HasMaxIter, HasTol, HasSeed, HasStepSize, JavaMLWritable,
1269-
JavaMLReadable):
1268+
HasMaxIter, HasTol, HasSeed, HasStepSize, HasSolver,
1269+
JavaMLWritable, JavaMLReadable):
12701270
"""
12711271
Classifier trainer based on the Multilayer Perceptron.
12721272
Each layer has sigmoid activation function, output layer has softmax.
@@ -1407,20 +1407,6 @@ def getStepSize(self):
14071407
"""
14081408
return self.getOrDefault(self.stepSize)
14091409

1410-
@since("2.0.0")
1411-
def setSolver(self, value):
1412-
"""
1413-
Sets the value of :py:attr:`solver`.
1414-
"""
1415-
return self._set(solver=value)
1416-
1417-
@since("2.0.0")
1418-
def getSolver(self):
1419-
"""
1420-
Gets the value of solver or its default value.
1421-
"""
1422-
return self.getOrDefault(self.solver)
1423-
14241410
@since("2.0.0")
14251411
def setInitialWeights(self, value):
14261412
"""

python/pyspark/ml/regression.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
9595
.. versionadded:: 1.4.0
9696
"""
9797

98+
solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
99+
"options: auto, normal, l-bfgs.", typeConverter=TypeConverters.toString)
100+
98101
@keyword_only
99102
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
100103
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
@@ -1371,6 +1374,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
13711374
linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " +
13721375
"Only applicable to the Tweedie family.",
13731376
typeConverter=TypeConverters.toFloat)
1377+
solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
1378+
"options: irls.", typeConverter=TypeConverters.toString)
13741379

13751380
@keyword_only
13761381
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",

0 commit comments

Comments
 (0)