Skip to content

Commit 9434280

Browse files
MrBagojkbradley
authored andcommitted
[SPARK-20861][ML][PYTHON] Delegate looping over paramMaps to estimators
Changes: pyspark.ml Estimators can take either a list of param maps or a dict of params. This change allows the CrossValidator and TrainValidationSplit Estimators to pass through lists of param maps to the underlying estimators so that those estimators can handle parallelization when appropriate (eg distributed hyper parameter tuning). Testing: Existing unit tests. Author: Bago Amirbekian <bago@databricks.com> Closes #18077 from MrBago/delegate_params.
1 parent 4816c2e commit 9434280

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

python/pyspark/ml/tuning.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,11 @@
1818
import itertools
1919
import numpy as np
2020

21-
from pyspark import SparkContext
2221
from pyspark import since, keyword_only
2322
from pyspark.ml import Estimator, Model
2423
from pyspark.ml.param import Params, Param, TypeConverters
2524
from pyspark.ml.param.shared import HasSeed
26-
from pyspark.ml.wrapper import JavaParams
2725
from pyspark.sql.functions import rand
28-
from pyspark.ml.common import inherit_doc, _py2java
2926

3027
__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit',
3128
'TrainValidationSplitModel']
@@ -232,8 +229,9 @@ def _fit(self, dataset):
232229
condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB)
233230
validation = df.filter(condition)
234231
train = df.filter(~condition)
232+
models = est.fit(train, epm)
235233
for j in range(numModels):
236-
model = est.fit(train, epm[j])
234+
model = models[j]
237235
# TODO: duplicate evaluator to take extra params from input
238236
metric = eva.evaluate(model.transform(validation, epm[j]))
239237
metrics[j] += metric/nFolds
@@ -388,8 +386,9 @@ def _fit(self, dataset):
388386
condition = (df[randCol] >= tRatio)
389387
validation = df.filter(condition)
390388
train = df.filter(~condition)
389+
models = est.fit(train, epm)
391390
for j in range(numModels):
392-
model = est.fit(train, epm[j])
391+
model = models[j]
393392
metric = eva.evaluate(model.transform(validation, epm[j]))
394393
metrics[j] += metric
395394
if eva.isLargerBetter():

0 commit comments

Comments
 (0)