Skip to content

[SPARK-7387][ml][doc] CrossValidator example code in Python #6358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed

[SPARK-7387][ml][doc] CrossValidator example code in Python #6358

wants to merge 6 commits into from

Conversation

harsha2010
Copy link

No description provided.

@harsha2010
Copy link
Author

@jkbradley @mengxr, please review. I wanted to add an example of how to introspect the best model's parameters in python (in scala, i was able to get the pipeline model's stages and extract the hashingTF.numFeatures and lrModel.regParam). But the Python implementation doesn't seem to expose lrModel.getParam (i might be mistaken here)

@SparkQA
Copy link

SparkQA commented May 22, 2015

Test build #33344 has finished for PR 6358 at commit 7246d35.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@jkbradley
Copy link
Member

@harsha2010 You're correct that the Python API is incomplete. Most of the spark.ml Models are simple wrappers around Java objects and don't expose much functionality other than transform() right now. Next release...


# Prepare training documents, which are labeled.
LabeledDocument = Row("id", "text", "label")
training = sc.parallelize([(0, "a b c d e spark", 1.0),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not give the hoped-for predictions on the test data (predicting everything with "Spark" as 1.0). In the Scala and Java examples, I added more training examples to make it more likely that the chosen model would look for "Spark." Could you please add those other training examples here too?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jkbradley , sure..is this checked into master yet? on master i notice the most recent change was to change the regularization and one of the test examples

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, it's from Spark 1.3: [https://github.com/apache/spark/blob/1bb5d716c0351cd0b4c11b397fd778f30db39bd9/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala]

@jkbradley
Copy link
Member

@harsha2010 Also, would you mind updating (either in this PR or another) the forest & GBT examples you just wrote? After a little investigation, I realized the explanation I gave for why Spark SQL can infer the schema from LabeledPoint is wrong. It's because LabeledPoint is an Object, and Spark SQL identifies the 2 fields "label" and "features" from the Object.

@SparkQA
Copy link

SparkQA commented Jun 2, 2015

Test build #33971 has finished for PR 6358 at commit 54a500c.

  • This patch fails Python style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jun 2, 2015

Test build #33975 has finished for PR 6358 at commit aeb6bb6.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@jkbradley
Copy link
Member

Except for the issue with needing more training examples, this looks good to me.

@harsha2010
Copy link
Author

@jkbradley thanks, i have updated the cross validator with enough examples now

@jkbradley
Copy link
Member

@harsha2010 Thanks! (Does it give the hoped-for predictions on the test data now when you run it?)

@harsha2010
Copy link
Author

@jkbradley yes it does much better now

Row(id=4, text=u'spark i j k', probability=DenseVector([0.248, 0.752]), prediction=1.0)
Row(id=5, text=u'l m n', probability=DenseVector([0.9647, 0.0353]), prediction=0.0)
Row(id=6, text=u'mapreduce spark', probability=DenseVector([0.4248, 0.5752]), prediction=1.0)
Row(id=7, text=u'apache hadoop', probability=DenseVector([0.69, 0.31]), prediction=0.0)

@jkbradley
Copy link
Member

Haha good. LGTM pending tests

@SparkQA
Copy link

SparkQA commented Jun 2, 2015

Test build #34008 has finished for PR 6358 at commit 63efda2.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@jkbradley
Copy link
Member

Merging with master and branch-1.4

asfgit pushed a commit that referenced this pull request Jun 3, 2015
Author: Ram Sriharsha <rsriharsha@hw11853.local>

Closes #6358 from harsha2010/SPARK-7387 and squashes the following commits:

63efda2 [Ram Sriharsha] more examples for classifier to distinguish mapreduce from spark properly
aeb6bb6 [Ram Sriharsha] Python Style Fix
54a500c [Ram Sriharsha] Merge branch 'master' into SPARK-7387
615e91c [Ram Sriharsha] cleanup
204c4e3 [Ram Sriharsha] Merge branch 'master' into SPARK-7387
7246d35 [Ram Sriharsha] [SPARK-7387][ml][doc] CrossValidator example code in Python

(cherry picked from commit c3f4c32)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
@asfgit asfgit closed this in c3f4c32 Jun 3, 2015
jeanlyn pushed a commit to jeanlyn/spark that referenced this pull request Jun 12, 2015
Author: Ram Sriharsha <rsriharsha@hw11853.local>

Closes apache#6358 from harsha2010/SPARK-7387 and squashes the following commits:

63efda2 [Ram Sriharsha] more examples for classifier to distinguish mapreduce from spark properly
aeb6bb6 [Ram Sriharsha] Python Style Fix
54a500c [Ram Sriharsha] Merge branch 'master' into SPARK-7387
615e91c [Ram Sriharsha] cleanup
204c4e3 [Ram Sriharsha] Merge branch 'master' into SPARK-7387
7246d35 [Ram Sriharsha] [SPARK-7387][ml][doc] CrossValidator example code in Python
nemccarthy pushed a commit to nemccarthy/spark that referenced this pull request Jun 19, 2015
Author: Ram Sriharsha <rsriharsha@hw11853.local>

Closes apache#6358 from harsha2010/SPARK-7387 and squashes the following commits:

63efda2 [Ram Sriharsha] more examples for classifier to distinguish mapreduce from spark properly
aeb6bb6 [Ram Sriharsha] Python Style Fix
54a500c [Ram Sriharsha] Merge branch 'master' into SPARK-7387
615e91c [Ram Sriharsha] cleanup
204c4e3 [Ram Sriharsha] Merge branch 'master' into SPARK-7387
7246d35 [Ram Sriharsha] [SPARK-7387][ml][doc] CrossValidator example code in Python
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants