|
27 | 27 |
|
28 | 28 |
|
29 | 29 | class ParamGridBuilder(object):
|
30 |
| - """ |
| 30 | + r""" |
31 | 31 | Builder for a param grid used in grid search-based model selection.
|
32 | 32 |
|
33 |
| - >>> from classification import LogisticRegression |
| 33 | + >>> from pyspark.ml.classification import LogisticRegression |
34 | 34 | >>> lr = LogisticRegression()
|
35 |
| - >>> output = ParamGridBuilder().baseOn({lr.labelCol: 'l'}) \ |
36 |
| - .baseOn([lr.predictionCol, 'p']) \ |
37 |
| - .addGrid(lr.regParam, [1.0, 2.0, 3.0]) \ |
38 |
| - .addGrid(lr.maxIter, [1, 5]) \ |
39 |
| - .addGrid(lr.featuresCol, ['f']) \ |
40 |
| - .build() |
41 |
| - >>> expected = [ \ |
42 |
| -{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ |
43 |
| -{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ |
44 |
| -{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ |
45 |
| -{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ |
46 |
| -{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ |
47 |
| -{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}] |
| 35 | + >>> output = ParamGridBuilder() \ |
| 36 | + ... .baseOn({lr.labelCol: 'l'}) \ |
| 37 | + ... .baseOn([lr.predictionCol, 'p']) \ |
| 38 | + ... .addGrid(lr.regParam, [1.0, 2.0]) \ |
| 39 | + ... .addGrid(lr.maxIter, [1, 5]) \ |
| 40 | + ... .build() |
| 41 | + >>> expected = [ |
| 42 | + ... {lr.regParam: 1.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, |
| 43 | + ... {lr.regParam: 2.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, |
| 44 | + ... {lr.regParam: 1.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, |
| 45 | + ... {lr.regParam: 2.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}] |
48 | 46 | >>> len(output) == len(expected)
|
49 | 47 | True
|
50 | 48 | >>> all([m in expected for m in output])
|
|
0 commit comments