Skip to content

Commit d32ed25

Browse files
huaxingaosrowen
authored andcommitted
[SPARK-30144][ML][PYSPARK] Make MultilayerPerceptronClassificationModel extend MultilayerPerceptronParams
### What changes were proposed in this pull request? Make ```MultilayerPerceptronClassificationModel``` extend ```MultilayerPerceptronParams``` ### Why are the changes needed? Make ```MultilayerPerceptronClassificationModel``` extend ```MultilayerPerceptronParams``` to expose the training params, so user can see these params when calling ```extractParamMap``` ### Does this PR introduce any user-facing change? Yes. The ```MultilayerPerceptronParams``` such as ```seed```, ```maxIter``` ... are available in ```MultilayerPerceptronClassificationModel``` now ### How was this patch tested? Manually tested ```MultilayerPerceptronClassificationModel.extractParamMap()``` to verify all the new params are there. Closes #26838 from huaxingao/spark-30144. Authored-by: Huaxin Gao <huaxing@us.ibm.com> Signed-off-by: Sean Owen <srowen@gmail.com>
1 parent 6196c20 commit d32ed25

30 files changed

+58
-44
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.ml.param.shared._
3030
import org.apache.spark.ml.util._
3131
import org.apache.spark.ml.util.Instrumentation.instrumented
3232
import org.apache.spark.sql.{Dataset, Row}
33+
import org.apache.spark.util.VersionUtils.majorMinorVersion
3334

3435
/** Params for Multilayer Perceptron. */
3536
private[classification] trait MultilayerPerceptronParams extends ProbabilisticClassifierParams
@@ -247,7 +248,7 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
247248
}
248249
trainer.setStackSize($(blockSize))
249250
val mlpModel = trainer.train(data)
250-
new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights)
251+
new MultilayerPerceptronClassificationModel(uid, mlpModel.weights)
251252
}
252253
}
253254

@@ -273,31 +274,22 @@ object MultilayerPerceptronClassifier
273274
* Each layer has sigmoid activation function, output layer has softmax.
274275
*
275276
* @param uid uid
276-
* @param layers array of layer sizes including input and output layers
277277
* @param weights the weights of layers
278278
*/
279279
@Since("1.5.0")
280280
class MultilayerPerceptronClassificationModel private[ml] (
281281
@Since("1.5.0") override val uid: String,
282-
@Since("1.5.0") val layers: Array[Int],
283282
@Since("2.0.0") val weights: Vector)
284283
extends ProbabilisticClassificationModel[Vector, MultilayerPerceptronClassificationModel]
285-
with Serializable with MLWritable {
284+
with MultilayerPerceptronParams with Serializable with MLWritable {
286285

287286
@Since("1.6.0")
288-
override val numFeatures: Int = layers.head
287+
override lazy val numFeatures: Int = $(layers).head
289288

290-
private[ml] val mlpModel = FeedForwardTopology
291-
.multiLayerPerceptron(layers, softmaxOnTop = true)
289+
@transient private[ml] lazy val mlpModel = FeedForwardTopology
290+
.multiLayerPerceptron($(layers), softmaxOnTop = true)
292291
.model(weights)
293292

294-
/**
295-
* Returns layers in a Java List.
296-
*/
297-
private[ml] def javaLayers: java.util.List[Int] = {
298-
layers.toList.asJava
299-
}
300-
301293
/**
302294
* Predict label for the given features.
303295
* This internal method is used to implement `transform()` and output [[predictionCol]].
@@ -308,7 +300,8 @@ class MultilayerPerceptronClassificationModel private[ml] (
308300

309301
@Since("1.5.0")
310302
override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = {
311-
val copied = new MultilayerPerceptronClassificationModel(uid, layers, weights).setParent(parent)
303+
val copied = new MultilayerPerceptronClassificationModel(uid, weights)
304+
.setParent(parent)
312305
copyValues(copied, extra)
313306
}
314307

@@ -323,11 +316,11 @@ class MultilayerPerceptronClassificationModel private[ml] (
323316
@Since("3.0.0")
324317
override def predictRaw(features: Vector): Vector = mlpModel.predictRaw(features)
325318

326-
override def numClasses: Int = layers.last
319+
override def numClasses: Int = $(layers).last
327320

328321
@Since("3.0.0")
329322
override def toString: String = {
330-
s"MultilayerPerceptronClassificationModel: uid=$uid, numLayers=${layers.length}, " +
323+
s"MultilayerPerceptronClassificationModel: uid=$uid, numLayers=${$(layers).length}, " +
331324
s"numClasses=$numClasses, numFeatures=$numFeatures"
332325
}
333326
}
@@ -348,13 +341,13 @@ object MultilayerPerceptronClassificationModel
348341
class MultilayerPerceptronClassificationModelWriter(
349342
instance: MultilayerPerceptronClassificationModel) extends MLWriter {
350343

351-
private case class Data(layers: Array[Int], weights: Vector)
344+
private case class Data(weights: Vector)
352345

353346
override protected def saveImpl(path: String): Unit = {
354347
// Save metadata and Params
355348
DefaultParamsWriter.saveMetadata(instance, path, sc)
356-
// Save model data: layers, weights
357-
val data = Data(instance.layers, instance.weights)
349+
// Save model data: weights
350+
val data = Data(instance.weights)
358351
val dataPath = new Path(path, "data").toString
359352
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
360353
}
@@ -368,13 +361,21 @@ object MultilayerPerceptronClassificationModel
368361

369362
override def load(path: String): MultilayerPerceptronClassificationModel = {
370363
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
364+
val (majorVersion, _) = majorMinorVersion(metadata.sparkVersion)
371365

372366
val dataPath = new Path(path, "data").toString
373-
val data = sparkSession.read.parquet(dataPath).select("layers", "weights").head()
374-
val layers = data.getAs[Seq[Int]](0).toArray
375-
val weights = data.getAs[Vector](1)
376-
val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights)
377-
367+
val df = sparkSession.read.parquet(dataPath)
368+
val model = if (majorVersion < 3) { // model prior to 3.0.0
369+
val data = df.select("layers", "weights").head()
370+
val layers = data.getAs[Seq[Int]](0).toArray
371+
val weights = data.getAs[Vector](1)
372+
val model = new MultilayerPerceptronClassificationModel(metadata.uid, weights)
373+
model.set("layers", layers)
374+
} else {
375+
val data = df.select("weights").head()
376+
val weights = data.getAs[Vector](0)
377+
new MultilayerPerceptronClassificationModel(metadata.uid, weights)
378+
}
378379
metadata.getAndSetParams(model)
379380
model
380381
}

mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ private[r] class MultilayerPerceptronClassifierWrapper private (
4040
pipeline.stages(1).asInstanceOf[MultilayerPerceptronClassificationModel]
4141

4242
lazy val weights: Array[Double] = mlpModel.weights.toArray
43-
lazy val layers: Array[Int] = mlpModel.layers
43+
lazy val layers: Array[Int] = mlpModel.getLayers
4444

4545
def transform(dataset: Dataset[_]): DataFrame = {
4646
pipeline.transform(dataset)
Binary file not shown.
Binary file not shown.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"class":"org.apache.spark.ml.feature.HashingTF","timestamp":1577833408759,"sparkVersion":"2.4.4","uid":"hashingTF_f4565fe7f7da","paramMap":{"numFeatures":100,"outputCol":"features","inputCol":"words","binary":true},"defaultParamMap":{"numFeatures":262144,"outputCol":"hashingTF_f4565fe7f7da__output","binary":false}}
Binary file not shown.

mllib/src/test/resources/ml-models/mlp-2.4.4/data/_SUCCESS

Whitespace-only changes.
Binary file not shown.
Binary file not shown.

mllib/src/test/resources/ml-models/mlp-2.4.4/metadata/_SUCCESS

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"class":"org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel","timestamp":1577833765310,"sparkVersion":"2.4.4","uid":"mlpc_30aa2f44dacc","paramMap":{},"defaultParamMap":{"rawPredictionCol":"rawPrediction","predictionCol":"prediction","probabilityCol":"probability","labelCol":"label","featuresCol":"features"}}
Binary file not shown.

mllib/src/test/resources/ml-models/strIndexerModel-2.4.4/data/_SUCCESS

Whitespace-only changes.
Binary file not shown.
Binary file not shown.

mllib/src/test/resources/ml-models/strIndexerModel-2.4.4/metadata/_SUCCESS

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"class":"org.apache.spark.ml.feature.StringIndexerModel","timestamp":1577831053235,"sparkVersion":"2.4.4","uid":"myStringIndexerModel","paramMap":{"inputCol":"myInputCol","outputCol":"myOutputCol","handleInvalid":"skip"},"defaultParamMap":{"outputCol":"myStringIndexerModel__output","handleInvalid":"error"}}
Binary file not shown.

mllib/src/test/resources/test-data/hashingTF-pre3.0/metadata/part-00000

Lines changed: 0 additions & 1 deletion
This file was deleted.

mllib/src/test/resources/test-data/strIndexerModel/metadata/part-00000

Lines changed: 0 additions & 1 deletion
This file was deleted.

mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,4 +229,17 @@ class MultilayerPerceptronClassifierSuite extends MLTest with DefaultReadWriteTe
229229
assert(expected.weights === actual.weights)
230230
}
231231
}
232+
233+
test("Load MultilayerPerceptronClassificationModel prior to Spark 3.0") {
234+
val mlpPath = testFile("ml-models/mlp-2.4.4")
235+
val model = MultilayerPerceptronClassificationModel.load(mlpPath)
236+
val layers = model.getLayers
237+
assert(layers(0) === 4)
238+
assert(layers(1) === 5)
239+
assert(layers(2) === 2)
240+
241+
val metadata = spark.read.json(s"$mlpPath/metadata")
242+
val sparkVersionStr = metadata.select("sparkVersion").first().getString(0)
243+
assert(sparkVersionStr == "2.4.4")
244+
}
232245
}

mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class HashingTFSuite extends MLTest with DefaultReadWriteTest {
8989
}
9090

9191
test("SPARK-23469: Load HashingTF prior to Spark 3.0") {
92-
val hashingTFPath = testFile("test-data/hashingTF-pre3.0")
92+
val hashingTFPath = testFile("ml-models/hashingTF-2.4.4")
9393
val loadedHashingTF = HashingTF.load(hashingTFPath)
9494
val mLlibHashingTF = new MLlibHashingTF(100)
9595
assert(loadedHashingTF.indexOf("a") === mLlibHashingTF.indexOf("a"))
@@ -99,7 +99,7 @@ class HashingTFSuite extends MLTest with DefaultReadWriteTest {
9999

100100
val metadata = spark.read.json(s"$hashingTFPath/metadata")
101101
val sparkVersionStr = metadata.select("sparkVersion").first().getString(0)
102-
assert(sparkVersionStr == "2.3.0-SNAPSHOT")
102+
assert(sparkVersionStr == "2.4.4")
103103
}
104104

105105
test("read/write") {

mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,13 +459,13 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest {
459459
}
460460

461461
test("Load StringIndexderModel prior to Spark 3.0") {
462-
val modelPath = testFile("test-data/strIndexerModel")
462+
val modelPath = testFile("ml-models/strIndexerModel-2.4.4")
463463

464464
val loadedModel = StringIndexerModel.load(modelPath)
465465
assert(loadedModel.labelsArray === Array(Array("b", "c", "a")))
466466

467467
val metadata = spark.read.json(s"$modelPath/metadata")
468468
val sparkVersionStr = metadata.select("sparkVersion").first().getString(0)
469-
assert(sparkVersionStr == "2.4.1-SNAPSHOT")
469+
assert(sparkVersionStr == "2.4.4")
470470
}
471471
}

project/MimaExcludes.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,10 @@ object MimaExcludes {
332332
// [SPARK-26457] Show hadoop configurations in HistoryServer environment tab
333333
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationEnvironmentInfo.this"),
334334

335+
// [SPARK-30144][ML] Make MultilayerPerceptronClassificationModel extend MultilayerPerceptronParams
336+
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.layers"),
337+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.this"),
338+
335339
// Data Source V2 API changes
336340
(problem: Problem) => problem match {
337341
case MissingClassProblem(cls) =>

python/pyspark/ml/classification.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2201,7 +2201,9 @@ class MultilayerPerceptronClassifier(JavaProbabilisticClassifier, _MultilayerPer
22012201
>>> model = mlp.fit(df)
22022202
>>> model.setFeaturesCol("features")
22032203
MultilayerPerceptronClassificationModel...
2204-
>>> model.layers
2204+
>>> model.getMaxIter()
2205+
100
2206+
>>> model.getLayers()
22052207
[2, 2, 2]
22062208
>>> model.weights.size
22072209
12
@@ -2230,15 +2232,15 @@ class MultilayerPerceptronClassifier(JavaProbabilisticClassifier, _MultilayerPer
22302232
>>> model_path = temp_path + "/mlp_model"
22312233
>>> model.save(model_path)
22322234
>>> model2 = MultilayerPerceptronClassificationModel.load(model_path)
2233-
>>> model.layers == model2.layers
2235+
>>> model.getLayers() == model2.getLayers()
22342236
True
22352237
>>> model.weights == model2.weights
22362238
True
22372239
>>> mlp2 = mlp2.setInitialWeights(list(range(0, 12)))
22382240
>>> model3 = mlp2.fit(df)
22392241
>>> model3.weights != model2.weights
22402242
True
2241-
>>> model3.layers == model.layers
2243+
>>> model3.getLayers() == model.getLayers()
22422244
True
22432245
22442246
.. versionadded:: 1.6.0
@@ -2334,22 +2336,15 @@ def setSolver(self, value):
23342336
return self._set(solver=value)
23352337

23362338

2337-
class MultilayerPerceptronClassificationModel(JavaProbabilisticClassificationModel, JavaMLWritable,
2339+
class MultilayerPerceptronClassificationModel(JavaProbabilisticClassificationModel,
2340+
_MultilayerPerceptronParams, JavaMLWritable,
23382341
JavaMLReadable):
23392342
"""
23402343
Model fitted by MultilayerPerceptronClassifier.
23412344
23422345
.. versionadded:: 1.6.0
23432346
"""
23442347

2345-
@property
2346-
@since("1.6.0")
2347-
def layers(self):
2348-
"""
2349-
array of layer sizes including input and output layers.
2350-
"""
2351-
return self._call_java("javaLayers")
2352-
23532348
@property
23542349
@since("2.0.0")
23552350
def weights(self):

0 commit comments

Comments
 (0)