Skip to content

Commit e80dc1c

Browse files
committed
[SPARK-4586][MLLIB] Python API for ML pipeline and parameters
This PR adds Python API for ML pipeline and parameters. The design doc can be found on the JIRA page. It includes transformers and an estimator to demo the simple text classification example code. TODO: - [x] handle parameters in LRModel - [x] unit tests - [x] missing some docs CC: davies jkbradley Author: Xiangrui Meng <meng@databricks.com> Author: Davies Liu <davies@databricks.com> Closes #4151 from mengxr/SPARK-4586 and squashes the following commits: 415268e [Xiangrui Meng] remove inherit_doc from __init__ edbd6fe [Xiangrui Meng] move Identifiable to ml.util 44c2405 [Xiangrui Meng] Merge pull request #2 from davies/ml dd1256b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4586 14ae7e2 [Davies Liu] fix docs 54ca7df [Davies Liu] fix tests 78638df [Davies Liu] Merge branch 'SPARK-4586' of github.com:mengxr/spark into ml fc59a02 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4586 1dca16a [Davies Liu] refactor 090b3a3 [Davies Liu] Merge branch 'master' of github.com:apache/spark into ml 0882513 [Xiangrui Meng] update doc style a4f4dbf [Xiangrui Meng] add unit test for LR 7521d1c [Xiangrui Meng] add unit tests to HashingTF and Tokenizer ba0ba1e [Xiangrui Meng] add unit tests for pipeline 0586c7b [Xiangrui Meng] add more comments to the example 5153cff [Xiangrui Meng] simplify java models 036ca04 [Xiangrui Meng] gen numFeatures 46fa147 [Xiangrui Meng] update mllib/pom.xml to include python files in the assembly 1dcc17e [Xiangrui Meng] update code gen and make param appear in the doc f66ba0c [Xiangrui Meng] make params a property d5efd34 [Xiangrui Meng] update doc conf and move embedded param map to instance attribute f4d0fe6 [Xiangrui Meng] use LabeledDocument and Document in example 05e3e40 [Xiangrui Meng] update example d3e8dbe [Xiangrui Meng] more docs optimize pipeline.fit impl 56de571 [Xiangrui Meng] fix style d0c5bb8 [Xiangrui Meng] a working copy bce72f4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-4586 17ecfb9 [Xiangrui Meng] code gen for shared params d9ea77c [Xiangrui Meng] update doc c18dca1 [Xiangrui Meng] make the example working dadd84e [Xiangrui Meng] add base classes and docs a3015cf [Xiangrui Meng] add Estimator and Transformer 46eea43 [Xiangrui Meng] a pipeline in python 33b68e0 [Xiangrui Meng] a working LR
1 parent e023112 commit e80dc1c

File tree

19 files changed

+1212
-17
lines changed

19 files changed

+1212
-17
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from pyspark import SparkContext
19+
from pyspark.sql import SQLContext, Row
20+
from pyspark.ml import Pipeline
21+
from pyspark.ml.feature import HashingTF, Tokenizer
22+
from pyspark.ml.classification import LogisticRegression
23+
24+
25+
"""
26+
A simple text classification pipeline that recognizes "spark" from
27+
input text. This is to show how to create and configure a Spark ML
28+
pipeline in Python. Run with:
29+
30+
bin/spark-submit examples/src/main/python/ml/simple_text_classification_pipeline.py
31+
"""
32+
33+
34+
if __name__ == "__main__":
35+
sc = SparkContext(appName="SimpleTextClassificationPipeline")
36+
sqlCtx = SQLContext(sc)
37+
38+
# Prepare training documents, which are labeled.
39+
LabeledDocument = Row('id', 'text', 'label')
40+
training = sqlCtx.inferSchema(
41+
sc.parallelize([(0L, "a b c d e spark", 1.0),
42+
(1L, "b d", 0.0),
43+
(2L, "spark f g h", 1.0),
44+
(3L, "hadoop mapreduce", 0.0)])
45+
.map(lambda x: LabeledDocument(*x)))
46+
47+
# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr.
48+
tokenizer = Tokenizer() \
49+
.setInputCol("text") \
50+
.setOutputCol("words")
51+
hashingTF = HashingTF() \
52+
.setInputCol(tokenizer.getOutputCol()) \
53+
.setOutputCol("features")
54+
lr = LogisticRegression() \
55+
.setMaxIter(10) \
56+
.setRegParam(0.01)
57+
pipeline = Pipeline() \
58+
.setStages([tokenizer, hashingTF, lr])
59+
60+
# Fit the pipeline to training documents.
61+
model = pipeline.fit(training)
62+
63+
# Prepare test documents, which are unlabeled.
64+
Document = Row('id', 'text')
65+
test = sqlCtx.inferSchema(
66+
sc.parallelize([(4L, "spark i j k"),
67+
(5L, "l m n"),
68+
(6L, "mapreduce spark"),
69+
(7L, "apache hadoop")])
70+
.map(lambda x: Document(*x)))
71+
72+
# Make predictions on test documents and print columns of interest.
73+
prediction = model.transform(test)
74+
prediction.registerTempTable("prediction")
75+
selected = sqlCtx.sql("SELECT id, text, prediction from prediction")
76+
for row in selected.collect():
77+
print row
78+
79+
sc.stop()

mllib/pom.xml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@
125125
<directory>../python</directory>
126126
<includes>
127127
<include>pyspark/mllib/*.py</include>
128+
<include>pyspark/ml/*.py</include>
129+
<include>pyspark/ml/param/*.py</include>
128130
</includes>
129131
</resource>
130132
</resources>

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,13 @@ trait Params extends Identifiable with Serializable {
164164
this
165165
}
166166

167+
/**
168+
* Sets a parameter (by name) in the embedded param map.
169+
*/
170+
private[ml] def set(param: String, value: Any): this.type = {
171+
set(getParam(param), value)
172+
}
173+
167174
/**
168175
* Gets the value of a parameter in the embedded param map.
169176
*/
@@ -286,7 +293,6 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
286293
new ParamMap(this.map ++ other.map)
287294
}
288295

289-
290296
/**
291297
* Adds all parameters from the input param map into this param map.
292298
*/

python/docs/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@
5555
# built documents.
5656
#
5757
# The short X.Y version.
58-
version = '1.2-SNAPSHOT'
58+
version = '1.3-SNAPSHOT'
5959
# The full version, including alpha/beta/rc tags.
60-
release = '1.2-SNAPSHOT'
60+
release = '1.3-SNAPSHOT'
6161

6262
# The language for content autogenerated by Sphinx. Refer to documentation
6363
# for a list of supported languages.

python/docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Contents:
1414
pyspark
1515
pyspark.sql
1616
pyspark.streaming
17+
pyspark.ml
1718
pyspark.mllib
1819

1920

python/docs/pyspark.ml.rst

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
pyspark.ml package
2+
=====================
3+
4+
Submodules
5+
----------
6+
7+
pyspark.ml module
8+
-----------------
9+
10+
.. automodule:: pyspark.ml
11+
:members:
12+
:undoc-members:
13+
:inherited-members:
14+
15+
pyspark.ml.feature module
16+
-------------------------
17+
18+
.. automodule:: pyspark.ml.feature
19+
:members:
20+
:undoc-members:
21+
:inherited-members:
22+
23+
pyspark.ml.classification module
24+
--------------------------------
25+
26+
.. automodule:: pyspark.ml.classification
27+
:members:
28+
:undoc-members:
29+
:inherited-members:

python/docs/pyspark.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Subpackages
99

1010
pyspark.sql
1111
pyspark.streaming
12+
pyspark.ml
1213
pyspark.mllib
1314

1415
Contents

python/pyspark/ml/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from pyspark.ml.param import *
19+
from pyspark.ml.pipeline import *
20+
21+
__all__ = ["Param", "Params", "Transformer", "Estimator", "Pipeline"]

python/pyspark/ml/classification.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from pyspark.ml.util import inherit_doc
19+
from pyspark.ml.wrapper import JavaEstimator, JavaModel
20+
from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\
21+
HasRegParam
22+
23+
24+
__all__ = ['LogisticRegression', 'LogisticRegressionModel']
25+
26+
27+
@inherit_doc
28+
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
29+
HasRegParam):
30+
"""
31+
Logistic regression.
32+
33+
>>> from pyspark.sql import Row
34+
>>> from pyspark.mllib.linalg import Vectors
35+
>>> dataset = sqlCtx.inferSchema(sc.parallelize([ \
36+
Row(label=1.0, features=Vectors.dense(1.0)), \
37+
Row(label=0.0, features=Vectors.sparse(1, [], []))]))
38+
>>> lr = LogisticRegression() \
39+
.setMaxIter(5) \
40+
.setRegParam(0.01)
41+
>>> model = lr.fit(dataset)
42+
>>> test0 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.dense(-1.0))]))
43+
>>> print model.transform(test0).head().prediction
44+
0.0
45+
>>> test1 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]))
46+
>>> print model.transform(test1).head().prediction
47+
1.0
48+
"""
49+
_java_class = "org.apache.spark.ml.classification.LogisticRegression"
50+
51+
def _create_model(self, java_model):
52+
return LogisticRegressionModel(java_model)
53+
54+
55+
class LogisticRegressionModel(JavaModel):
56+
"""
57+
Model fitted by LogisticRegression.
58+
"""
59+
60+
61+
if __name__ == "__main__":
62+
import doctest
63+
from pyspark.context import SparkContext
64+
from pyspark.sql import SQLContext
65+
globs = globals().copy()
66+
# The small batch size here ensures that we see multiple batches,
67+
# even in these small test examples:
68+
sc = SparkContext("local[2]", "ml.feature tests")
69+
sqlCtx = SQLContext(sc)
70+
globs['sc'] = sc
71+
globs['sqlCtx'] = sqlCtx
72+
(failure_count, test_count) = doctest.testmod(
73+
globs=globs, optionflags=doctest.ELLIPSIS)
74+
sc.stop()
75+
if failure_count:
76+
exit(-1)

python/pyspark/ml/feature.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures
19+
from pyspark.ml.util import inherit_doc
20+
from pyspark.ml.wrapper import JavaTransformer
21+
22+
__all__ = ['Tokenizer', 'HashingTF']
23+
24+
25+
@inherit_doc
26+
class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
27+
"""
28+
A tokenizer that converts the input string to lowercase and then
29+
splits it by white spaces.
30+
31+
>>> from pyspark.sql import Row
32+
>>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(text="a b c")]))
33+
>>> tokenizer = Tokenizer() \
34+
.setInputCol("text") \
35+
.setOutputCol("words")
36+
>>> print tokenizer.transform(dataset).head()
37+
Row(text=u'a b c', words=[u'a', u'b', u'c'])
38+
>>> print tokenizer.transform(dataset, {tokenizer.outputCol: "tokens"}).head()
39+
Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
40+
"""
41+
42+
_java_class = "org.apache.spark.ml.feature.Tokenizer"
43+
44+
45+
@inherit_doc
46+
class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
47+
"""
48+
Maps a sequence of terms to their term frequencies using the
49+
hashing trick.
50+
51+
>>> from pyspark.sql import Row
52+
>>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(words=["a", "b", "c"])]))
53+
>>> hashingTF = HashingTF() \
54+
.setNumFeatures(10) \
55+
.setInputCol("words") \
56+
.setOutputCol("features")
57+
>>> print hashingTF.transform(dataset).head().features
58+
(10,[7,8,9],[1.0,1.0,1.0])
59+
>>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"}
60+
>>> print hashingTF.transform(dataset, params).head().vector
61+
(5,[2,3,4],[1.0,1.0,1.0])
62+
"""
63+
64+
_java_class = "org.apache.spark.ml.feature.HashingTF"
65+
66+
67+
if __name__ == "__main__":
68+
import doctest
69+
from pyspark.context import SparkContext
70+
from pyspark.sql import SQLContext
71+
globs = globals().copy()
72+
# The small batch size here ensures that we see multiple batches,
73+
# even in these small test examples:
74+
sc = SparkContext("local[2]", "ml.feature tests")
75+
sqlCtx = SQLContext(sc)
76+
globs['sc'] = sc
77+
globs['sqlCtx'] = sqlCtx
78+
(failure_count, test_count) = doctest.testmod(
79+
globs=globs, optionflags=doctest.ELLIPSIS)
80+
sc.stop()
81+
if failure_count:
82+
exit(-1)

0 commit comments

Comments
 (0)