-
Notifications
You must be signed in to change notification settings - Fork 28.6k
[SPARK-17848][ML] Move LabelCol datatype cast into Predictor.fit #15414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Test build #66629 has finished for PR 15414 at commit
|
Test build #66635 has finished for PR 15414 at commit
|
Jenkins, test this please |
Test build #66637 has finished for PR 15414 at commit
|
Jenkins, retest this please |
Test build #66649 has finished for PR 15414 at commit
|
* Return the given DataFrame, with [[labelCol]] casted to DoubleType. | ||
*/ | ||
protected def castDataSet(dataset: Dataset[_]): DataFrame = { | ||
val labelMeta = dataset.schema.fields.filter(_.name == $(labelCol)).head.metadata |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe simplify it: dataset.schema("value").metadata
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
What do you think about adding a new suite class MockPredictor(override val uid: String)
extends Predictor[Vector, MockPredictor, MockPredictionModel] {
override def train(dataset: Dataset[_]): MockPredictionModel = {
require(dataset.schema("label").dataType == DoubleType)
new MockPredictionModel(uid)
}
override def copy(extra: ParamMap): MockPredictor = defaultCopy(extra)
}
class MockPredictionModel(override val uid: String)
extends PredictionModel[Vector, MockPredictionModel] {
override def predict(features: Vector): Double = 1.0
override def copy(extra: ParamMap): MockPredictionModel = defaultCopy(extra)
} Then we just have a test that calls |
Ok, I will create this Suite. |
6c2a8d0
to
6c61e73
Compare
Test build #66710 has finished for PR 15414 at commit
|
|
||
import testImplicits._ | ||
|
||
class MockPredictor(override val uid: String) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move into companion object.
/** | ||
* Return the given DataFrame, with [[labelCol]] casted to DoubleType. | ||
*/ | ||
protected def castDataSet(dataset: Dataset[_]): DataFrame = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's just put this logic directly in fit
@@ -117,7 +117,7 @@ object MLTestingUtils extends SparkFunSuite { | |||
Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) | |||
types.map { t => | |||
val castDF = df.select(col(labelColName).cast(t), col(featuresColName)) | |||
t -> TreeTests.setMetadata(castDF, 2, labelColName, featuresColName) | |||
t -> TreeTests.setMetadata(castDF, 0, labelColName, featuresColName) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this for? If the intent is to force getNumClasses
to infer the number of classes, then you're no longer testing the not inferred case. Further, the point of this PR is to eliminate the need to do that since it is not a robust solution, IMO.
Also, I'd like to remove the dependence on TreeTests
here (and genRegressionDF
) and just explicitly set the attributes in the functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I will revert this
|
||
test("should support all NumericType labels and not support other types") { | ||
val predictor = new MockPredictor("mock") | ||
MLTestingUtils.checkNumericTypes[MockPredictionModel, MockPredictor]( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we just cycle through the types here and call fit
. I think it's a bit confusing the way it is now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I will update this.
class MockPredictionModel(override val uid: String) | ||
extends PredictionModel[Vector, MockPredictionModel] { | ||
|
||
override def predict(features: Vector): Double = 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
override def predict(features: Vector): Double = throw new NotImplementedError()
We can do this for everything except train
.
6c61e73
to
6ef17b7
Compare
Test build #66814 has finished for PR 15414 at commit
|
@sethah I have maken some modification according to the comments |
new MockPredictionModel(uid) | ||
} | ||
|
||
override def copy(extra: ParamMap): MockPredictor = defaultCopy(extra) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change the copy methods to throw NotImplementedError
Thanks, I'll take a more detailed look in the next couple of days. Let's also wait and see if we can get @yanboliang or @jkbradley to give an opinion. |
import org.apache.spark.sql.functions._ | ||
import org.apache.spark.sql.types._ | ||
|
||
class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't need DefaultReadWriteTest
Test build #66872 has finished for PR 15414 at commit
|
Test build #66880 has finished for PR 15414 at commit
|
@jkbradley @yanboliang Could you please have a review of this? This PR unify usage of labelCol casting and fixs a bug described in [https://issues.apache.org/jira/browse/SPARK-17797] |
@jkbradley @yanboliang Just re-pinging for your opinions. |
Can you please document in Predictor that it accepts all NumericType labels? Other than that, this LGTM. Thanks! |
LGTM as well after adding @jkbradley's suggestion. |
7cb4510
to
810c973
Compare
@jkbradley @sethah I add a comment, thanks for reviews. |
Test build #67861 has finished for PR 15414 at commit
|
LGTM |
## What changes were proposed in this pull request? 1, move cast to `Predictor` 2, and then, remove unnecessary cast ## How was this patch tested? existing tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes apache#15414 from zhengruifeng/move_cast.
What changes were proposed in this pull request?
1, move cast to
Predictor
2, and then, remove unnecessary cast
How was this patch tested?
existing tests