Skip to content

Commit 7f6a8ab

Browse files
committed
[SPARK-31777][ML][PYSPARK] Add user-specified fold column to CrossValidator
### What changes were proposed in this pull request? This patch adds user-specified fold column support to `CrossValidator`. User can assign fold numbers to dataset instead of letting Spark do random splits. ### Why are the changes needed? This gives `CrossValidator` users more flexibility in splitting folds. ### Does this PR introduce _any_ user-facing change? Yes, a new `foldCol` param is added to `CrossValidator`. User can use it to specify custom fold splitting. ### How was this patch tested? Added unit tests. Closes #28704 from viirya/SPARK-31777. Authored-by: Liang-Chi Hsieh <viirya@gmail.com> Signed-off-by: Liang-Chi Hsieh <liangchi@uber.com>
1 parent 2ec9b86 commit 7f6a8ab

File tree

6 files changed

+308
-25
lines changed

6 files changed

+308
-25
lines changed

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@ import org.apache.spark.annotation.Since
3030
import org.apache.spark.internal.Logging
3131
import org.apache.spark.ml.{Estimator, Model}
3232
import org.apache.spark.ml.evaluation.Evaluator
33-
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
33+
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators}
3434
import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism}
3535
import org.apache.spark.ml.util._
3636
import org.apache.spark.ml.util.Instrumentation.instrumented
3737
import org.apache.spark.mllib.util.MLUtils
3838
import org.apache.spark.sql.{DataFrame, Dataset}
39-
import org.apache.spark.sql.types.StructType
39+
import org.apache.spark.sql.types.{IntegerType, StructType}
4040
import org.apache.spark.util.ThreadUtils
4141

4242
/**
@@ -56,6 +56,19 @@ private[ml] trait CrossValidatorParams extends ValidatorParams {
5656
def getNumFolds: Int = $(numFolds)
5757

5858
setDefault(numFolds -> 3)
59+
60+
/**
61+
* Param for the column name of user specified fold number. Once this is specified,
62+
* `CrossValidator` won't do random k-fold split. Note that this column should be
63+
* integer type with range [0, numFolds) and Spark will throw exception on out-of-range
64+
* fold numbers.
65+
*/
66+
val foldCol: Param[String] = new Param[String](this, "foldCol",
67+
"the column name of user specified fold number")
68+
69+
def getFoldCol: String = $(foldCol)
70+
71+
setDefault(foldCol, "")
5972
}
6073

6174
/**
@@ -94,6 +107,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
94107
@Since("2.0.0")
95108
def setSeed(value: Long): this.type = set(seed, value)
96109

110+
/** @group setParam */
111+
@Since("3.1.0")
112+
def setFoldCol(value: String): this.type = set(foldCol, value)
113+
97114
/**
98115
* Set the maximum level of parallelism to evaluate models in parallel.
99116
* Default is 1 for serial evaluation
@@ -132,7 +149,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
132149

133150
instr.logPipelineStage(this)
134151
instr.logDataset(dataset)
135-
instr.logParams(this, numFolds, seed, parallelism)
152+
instr.logParams(this, numFolds, seed, parallelism, foldCol)
136153
logTuningParams(instr)
137154

138155
val collectSubModelsParam = $(collectSubModels)
@@ -142,10 +159,15 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
142159
} else None
143160

144161
// Compute metrics for each model over each split
145-
val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
162+
val (splits, schemaWithoutFold) = if ($(foldCol) == "") {
163+
(MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)), schema)
164+
} else {
165+
val filteredSchema = StructType(schema.filter(_.name != $(foldCol)).toArray)
166+
(MLUtils.kFold(dataset.toDF, $(numFolds), $(foldCol)), filteredSchema)
167+
}
146168
val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) =>
147-
val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
148-
val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
169+
val trainingDataset = sparkSession.createDataFrame(training, schemaWithoutFold).cache()
170+
val validationDataset = sparkSession.createDataFrame(validation, schemaWithoutFold).cache()
149171
instr.logDebug(s"Train split $splitIndex with multiple sets of parameters.")
150172

151173
// Fit models in a Future for training in parallel
@@ -183,7 +205,14 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
183205
}
184206

185207
@Since("1.4.0")
186-
override def transformSchema(schema: StructType): StructType = transformSchemaImpl(schema)
208+
override def transformSchema(schema: StructType): StructType = {
209+
if ($(foldCol) != "") {
210+
val foldColDt = schema.apply($(foldCol)).dataType
211+
require(foldColDt.isInstanceOf[IntegerType],
212+
s"The specified `foldCol` column ${$(foldCol)} must be integer type, but got $foldColDt.")
213+
}
214+
transformSchemaImpl(schema)
215+
}
187216

188217
@Since("1.4.0")
189218
override def copy(extra: ParamMap): CrossValidator = {

mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@ package org.apache.spark.mllib.util
2020
import scala.annotation.varargs
2121
import scala.reflect.ClassTag
2222

23-
import org.apache.spark.SparkContext
23+
import org.apache.spark.{SparkContext, SparkException}
2424
import org.apache.spark.annotation.Since
2525
import org.apache.spark.internal.Logging
2626
import org.apache.spark.ml.linalg.{MatrixUDT => MLMatrixUDT, VectorUDT => MLVectorUDT}
2727
import org.apache.spark.mllib.linalg._
2828
import org.apache.spark.mllib.linalg.BLAS.dot
2929
import org.apache.spark.mllib.regression.LabeledPoint
3030
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
31-
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
31+
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
3232
import org.apache.spark.sql.execution.datasources.DataSource
3333
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
3434
import org.apache.spark.sql.functions._
@@ -248,6 +248,36 @@ object MLUtils extends Logging {
248248
}.toArray
249249
}
250250

251+
/**
252+
* Version of `kFold()` taking a fold column name.
253+
*/
254+
@Since("3.1.0")
255+
def kFold(df: DataFrame, numFolds: Int, foldColName: String): Array[(RDD[Row], RDD[Row])] = {
256+
val foldCol = df.col(foldColName)
257+
val checker = udf { foldNum: Int =>
258+
// Valid fold number is in range [0, numFolds).
259+
if (foldNum < 0 || foldNum >= numFolds) {
260+
throw new SparkException(s"Fold number must be in range [0, $numFolds), but got $foldNum.")
261+
}
262+
true
263+
}
264+
(0 until numFolds).map { fold =>
265+
val training = df
266+
.filter(checker(foldCol) && foldCol =!= fold)
267+
.drop(foldColName).rdd
268+
val validation = df
269+
.filter(checker(foldCol) && foldCol === fold)
270+
.drop(foldColName).rdd
271+
if (training.isEmpty()) {
272+
throw new SparkException(s"The training data at fold $fold is empty.")
273+
}
274+
if (validation.isEmpty()) {
275+
throw new SparkException(s"The validation data at fold $fold is empty.")
276+
}
277+
(training, validation)
278+
}.toArray
279+
}
280+
251281
/**
252282
* Returns a new vector with `1.0` (bias) appended to the input vector.
253283
*/

mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.ml.regression.LinearRegression
3232
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
3333
import org.apache.spark.mllib.util.LinearDataGenerator
3434
import org.apache.spark.sql.Dataset
35+
import org.apache.spark.sql.functions._
3536
import org.apache.spark.sql.types.StructType
3637

3738
class CrossValidatorSuite
@@ -40,10 +41,14 @@ class CrossValidatorSuite
4041
import testImplicits._
4142

4243
@transient var dataset: Dataset[_] = _
44+
@transient var datasetWithFold: Dataset[_] = _
4345

4446
override def beforeAll(): Unit = {
4547
super.beforeAll()
4648
dataset = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2).toDF()
49+
val dfWithRandom = dataset.repartition(1).withColumn("random", rand(100L))
50+
val foldCol = when(col("random") < 0.33, 0).when(col("random") < 0.66, 1).otherwise(2)
51+
datasetWithFold = dfWithRandom.withColumn("fold", foldCol).drop("random").repartition(2)
4752
}
4853

4954
test("cross validation with logistic regression") {
@@ -75,6 +80,65 @@ class CrossValidatorSuite
7580
}
7681
}
7782

83+
test("cross validation with logistic regression with fold col") {
84+
val lr = new LogisticRegression
85+
val lrParamMaps = new ParamGridBuilder()
86+
.addGrid(lr.regParam, Array(0.001, 1000.0))
87+
.addGrid(lr.maxIter, Array(0, 10))
88+
.build()
89+
val eval = new BinaryClassificationEvaluator
90+
val cv = new CrossValidator()
91+
.setEstimator(lr)
92+
.setEstimatorParamMaps(lrParamMaps)
93+
.setEvaluator(eval)
94+
.setNumFolds(3)
95+
.setFoldCol("fold")
96+
val cvModel = cv.fit(datasetWithFold)
97+
98+
MLTestingUtils.checkCopyAndUids(cv, cvModel)
99+
100+
val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
101+
assert(parent.getRegParam === 0.001)
102+
assert(parent.getMaxIter === 10)
103+
assert(cvModel.avgMetrics.length === lrParamMaps.length)
104+
105+
val result = cvModel.transform(dataset).select("prediction").as[Double].collect()
106+
testTransformerByGlobalCheckFunc[(Double, Vector)](dataset.toDF(), cvModel, "prediction") {
107+
rows =>
108+
val result2 = rows.map(_.getDouble(0))
109+
assert(result === result2)
110+
}
111+
}
112+
113+
test("cross validation with logistic regression with wrong fold col") {
114+
val lr = new LogisticRegression
115+
val lrParamMaps = new ParamGridBuilder()
116+
.addGrid(lr.regParam, Array(0.001, 1000.0))
117+
.addGrid(lr.maxIter, Array(0, 10))
118+
.build()
119+
val eval = new BinaryClassificationEvaluator
120+
val cv = new CrossValidator()
121+
.setEstimator(lr)
122+
.setEstimatorParamMaps(lrParamMaps)
123+
.setEvaluator(eval)
124+
.setNumFolds(3)
125+
.setFoldCol("fold1")
126+
val err1 = intercept[IllegalArgumentException] {
127+
cv.fit(datasetWithFold)
128+
}
129+
assert(err1.getMessage.contains("fold1 does not exist. Available: label, features, fold"))
130+
131+
// Fold column must be integer type.
132+
val foldCol = udf(() => 1L)
133+
val datasetWithWrongFoldType = dataset.withColumn("fold1", foldCol())
134+
val err2 = intercept[IllegalArgumentException] {
135+
cv.fit(datasetWithWrongFoldType)
136+
}
137+
assert(err2
138+
.getMessage
139+
.contains("The specified `foldCol` column fold1 must be integer type, but got LongType."))
140+
}
141+
78142
test("cross validation with linear regression") {
79143
val dataset = sc.parallelize(
80144
LinearDataGenerator.generateLinearInput(

mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,4 +353,34 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext {
353353
convertMatrixColumnsFromML(df, "p._2")
354354
}
355355
}
356+
357+
test("kFold with fold column") {
358+
val data = sc.parallelize(1 to 100, 2).map(x => (x, if (x <= 50) 0 else 1)).toDF("i", "fold")
359+
val collectedData = data.collect().map(_.getInt(0)).sorted
360+
val twoFoldedRdd = kFold(data, 2, "fold")
361+
assert(twoFoldedRdd(0)._1.collect().map(_.getInt(0)).sorted ===
362+
twoFoldedRdd(1)._2.collect().map(_.getInt(0)).sorted)
363+
assert(twoFoldedRdd(0)._2.collect().map(_.getInt(0)).sorted ===
364+
twoFoldedRdd(1)._1.collect().map(_.getInt(0)).sorted)
365+
366+
val result1 = twoFoldedRdd(0)._1.union(twoFoldedRdd(0)._2).collect().map(_.getInt(0)).sorted
367+
assert(result1 === collectedData,
368+
"Each training+validation set combined should contain all of the data.")
369+
val result2 = twoFoldedRdd(1)._1.union(twoFoldedRdd(1)._2).collect().map(_.getInt(0)).sorted
370+
assert(result2 === collectedData,
371+
"Each training+validation set combined should contain all of the data.")
372+
}
373+
374+
test("kFold with fold column: invalid fold numbers") {
375+
val data = sc.parallelize(Seq(0, 1, 2), 2).toDF( "fold")
376+
val err1 = intercept[SparkException] {
377+
kFold(data, 2, "fold")(0)._1.collect()
378+
}
379+
assert(err1.getMessage.contains("Fold number must be in range [0, 2), but got 2."))
380+
381+
val err2 = intercept[SparkException] {
382+
kFold(data, 4, "fold")(0)._1.collect()
383+
}
384+
assert(err2.getMessage.contains("The validation data at fold 3 is empty."))
385+
}
356386
}

python/pyspark/ml/tests/test_tuning.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,80 @@ def test_save_load_pipeline_estimator(self):
380380
original_nested_pipeline_model.stages):
381381
self.assertEqual(loadedStage.uid, originalStage.uid)
382382

383+
def test_user_specified_folds(self):
384+
from pyspark.sql import functions as F
385+
386+
dataset = self.spark.createDataFrame(
387+
[(Vectors.dense([0.0]), 0.0),
388+
(Vectors.dense([0.4]), 1.0),
389+
(Vectors.dense([0.5]), 0.0),
390+
(Vectors.dense([0.6]), 1.0),
391+
(Vectors.dense([1.0]), 1.0)] * 10,
392+
["features", "label"]).repartition(2, "features")
393+
394+
dataset_with_folds = dataset.repartition(1).withColumn("random", rand(100)) \
395+
.withColumn("fold", F.when(F.col("random") < 0.33, 0)
396+
.when(F.col("random") < 0.66, 1)
397+
.otherwise(2)).repartition(2, "features")
398+
399+
lr = LogisticRegression()
400+
grid = ParamGridBuilder().addGrid(lr.maxIter, [20]).build()
401+
evaluator = BinaryClassificationEvaluator()
402+
403+
cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, numFolds=3)
404+
cv_with_user_folds = CrossValidator(estimator=lr,
405+
estimatorParamMaps=grid,
406+
evaluator=evaluator,
407+
numFolds=3,
408+
foldCol="fold")
409+
410+
self.assertEqual(cv.getEstimator().uid, cv_with_user_folds.getEstimator().uid)
411+
412+
cvModel1 = cv.fit(dataset)
413+
cvModel2 = cv_with_user_folds.fit(dataset_with_folds)
414+
for index in range(len(cvModel1.avgMetrics)):
415+
print(abs(cvModel1.avgMetrics[index] - cvModel2.avgMetrics[index]))
416+
self.assertTrue(abs(cvModel1.avgMetrics[index] - cvModel2.avgMetrics[index])
417+
< 0.1)
418+
419+
# test save/load of CrossValidator
420+
temp_path = tempfile.mkdtemp()
421+
cvPath = temp_path + "/cv"
422+
cv_with_user_folds.save(cvPath)
423+
loadedCV = CrossValidator.load(cvPath)
424+
self.assertEqual(loadedCV.getFoldCol(), cv_with_user_folds.getFoldCol())
425+
426+
def test_invalid_user_specified_folds(self):
427+
from pyspark.sql import functions as F
428+
429+
dataset_with_folds = self.spark.createDataFrame(
430+
[(Vectors.dense([0.0]), 0.0, 0),
431+
(Vectors.dense([0.4]), 1.0, 1),
432+
(Vectors.dense([0.5]), 0.0, 2),
433+
(Vectors.dense([0.6]), 1.0, 0),
434+
(Vectors.dense([1.0]), 1.0, 1)] * 10,
435+
["features", "label", "fold"])
436+
437+
lr = LogisticRegression()
438+
grid = ParamGridBuilder().addGrid(lr.maxIter, [20]).build()
439+
evaluator = BinaryClassificationEvaluator()
440+
441+
cv = CrossValidator(estimator=lr,
442+
estimatorParamMaps=grid,
443+
evaluator=evaluator,
444+
numFolds=2,
445+
foldCol="fold")
446+
with self.assertRaisesRegexp(Exception, "Fold number must be in range"):
447+
cv.fit(dataset_with_folds)
448+
449+
cv = CrossValidator(estimator=lr,
450+
estimatorParamMaps=grid,
451+
evaluator=evaluator,
452+
numFolds=4,
453+
foldCol="fold")
454+
with self.assertRaisesRegexp(Exception, "The validation data at fold 3 is empty"):
455+
cv.fit(dataset_with_folds)
456+
383457

384458
class TrainValidationSplitTests(SparkSessionTestCase):
385459

0 commit comments

Comments
 (0)