Skip to content

[SPARK-17848][ML] Move LabelCol datatype cast into Predictor.fit #15414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from

Conversation

zhengruifeng
Copy link
Contributor

What changes were proposed in this pull request?

1, move cast to Predictor
2, and then, remove unnecessary cast

How was this patch tested?

existing tests

@SparkQA
Copy link

SparkQA commented Oct 10, 2016

Test build #66629 has finished for PR 15414 at commit 5cb06fc.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Oct 10, 2016

Test build #66635 has finished for PR 15414 at commit 6c2a8d0.

  • This patch fails to build.
  • This patch merges cleanly.
  • This patch adds no public classes.

@zhengruifeng
Copy link
Contributor Author

zhengruifeng commented Oct 10, 2016

Jenkins, test this please

@SparkQA
Copy link

SparkQA commented Oct 10, 2016

Test build #66637 has finished for PR 15414 at commit 6c2a8d0.

  • This patch fails to build.
  • This patch merges cleanly.
  • This patch adds no public classes.

@zhengruifeng
Copy link
Contributor Author

Jenkins, retest this please

@SparkQA
Copy link

SparkQA commented Oct 10, 2016

Test build #66649 has finished for PR 15414 at commit 6c2a8d0.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

* Return the given DataFrame, with [[labelCol]] casted to DoubleType.
*/
protected def castDataSet(dataset: Dataset[_]): DataFrame = {
val labelMeta = dataset.schema.fields.filter(_.name == $(labelCol)).head.metadata
Copy link
Contributor

@hhbyyh hhbyyh Oct 10, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe simplify it: dataset.schema("value").metadata

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

@sethah
Copy link
Contributor

sethah commented Oct 10, 2016

What do you think about adding a new suite PredictorSuite where we can create a mock predictor, and call train on data of various types. The train method can just require that the label column is DoubleType:

class MockPredictor(override val uid: String)
  extends Predictor[Vector, MockPredictor, MockPredictionModel] {

  override def train(dataset: Dataset[_]): MockPredictionModel = {
    require(dataset.schema("label").dataType == DoubleType)
    new MockPredictionModel(uid)
  }

  override def copy(extra: ParamMap): MockPredictor = defaultCopy(extra)
}

class MockPredictionModel(override val uid: String)
  extends PredictionModel[Vector, MockPredictionModel] {

  override def predict(features: Vector): Double = 1.0

  override def copy(extra: ParamMap): MockPredictionModel = defaultCopy(extra)
}

Then we just have a test that calls fit for each type of data.

@zhengruifeng
Copy link
Contributor Author

Ok, I will create this Suite.

@SparkQA
Copy link

SparkQA commented Oct 11, 2016

Test build #66710 has finished for PR 15414 at commit 6c61e73.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.


import testImplicits._

class MockPredictor(override val uid: String)
Copy link
Contributor

@sethah sethah Oct 12, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move into companion object.

/**
* Return the given DataFrame, with [[labelCol]] casted to DoubleType.
*/
protected def castDataSet(dataset: Dataset[_]): DataFrame = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just put this logic directly in fit

@@ -117,7 +117,7 @@ object MLTestingUtils extends SparkFunSuite {
Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
types.map { t =>
val castDF = df.select(col(labelColName).cast(t), col(featuresColName))
t -> TreeTests.setMetadata(castDF, 2, labelColName, featuresColName)
t -> TreeTests.setMetadata(castDF, 0, labelColName, featuresColName)
Copy link
Contributor

@sethah sethah Oct 12, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this for? If the intent is to force getNumClasses to infer the number of classes, then you're no longer testing the not inferred case. Further, the point of this PR is to eliminate the need to do that since it is not a robust solution, IMO.

Also, I'd like to remove the dependence on TreeTests here (and genRegressionDF) and just explicitly set the attributes in the functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will revert this


test("should support all NumericType labels and not support other types") {
val predictor = new MockPredictor("mock")
MLTestingUtils.checkNumericTypes[MockPredictionModel, MockPredictor](
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we just cycle through the types here and call fit. I think it's a bit confusing the way it is now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will update this.

class MockPredictionModel(override val uid: String)
extends PredictionModel[Vector, MockPredictionModel] {

override def predict(features: Vector): Double = 1.0
Copy link
Contributor

@sethah sethah Oct 12, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override def predict(features: Vector): Double = throw new NotImplementedError() We can do this for everything except train.

@SparkQA
Copy link

SparkQA commented Oct 12, 2016

Test build #66814 has finished for PR 15414 at commit 6ef17b7.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@zhengruifeng
Copy link
Contributor Author

@sethah I have maken some modification according to the comments

new MockPredictionModel(uid)
}

override def copy(extra: ParamMap): MockPredictor = defaultCopy(extra)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change the copy methods to throw NotImplementedError

@sethah
Copy link
Contributor

sethah commented Oct 13, 2016

Thanks, I'll take a more detailed look in the next couple of days. Let's also wait and see if we can get @yanboliang or @jkbradley to give an opinion.

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't need DefaultReadWriteTest

@SparkQA
Copy link

SparkQA commented Oct 13, 2016

Test build #66872 has finished for PR 15414 at commit 7e2d501.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext

@SparkQA
Copy link

SparkQA commented Oct 13, 2016

Test build #66880 has finished for PR 15414 at commit 7cb4510.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@zhengruifeng
Copy link
Contributor Author

@jkbradley @yanboliang Could you please have a review of this? This PR unify usage of labelCol casting and fixs a bug described in [https://issues.apache.org/jira/browse/SPARK-17797]

@zhengruifeng
Copy link
Contributor Author

@jkbradley @yanboliang Just re-pinging for your opinions.

@jkbradley
Copy link
Member

Can you please document in Predictor that it accepts all NumericType labels? Other than that, this LGTM. Thanks!

@sethah
Copy link
Contributor

sethah commented Nov 1, 2016

LGTM as well after adding @jkbradley's suggestion.

@zhengruifeng
Copy link
Contributor Author

@jkbradley @sethah I add a comment, thanks for reviews.

@SparkQA
Copy link

SparkQA commented Nov 1, 2016

Test build #67861 has finished for PR 15414 at commit 810c973.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@jkbradley
Copy link
Member

LGTM
Merging with master
Thanks!

@asfgit asfgit closed this in 8ac0910 Nov 1, 2016
@zhengruifeng zhengruifeng deleted the move_cast branch November 2, 2016 01:32
uzadude pushed a commit to uzadude/spark that referenced this pull request Jan 27, 2017
## What changes were proposed in this pull request?

1, move cast to `Predictor`
2, and then, remove unnecessary cast
## How was this patch tested?

existing tests

Author: Zheng RuiFeng <ruifengz@foxmail.com>

Closes apache#15414 from zhengruifeng/move_cast.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants