Skip to content

Commit 3c22f79

Browse files
committed
more code style
1 parent e2313df commit 3c22f79

File tree

7 files changed

+73
-83
lines changed

7 files changed

+73
-83
lines changed

mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717

1818
package org.apache.spark.mllib.pmml
1919

20-
import java.io.File
21-
import java.io.OutputStream
22-
import java.io.StringWriter
20+
import java.io.{File, OutputStream, StringWriter}
2321
import javax.xml.transform.stream.StreamResult
2422

2523
import org.jpmml.model.JAXBUtil
@@ -33,37 +31,37 @@ import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory
3331
* developed by the Data Mining Group (www.dmg.org).
3432
*/
3533
trait PMMLExportable {
36-
34+
3735
/**
3836
* Export the model to the stream result in PMML format
3937
*/
4038
private def toPMML(streamResult: StreamResult): Unit = {
4139
val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this)
4240
JAXBUtil.marshalPMML(pmmlModelExport.getPmml, streamResult)
4341
}
44-
42+
4543
/**
4644
* Export the model to a local file in PMML format
4745
*/
4846
def toPMML(localPath: String): Unit = {
4947
toPMML(new StreamResult(new File(localPath)))
5048
}
51-
49+
5250
/**
5351
* Export the model to a directory on a distributed file system in PMML format
5452
*/
5553
def toPMML(sc: SparkContext, path: String): Unit = {
5654
val pmml = toPMML()
5755
sc.parallelize(Array(pmml), 1).saveAsTextFile(path)
5856
}
59-
57+
6058
/**
6159
* Export the model to the OutputStream in PMML format
6260
*/
6361
def toPMML(outputStream: OutputStream): Unit = {
6462
toPMML(new StreamResult(outputStream))
6563
}
66-
64+
6765
/**
6866
* Export the model to a String in PMML format
6967
*/
@@ -72,5 +70,5 @@ trait PMMLExportable {
7270
toPMML(new StreamResult(writer))
7371
writer.toString
7472
}
75-
73+
7674
}

mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,47 +27,47 @@ import org.apache.spark.mllib.regression.GeneralizedLinearModel
2727
* PMML Model Export for GeneralizedLinearModel abstract class
2828
*/
2929
private[mllib] class GeneralizedLinearPMMLModelExport(
30-
model : GeneralizedLinearModel,
31-
description : String)
32-
extends PMMLModelExport{
30+
model: GeneralizedLinearModel,
31+
description: String)
32+
extends PMMLModelExport {
3333

3434
populateGeneralizedLinearPMML(model)
3535

3636
/**
3737
* Export the input GeneralizedLinearModel model to PMML format.
3838
*/
3939
private def populateGeneralizedLinearPMML(model: GeneralizedLinearModel): Unit = {
40-
pmml.getHeader.setDescription(description)
40+
pmml.getHeader.setDescription(description)
4141

42-
if(model.weights.size > 0){
43-
val fields = new SArray[FieldName](model.weights.size)
44-
val dataDictionary = new DataDictionary
45-
val miningSchema = new MiningSchema
46-
val regressionTable = new RegressionTable(model.intercept)
47-
val regressionModel = new RegressionModel(miningSchema,MiningFunctionType.REGRESSION)
48-
.withModelName(description)
49-
.withRegressionTables(regressionTable)
42+
if (model.weights.size > 0) {
43+
val fields = new SArray[FieldName](model.weights.size)
44+
val dataDictionary = new DataDictionary
45+
val miningSchema = new MiningSchema
46+
val regressionTable = new RegressionTable(model.intercept)
47+
val regressionModel = new RegressionModel(miningSchema, MiningFunctionType.REGRESSION)
48+
.withModelName(description)
49+
.withRegressionTables(regressionTable)
5050

51-
for (i <- 0 until model.weights.size) {
52-
fields(i) = FieldName.create("field_" + i)
53-
dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
54-
miningSchema
55-
.withMiningFields(new MiningField(fields(i))
56-
.withUsageType(FieldUsageType.ACTIVE))
57-
regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
58-
}
59-
60-
// for completeness add target field
61-
val targetField = FieldName.create("target")
62-
dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE))
63-
miningSchema
64-
.withMiningFields(new MiningField(targetField)
65-
.withUsageType(FieldUsageType.TARGET))
51+
for (i <- 0 until model.weights.size) {
52+
fields(i) = FieldName.create("field_" + i)
53+
dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
54+
miningSchema
55+
.withMiningFields(new MiningField(fields(i))
56+
.withUsageType(FieldUsageType.ACTIVE))
57+
regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
58+
}
6659

67-
dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)
60+
// for completeness add target field
61+
val targetField = FieldName.create("target")
62+
dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE))
63+
miningSchema
64+
.withMiningFields(new MiningField(targetField)
65+
.withUsageType(FieldUsageType.TARGET))
6866

69-
pmml.setDataDictionary(dataDictionary)
70-
pmml.withModels(regressionModel)
71-
}
67+
dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)
68+
69+
pmml.setDataDictionary(dataDictionary)
70+
pmml.withModels(regressionModel)
71+
}
7272
}
7373
}

mllib/src/main/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExport.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionModel
2929
private[mllib] class LogisticRegressionPMMLModelExport(
3030
model : LogisticRegressionModel,
3131
description : String)
32-
extends PMMLModelExport{
32+
extends PMMLModelExport {
3333

3434
populateLogisticRegressionPMML(model)
3535

mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ import org.dmg.pmml.RegressionModel
2121
import org.scalatest.FunSuite
2222

2323
import org.apache.spark.mllib.classification.SVMModel
24-
import org.apache.spark.mllib.regression.LassoModel
25-
import org.apache.spark.mllib.regression.LinearRegressionModel
26-
import org.apache.spark.mllib.regression.RidgeRegressionModel
24+
import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
2725
import org.apache.spark.mllib.util.LinearDataGenerator
2826

2927
class GeneralizedLinearPMMLModelExportSuite extends FunSuite {
@@ -87,7 +85,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite {
8785
test("svm pmml export") {
8886
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
8987
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
90-
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
88+
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
9189
// assert that the PMML format is as expected
9290
assert(svmModelExport.isInstanceOf[PMMLModelExport])
9391
val pmml = svmModelExport.getPmml

mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,14 @@ import org.apache.spark.mllib.linalg.Vectors
2626
class KMeansPMMLModelExportSuite extends FunSuite {
2727

2828
test("KMeansPMMLModelExport generate PMML format") {
29-
// arrange model to test
3029
val clusterCenters = Array(
3130
Vectors.dense(1.0, 2.0, 6.0),
3231
Vectors.dense(1.0, 3.0, 0.0),
3332
Vectors.dense(1.0, 4.0, 6.0))
3433
val kmeansModel = new KMeansModel(clusterCenters)
35-
34+
3635
val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel)
37-
36+
3837
// assert that the PMML format is as expected
3938
assert(modelExport.isInstanceOf[PMMLModelExport])
4039
val pmml = modelExport.asInstanceOf[PMMLModelExport].getPmml

mllib/src/test/scala/org/apache/spark/mllib/pmml/export/LogisticRegressionPMMLModelExportSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ import org.scalatest.FunSuite
2323
import org.apache.spark.mllib.classification.LogisticRegressionModel
2424
import org.apache.spark.mllib.util.LinearDataGenerator
2525

26-
class LogisticRegressionPMMLModelExportSuite extends FunSuite{
26+
class LogisticRegressionPMMLModelExportSuite extends FunSuite {
2727

2828
test("LogisticRegressionPMMLModelExport generate PMML format") {
2929
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
3030
val logisticRegressionModel =
3131
new LogisticRegressionModel(linearInput(0).features, linearInput(0).label)
32-
32+
3333
val logisticModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel)
3434

3535
// assert that the PMML format is as expected
@@ -48,5 +48,5 @@ class LogisticRegressionPMMLModelExportSuite extends FunSuite{
4848
// verify if there is a second table with target category 0 and no predictors
4949
assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0")
5050
assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0)
51-
}
51+
}
5252
}

mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,10 @@ package org.apache.spark.mllib.pmml.export
1919

2020
import org.scalatest.FunSuite
2121

22-
import org.apache.spark.mllib.classification.LogisticRegressionModel
23-
import org.apache.spark.mllib.classification.SVMModel
22+
import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel}
2423
import org.apache.spark.mllib.clustering.KMeansModel
2524
import org.apache.spark.mllib.linalg.Vectors
26-
import org.apache.spark.mllib.regression.LassoModel
27-
import org.apache.spark.mllib.regression.LinearRegressionModel
28-
import org.apache.spark.mllib.regression.RidgeRegressionModel
25+
import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel}
2926
import org.apache.spark.mllib.util.LinearDataGenerator
3027

3128
class PMMLModelExportFactorySuite extends FunSuite {
@@ -38,33 +35,32 @@ class PMMLModelExportFactorySuite extends FunSuite {
3835
val kmeansModel = new KMeansModel(clusterCenters)
3936

4037
val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel)
41-
38+
4239
assert(modelExport.isInstanceOf[KMeansPMMLModelExport])
43-
}
44-
45-
test("PMMLModelExportFactory create GeneralizedLinearPMMLModelExport when passing a "
46-
+ "LinearRegressionModel, RidgeRegressionModel, LassoModel or SVMModel") {
47-
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
40+
}
4841

49-
val linearRegressionModel =
50-
new LinearRegressionModel(linearInput(0).features, linearInput(0).label)
51-
val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel)
52-
assert(linearModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
42+
test("PMMLModelExportFactory create GeneralizedLinearPMMLModelExport when passing a "
43+
+ "LinearRegressionModel, RidgeRegressionModel, LassoModel or SVMModel") {
44+
val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17)
5345

54-
val ridgeRegressionModel =
55-
new RidgeRegressionModel(linearInput(0).features, linearInput(0).label)
56-
val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel)
57-
assert(ridgeModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
46+
val linearRegressionModel =
47+
new LinearRegressionModel(linearInput(0).features, linearInput(0).label)
48+
val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel)
49+
assert(linearModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
5850

51+
val ridgeRegressionModel =
52+
new RidgeRegressionModel(linearInput(0).features, linearInput(0).label)
53+
val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel)
54+
assert(ridgeModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
5955

60-
val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label)
61-
val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel)
62-
assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
56+
val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label)
57+
val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel)
58+
assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
6359

64-
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
65-
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
66-
assert(svmModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
67-
}
60+
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label)
61+
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
62+
assert(svmModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport])
63+
}
6864

6965
test("PMMLModelExportFactory create LogisticRegressionPMMLModelExport "
7066
+ "when passing a LogisticRegressionModel") {
@@ -76,14 +72,13 @@ class PMMLModelExportFactorySuite extends FunSuite {
7672
PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel)
7773

7874
assert(logisticRegressionModelExport.isInstanceOf[LogisticRegressionPMMLModelExport])
79-
}
80-
81-
test("PMMLModelExportFactory throw IllegalArgumentException "
82-
+ "when passing an unsupported model") {
75+
}
76+
77+
test("PMMLModelExportFactory throw IllegalArgumentException when passing an unsupported model") {
8378
val invalidModel = new Object
84-
79+
8580
intercept[IllegalArgumentException] {
8681
PMMLModelExportFactory.createPMMLModelExport(invalidModel)
8782
}
88-
}
83+
}
8984
}

0 commit comments

Comments
 (0)