|
18 | 18 | import itertools
|
19 | 19 | import numpy as np
|
20 | 20 |
|
21 |
| -from pyspark import SparkContext |
22 | 21 | from pyspark import since, keyword_only
|
23 | 22 | from pyspark.ml import Estimator, Model
|
24 | 23 | from pyspark.ml.param import Params, Param, TypeConverters
|
25 | 24 | from pyspark.ml.param.shared import HasSeed
|
26 |
| -from pyspark.ml.wrapper import JavaParams |
27 | 25 | from pyspark.sql.functions import rand
|
28 |
| -from pyspark.ml.common import inherit_doc, _py2java |
29 | 26 |
|
30 | 27 | __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit',
|
31 | 28 | 'TrainValidationSplitModel']
|
@@ -232,8 +229,9 @@ def _fit(self, dataset):
|
232 | 229 | condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB)
|
233 | 230 | validation = df.filter(condition)
|
234 | 231 | train = df.filter(~condition)
|
| 232 | + models = est.fit(train, epm) |
235 | 233 | for j in range(numModels):
|
236 |
| - model = est.fit(train, epm[j]) |
| 234 | + model = models[j] |
237 | 235 | # TODO: duplicate evaluator to take extra params from input
|
238 | 236 | metric = eva.evaluate(model.transform(validation, epm[j]))
|
239 | 237 | metrics[j] += metric/nFolds
|
@@ -388,8 +386,9 @@ def _fit(self, dataset):
|
388 | 386 | condition = (df[randCol] >= tRatio)
|
389 | 387 | validation = df.filter(condition)
|
390 | 388 | train = df.filter(~condition)
|
| 389 | + models = est.fit(train, epm) |
391 | 390 | for j in range(numModels):
|
392 |
| - model = est.fit(train, epm[j]) |
| 391 | + model = models[j] |
393 | 392 | metric = eva.evaluate(model.transform(validation, epm[j]))
|
394 | 393 | metrics[j] += metric
|
395 | 394 | if eva.isLargerBetter():
|
|
0 commit comments