Skip to content

Commit 376db0a

Browse files
committed
pipeline and parameters
1 parent 5e73138 commit 376db0a

File tree

15 files changed

+940
-1
lines changed

15 files changed

+940
-1
lines changed
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml
19+
20+
import org.apache.spark.sql.SchemaRDD
21+
22+
/**
23+
* Abstract class for estimators that fits models to data.
24+
*/
25+
abstract class Estimator[M <: Model] extends Identifiable with Params with PipelineStage {
26+
27+
/**
28+
* Fits a single model to the input data with default parameters.
29+
*
30+
* @param dataset input dataset
31+
* @return fitted model
32+
*/
33+
def fit(dataset: SchemaRDD): M = {
34+
fit(dataset, ParamMap.empty)
35+
}
36+
37+
/**
38+
* Fits a single model to the input data with provided parameter map.
39+
*
40+
* @param dataset input dataset
41+
* @param paramMap parameters
42+
* @return fitted model
43+
*/
44+
def fit(dataset: SchemaRDD, paramMap: ParamMap): M
45+
46+
/**
47+
* Fits a single model to the input data with provided parameters.
48+
*
49+
* @param dataset input dataset
50+
* @param firstParamPair first parameter
51+
* @param otherParamPairs other parameters
52+
* @return fitted model
53+
*/
54+
def fit(
55+
dataset: SchemaRDD,
56+
firstParamPair: ParamPair[_],
57+
otherParamPairs: ParamPair[_]*): M = {
58+
val map = new ParamMap()
59+
map.put(firstParamPair)
60+
otherParamPairs.foreach(map.put(_))
61+
fit(dataset, map)
62+
}
63+
64+
/**
65+
* Fits multiple models to the input data with multiple sets of parameters.
66+
* The default implementation uses a for loop on each parameter map.
67+
* Subclasses could overwrite this to optimize multi-model training.
68+
*
69+
* @param dataset input dataset
70+
* @param paramMaps an array of parameter maps
71+
* @return fitted models, matching the input parameter maps
72+
*/
73+
def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = {
74+
paramMaps.map(fit(dataset, _))
75+
}
76+
77+
/**
78+
* Parameter for the output model.
79+
*/
80+
def model: Params = Params.empty
81+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml
19+
20+
import org.apache.spark.sql.SchemaRDD
21+
22+
/**
23+
* Abstract class for evaluators that compute metrics from predictions.
24+
*/
25+
abstract class Evaluator extends Identifiable {
26+
27+
/**
28+
* Evaluate the output
29+
* @param dataset a dataset that contains labels/observations and predictions.
30+
* @param paramMap parameter map that specifies the input columns and output metrics
31+
* @return metric
32+
*/
33+
def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double
34+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml
19+
20+
import java.util.UUID
21+
22+
/**
23+
* Something with a unique id.
24+
*/
25+
trait Identifiable extends Serializable {
26+
27+
/**
28+
* A unique id for the object.
29+
*/
30+
val uid: String = this.getClass.getSimpleName + "-" + Identifiable.randomUid
31+
}
32+
33+
object Identifiable {
34+
35+
/**
36+
* Returns a random uid, drawn uniformly from 4+ billion candidates.
37+
*/
38+
private def randomUid: String = UUID.randomUUID().toString.take(8)
39+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package org.apache.spark.ml
2+
3+
abstract class Model extends Transformer {
4+
// def parent: Estimator
5+
// def trainingParameters: ParamMap
6+
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml
19+
20+
import org.apache.spark.sql.SchemaRDD
21+
22+
import scala.collection.mutable.ListBuffer
23+
24+
trait PipelineStage extends Identifiable
25+
26+
/**
27+
* A simple pipeline, which acts as an estimator.
28+
*/
29+
class Pipeline extends Estimator[PipelineModel] {
30+
31+
val stages: Param[Array[PipelineStage]] =
32+
new Param[Array[PipelineStage]](this, "stages", "stages of the pipeline")
33+
34+
override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = {
35+
val theStages = paramMap.apply(stages)
36+
// Search for last estimator.
37+
var lastIndexOfEstimator = -1
38+
theStages.view.zipWithIndex.foreach { case (stage, index) =>
39+
stage match {
40+
case _: Estimator[_] =>
41+
lastIndexOfEstimator = index
42+
case _ =>
43+
}
44+
}
45+
var curDataset = dataset
46+
val transformers = ListBuffer.empty[Transformer]
47+
theStages.view.zipWithIndex.foreach { case (stage, index) =>
48+
stage match {
49+
case estimator: Estimator[_] =>
50+
val transformer = estimator.fit(curDataset, paramMap)
51+
if (index < lastIndexOfEstimator) {
52+
curDataset = transformer.transform(curDataset, paramMap)
53+
}
54+
transformers += transformer
55+
case transformer: Transformer =>
56+
if (index < lastIndexOfEstimator) {
57+
curDataset = transformer.transform(curDataset, paramMap)
58+
}
59+
transformers += transformer
60+
case _ =>
61+
throw new IllegalArgumentException
62+
}
63+
}
64+
65+
new PipelineModel(transformers.toArray)
66+
}
67+
68+
override def params: Array[Param[_]] = Array.empty
69+
}
70+
71+
class PipelineModel(val transformers: Array[Transformer]) extends Model {
72+
73+
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
74+
transformers.foldLeft(dataset) { (dataset, transformer) =>
75+
transformer.transform(dataset, paramMap)
76+
}
77+
}
78+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml
19+
20+
import org.apache.spark.sql.SchemaRDD
21+
22+
/**
23+
* Abstract class for transformers that transform one dataset into another.
24+
*/
25+
abstract class Transformer extends Identifiable with Params with PipelineStage {
26+
27+
/**
28+
* Transforms the dataset with the default parameters.
29+
* @param dataset input dataset
30+
* @return transformed dataset
31+
*/
32+
def transform(dataset: SchemaRDD): SchemaRDD = {
33+
transform(dataset, ParamMap.empty)
34+
}
35+
36+
/**
37+
* Transforms the dataset with provided parameter map.
38+
* @param dataset input dataset
39+
* @param paramMap parameters
40+
* @return transformed dataset
41+
*/
42+
def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD
43+
44+
/**
45+
* Transforms the dataset with provided parameter pairs.
46+
* @param dataset input dataset
47+
* @param firstParamPair first parameter pair
48+
* @param otherParamPairs second parameter pair
49+
* @return transformed dataset
50+
*/
51+
def transform(
52+
dataset: SchemaRDD,
53+
firstParamPair: ParamPair[_],
54+
otherParamPairs: ParamPair[_]*): SchemaRDD = {
55+
val map = new ParamMap()
56+
map.put(firstParamPair)
57+
otherParamPairs.foreach(map.put(_))
58+
transform(dataset, map)
59+
}
60+
61+
/**
62+
* Transforms the dataset with multiple sets of parameters.
63+
* @param dataset input dataset
64+
* @param paramMaps an array of parameter maps
65+
* @return transformed dataset
66+
*/
67+
def transform(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Array[SchemaRDD] = {
68+
paramMaps.map(transform(dataset, _))
69+
}
70+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.example
19+
20+
import org.apache.spark.ml._
21+
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
22+
import org.apache.spark.sql.SchemaRDD
23+
import org.apache.spark.sql.catalyst.expressions.Row
24+
25+
class BinaryClassificationEvaluator extends Evaluator with Params with OwnParamMap {
26+
27+
final val metricName: Param[String] =
28+
new Param(this, "metricName", "evaluation metric: areaUnderROC or areaUnderPR", "areaUnderROC")
29+
30+
final val scoreCol: Param[String] = new Param(this, "scoreCol", "score column name", "score")
31+
32+
final val labelCol: Param[String] = new Param(this, "labelCol", "label column name", "label")
33+
34+
override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = {
35+
import dataset.sqlContext._
36+
val map = this.paramMap ++ paramMap
37+
import map.implicitMapping
38+
val scoreAndLabels = dataset.select((scoreCol: String).attr, (labelCol: String).attr)
39+
.map { case Row(score: Double, label: Double) =>
40+
println(score, label)
41+
(score, label)
42+
}.cache()
43+
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
44+
(metricName: String) match {
45+
case "areaUnderROC" =>
46+
metrics.areaUnderROC()
47+
case "areaUnderPR" =>
48+
metrics.areaUnderPR()
49+
case other =>
50+
throw new IllegalArgumentException(s"Do not support metric $other.")
51+
}
52+
}
53+
}

0 commit comments

Comments
 (0)