Skip to content

Commit e2313df

Browse files
committed
Merge pull request #1 from mengxr/SPARK-1406
Update code style
2 parents 1676e15 + 472d757 commit e2313df

10 files changed

+273
-405
lines changed

mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ import java.io.File
2121
import java.io.OutputStream
2222
import java.io.StringWriter
2323
import javax.xml.transform.stream.StreamResult
24+
2425
import org.jpmml.model.JAXBUtil
26+
2527
import org.apache.spark.SparkContext
26-
import org.apache.spark.mllib.pmml.export.PMMLModelExport
2728
import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory
2829

2930
/**
@@ -34,42 +35,42 @@ import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory
3435
trait PMMLExportable {
3536

3637
/**
37-
* Export the model to the stream result in PMML format
38-
*/
38+
* Export the model to the stream result in PMML format
39+
*/
3940
private def toPMML(streamResult: StreamResult): Unit = {
4041
val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this)
41-
JAXBUtil.marshalPMML(pmmlModelExport.getPmml(), streamResult)
42+
JAXBUtil.marshalPMML(pmmlModelExport.getPmml, streamResult)
4243
}
4344

4445
/**
45-
* Export the model to a local File in PMML format
46-
*/
46+
* Export the model to a local file in PMML format
47+
*/
4748
def toPMML(localPath: String): Unit = {
4849
toPMML(new StreamResult(new File(localPath)))
4950
}
5051

5152
/**
52-
* Export the model to a distributed file in PMML format
53-
*/
53+
* Export the model to a directory on a distributed file system in PMML format
54+
*/
5455
def toPMML(sc: SparkContext, path: String): Unit = {
5556
val pmml = toPMML()
56-
sc.parallelize(Array(pmml),1).saveAsTextFile(path)
57+
sc.parallelize(Array(pmml), 1).saveAsTextFile(path)
5758
}
5859

5960
/**
60-
* Export the model to the Outputtream in PMML format
61-
*/
61+
* Export the model to the OutputStream in PMML format
62+
*/
6263
def toPMML(outputStream: OutputStream): Unit = {
6364
toPMML(new StreamResult(outputStream))
6465
}
6566

6667
/**
67-
* Export the model to a String in PMML format
68-
*/
68+
* Export the model to a String in PMML format
69+
*/
6970
def toPMML(): String = {
70-
var writer = new StringWriter();
71+
val writer = new StringWriter
7172
toPMML(new StreamResult(writer))
72-
return writer.toString();
73+
writer.toString
7374
}
7475

7576
}

mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala

Lines changed: 25 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,10 @@
1717

1818
package org.apache.spark.mllib.pmml.export
1919

20-
import org.dmg.pmml.DataDictionary
21-
import org.dmg.pmml.DataField
22-
import org.dmg.pmml.DataType
23-
import org.dmg.pmml.FieldName
24-
import org.dmg.pmml.FieldUsageType
25-
import org.dmg.pmml.MiningField
26-
import org.dmg.pmml.MiningFunctionType
27-
import org.dmg.pmml.MiningSchema
28-
import org.dmg.pmml.NumericPredictor
29-
import org.dmg.pmml.OpType
30-
import org.dmg.pmml.RegressionModel
31-
import org.dmg.pmml.RegressionTable
20+
import scala.{Array => SArray}
21+
22+
import org.dmg.pmml._
23+
3224
import org.apache.spark.mllib.regression.GeneralizedLinearModel
3325

3426
/**
@@ -39,55 +31,43 @@ private[mllib] class GeneralizedLinearPMMLModelExport(
3931
description : String)
4032
extends PMMLModelExport{
4133

34+
populateGeneralizedLinearPMML(model)
35+
4236
/**
43-
* Export the input GeneralizedLinearModel model to PMML format
37+
* Export the input GeneralizedLinearModel model to PMML format.
4438
*/
45-
populateGeneralizedLinearPMML(model)
46-
47-
private def populateGeneralizedLinearPMML(model : GeneralizedLinearModel): Unit = {
39+
private def populateGeneralizedLinearPMML(model: GeneralizedLinearModel): Unit = {
40+
pmml.getHeader.setDescription(description)
4841

49-
pmml.getHeader().setDescription(description)
50-
5142
if(model.weights.size > 0){
52-
53-
val fields = new Array[FieldName](model.weights.size)
54-
55-
val dataDictionary = new DataDictionary()
56-
57-
val miningSchema = new MiningSchema()
58-
43+
val fields = new SArray[FieldName](model.weights.size)
44+
val dataDictionary = new DataDictionary
45+
val miningSchema = new MiningSchema
5946
val regressionTable = new RegressionTable(model.intercept)
60-
6147
val regressionModel = new RegressionModel(miningSchema,MiningFunctionType.REGRESSION)
62-
.withModelName(description).withRegressionTables(regressionTable)
63-
64-
for ( i <- 0 until model.weights.size) {
48+
.withModelName(description)
49+
.withRegressionTables(regressionTable)
50+
51+
for (i <- 0 until model.weights.size) {
6552
fields(i) = FieldName.create("field_" + i)
66-
dataDictionary
67-
.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
53+
dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
6854
miningSchema
69-
.withMiningFields(new MiningField(fields(i))
70-
.withUsageType(FieldUsageType.ACTIVE))
55+
.withMiningFields(new MiningField(fields(i))
56+
.withUsageType(FieldUsageType.ACTIVE))
7157
regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i)))
7258
}
7359

7460
// for completeness add target field
75-
val targetField = FieldName.create("target");
76-
dataDictionary
77-
.withDataFields(
78-
new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)
79-
)
80-
miningSchema
61+
val targetField = FieldName.create("target")
62+
dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE))
63+
miningSchema
8164
.withMiningFields(new MiningField(targetField)
8265
.withUsageType(FieldUsageType.TARGET))
83-
84-
dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size())
85-
66+
67+
dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)
68+
8669
pmml.setDataDictionary(dataDictionary)
8770
pmml.withModels(regressionModel)
88-
8971
}
90-
9172
}
92-
9373
}

mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala

Lines changed: 48 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -17,90 +17,64 @@
1717

1818
package org.apache.spark.mllib.pmml.export
1919

20-
import org.dmg.pmml.Array.Type
21-
import org.dmg.pmml.Cluster
22-
import org.dmg.pmml.ClusteringField
23-
import org.dmg.pmml.ClusteringModel
24-
import org.dmg.pmml.ClusteringModel.ModelClass
25-
import org.dmg.pmml.CompareFunctionType
26-
import org.dmg.pmml.ComparisonMeasure
27-
import org.dmg.pmml.ComparisonMeasure.Kind
28-
import org.dmg.pmml.DataDictionary
29-
import org.dmg.pmml.DataField
30-
import org.dmg.pmml.DataType
31-
import org.dmg.pmml.FieldName
32-
import org.dmg.pmml.FieldUsageType
33-
import org.dmg.pmml.MiningField
34-
import org.dmg.pmml.MiningFunctionType
35-
import org.dmg.pmml.MiningSchema
36-
import org.dmg.pmml.OpType
37-
import org.dmg.pmml.SquaredEuclidean
20+
import scala.{Array => SArray}
21+
22+
import org.dmg.pmml._
23+
3824
import org.apache.spark.mllib.clustering.KMeansModel
3925

4026
/**
4127
* PMML Model Export for KMeansModel class
4228
*/
4329
private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{
4430

31+
populateKMeansPMML(model)
32+
4533
/**
46-
* Export the input KMeansModel model to PMML format
34+
* Export the input KMeansModel model to PMML format.
4735
*/
48-
populateKMeansPMML(model)
49-
5036
private def populateKMeansPMML(model : KMeansModel): Unit = {
51-
52-
pmml.getHeader().setDescription("k-means clustering")
53-
54-
if(model.clusterCenters.length > 0){
55-
56-
val clusterCenter = model.clusterCenters(0)
57-
58-
val fields = new Array[FieldName](clusterCenter.size)
59-
60-
val dataDictionary = new DataDictionary()
61-
62-
val miningSchema = new MiningSchema()
63-
64-
val comparisonMeasure = new ComparisonMeasure()
65-
.withKind(Kind.DISTANCE)
66-
.withMeasure(new SquaredEuclidean()
67-
)
68-
69-
val clusteringModel = new ClusteringModel(miningSchema, comparisonMeasure,
70-
MiningFunctionType.CLUSTERING, ModelClass.CENTER_BASED, model.clusterCenters.length)
37+
pmml.getHeader.setDescription("k-means clustering")
38+
39+
if (model.clusterCenters.length > 0) {
40+
val clusterCenter = model.clusterCenters(0)
41+
val fields = new SArray[FieldName](clusterCenter.size)
42+
val dataDictionary = new DataDictionary
43+
val miningSchema = new MiningSchema
44+
val comparisonMeasure = new ComparisonMeasure()
45+
.withKind(ComparisonMeasure.Kind.DISTANCE)
46+
.withMeasure(new SquaredEuclidean())
47+
val clusteringModel = new ClusteringModel(miningSchema, comparisonMeasure,
48+
MiningFunctionType.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED,
49+
model.clusterCenters.length)
7150
.withModelName("k-means")
72-
73-
for ( i <- 0 until clusterCenter.size) {
74-
fields(i) = FieldName.create("field_" + i)
75-
dataDictionary
76-
.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
77-
miningSchema
78-
.withMiningFields(new MiningField(fields(i))
79-
.withUsageType(FieldUsageType.ACTIVE))
80-
clusteringModel.withClusteringFields(
81-
new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF)
82-
)
83-
}
84-
85-
dataDictionary.withNumberOfFields((dataDictionary.getDataFields()).size())
86-
87-
for ( i <- 0 until model.clusterCenters.size ) {
88-
val cluster = new Cluster()
89-
.withName("cluster_" + i)
90-
.withArray(new org.dmg.pmml.Array()
91-
.withType(Type.REAL)
92-
.withN(clusterCenter.size)
93-
.withValue(model.clusterCenters(i).toArray.mkString(" ")))
94-
// we don't have the size of the single cluster but only the centroids (withValue)
95-
// .withSize(value)
96-
clusteringModel.withClusters(cluster)
97-
}
98-
99-
pmml.setDataDictionary(dataDictionary)
100-
pmml.withModels(clusteringModel)
101-
102-
}
103-
51+
52+
for (i <- 0 until clusterCenter.size) {
53+
fields(i) = FieldName.create("field_" + i)
54+
dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE))
55+
miningSchema
56+
.withMiningFields(new MiningField(fields(i))
57+
.withUsageType(FieldUsageType.ACTIVE))
58+
clusteringModel.withClusteringFields(
59+
new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF))
60+
}
61+
62+
dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size)
63+
64+
for (i <- 0 until model.clusterCenters.length) {
65+
val cluster = new Cluster()
66+
.withName("cluster_" + i)
67+
.withArray(new org.dmg.pmml.Array()
68+
.withType(Array.Type.REAL)
69+
.withN(clusterCenter.size)
70+
.withValue(model.clusterCenters(i).toArray.mkString(" ")))
71+
// we don't have the size of the single cluster but only the centroids (withValue)
72+
// .withSize(value)
73+
clusteringModel.withClusters(cluster)
74+
}
75+
76+
pmml.setDataDictionary(dataDictionary)
77+
pmml.withModels(clusteringModel)
78+
}
10479
}
105-
10680
}

0 commit comments

Comments
 (0)