Skip to content

Commit 65afd3c

Browse files
committed
[SPARK-7474] [MLLIB] update ParamGridBuilder doctest
Multiline commands are properly handled in this PR. oefirouz ![screen shot 2015-05-07 at 10 53 25 pm](https://cloud.githubusercontent.com/assets/829644/7531290/02ad2fd4-f50c-11e4-8c04-e58d1a61ad69.png) Author: Xiangrui Meng <meng@databricks.com> Closes #6001 from mengxr/SPARK-7474 and squashes the following commits: b94b11d [Xiangrui Meng] update ParamGridBuilder doctest
1 parent f5ff4a8 commit 65afd3c

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

python/pyspark/ml/tuning.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,22 @@
2727

2828

2929
class ParamGridBuilder(object):
30-
"""
30+
r"""
3131
Builder for a param grid used in grid search-based model selection.
3232
33-
>>> from classification import LogisticRegression
33+
>>> from pyspark.ml.classification import LogisticRegression
3434
>>> lr = LogisticRegression()
35-
>>> output = ParamGridBuilder().baseOn({lr.labelCol: 'l'}) \
36-
.baseOn([lr.predictionCol, 'p']) \
37-
.addGrid(lr.regParam, [1.0, 2.0, 3.0]) \
38-
.addGrid(lr.maxIter, [1, 5]) \
39-
.addGrid(lr.featuresCol, ['f']) \
40-
.build()
41-
>>> expected = [ \
42-
{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
43-
{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
44-
{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
45-
{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
46-
{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
47-
{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}]
35+
>>> output = ParamGridBuilder() \
36+
... .baseOn({lr.labelCol: 'l'}) \
37+
... .baseOn([lr.predictionCol, 'p']) \
38+
... .addGrid(lr.regParam, [1.0, 2.0]) \
39+
... .addGrid(lr.maxIter, [1, 5]) \
40+
... .build()
41+
>>> expected = [
42+
... {lr.regParam: 1.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'},
43+
... {lr.regParam: 2.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'},
44+
... {lr.regParam: 1.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'},
45+
... {lr.regParam: 2.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}]
4846
>>> len(output) == len(expected)
4947
True
5048
>>> all([m in expected for m in output])

0 commit comments

Comments
 (0)