Skip to content

Commit cfcb596

Browse files
committed
[SPARK-1406] Throw IllegalArgumentException when exporting a multinomial
logistic regression
1 parent 25dce33 commit cfcb596

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@ private[mllib] object PMMLModelExportFactory {
4444
new GeneralizedLinearPMMLModelExport(svm,
4545
"linear SVM: if predicted value > 0, the outcome is positive, or negative otherwise")
4646
case logistic: LogisticRegressionModel =>
47-
new LogisticRegressionPMMLModelExport(logistic, "logistic regression")
47+
if(logistic.numClasses == 2)
48+
new LogisticRegressionPMMLModelExport(logistic, "logistic regression")
49+
else
50+
throw new IllegalArgumentException(
51+
"PMML Export not supported for Multinomial Logistic Regression")
4852
case _ =>
4953
throw new IllegalArgumentException(
5054
"PMML Export not supported for model: " + model.getClass.getName)

mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,18 @@ class PMMLModelExportFactorySuite extends FunSuite {
7373

7474
assert(logisticRegressionModelExport.isInstanceOf[LogisticRegressionPMMLModelExport])
7575
}
76+
77+
test("PMMLModelExportFactory throw IllegalArgumentException "
78+
+ "when passing a Multinomial Logistic Regression") {
79+
/** 3 classes, 2 features */
80+
val multiclassLogisticRegressionModel = new LogisticRegressionModel(
81+
weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0,
82+
numFeatures = 2, numClasses = 3)
83+
84+
intercept[IllegalArgumentException] {
85+
PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel)
86+
}
87+
}
7688

7789
test("PMMLModelExportFactory throw IllegalArgumentException when passing an unsupported model") {
7890
val invalidModel = new Object

0 commit comments

Comments
 (0)