-
Notifications
You must be signed in to change notification settings - Fork 28.5k
[SPARK-7387][ml][doc] CrossValidator example code in Python #6358
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@jkbradley @mengxr, please review. I wanted to add an example of how to introspect the best model's parameters in python (in scala, i was able to get the pipeline model's stages and extract the hashingTF.numFeatures and lrModel.regParam). But the Python implementation doesn't seem to expose lrModel.getParam (i might be mistaken here) |
Test build #33344 has finished for PR 6358 at commit
|
@harsha2010 You're correct that the Python API is incomplete. Most of the spark.ml Models are simple wrappers around Java objects and don't expose much functionality other than transform() right now. Next release... |
|
||
# Prepare training documents, which are labeled. | ||
LabeledDocument = Row("id", "text", "label") | ||
training = sc.parallelize([(0, "a b c d e spark", 1.0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not give the hoped-for predictions on the test data (predicting everything with "Spark" as 1.0). In the Scala and Java examples, I added more training examples to make it more likely that the chosen model would look for "Spark." Could you please add those other training examples here too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jkbradley , sure..is this checked into master yet? on master i notice the most recent change was to change the regularization and one of the test examples
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, it's from Spark 1.3: [https://github.com/apache/spark/blob/1bb5d716c0351cd0b4c11b397fd778f30db39bd9/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala]
@harsha2010 Also, would you mind updating (either in this PR or another) the forest & GBT examples you just wrote? After a little investigation, I realized the explanation I gave for why Spark SQL can infer the schema from LabeledPoint is wrong. It's because LabeledPoint is an Object, and Spark SQL identifies the 2 fields "label" and "features" from the Object. |
Test build #33971 has finished for PR 6358 at commit
|
Test build #33975 has finished for PR 6358 at commit
|
Except for the issue with needing more training examples, this looks good to me. |
@jkbradley thanks, i have updated the cross validator with enough examples now |
@harsha2010 Thanks! (Does it give the hoped-for predictions on the test data now when you run it?) |
@jkbradley yes it does much better now Row(id=4, text=u'spark i j k', probability=DenseVector([0.248, 0.752]), prediction=1.0) |
Haha good. LGTM pending tests |
Test build #34008 has finished for PR 6358 at commit
|
Merging with master and branch-1.4 |
Author: Ram Sriharsha <rsriharsha@hw11853.local> Closes #6358 from harsha2010/SPARK-7387 and squashes the following commits: 63efda2 [Ram Sriharsha] more examples for classifier to distinguish mapreduce from spark properly aeb6bb6 [Ram Sriharsha] Python Style Fix 54a500c [Ram Sriharsha] Merge branch 'master' into SPARK-7387 615e91c [Ram Sriharsha] cleanup 204c4e3 [Ram Sriharsha] Merge branch 'master' into SPARK-7387 7246d35 [Ram Sriharsha] [SPARK-7387][ml][doc] CrossValidator example code in Python (cherry picked from commit c3f4c32) Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
Author: Ram Sriharsha <rsriharsha@hw11853.local> Closes apache#6358 from harsha2010/SPARK-7387 and squashes the following commits: 63efda2 [Ram Sriharsha] more examples for classifier to distinguish mapreduce from spark properly aeb6bb6 [Ram Sriharsha] Python Style Fix 54a500c [Ram Sriharsha] Merge branch 'master' into SPARK-7387 615e91c [Ram Sriharsha] cleanup 204c4e3 [Ram Sriharsha] Merge branch 'master' into SPARK-7387 7246d35 [Ram Sriharsha] [SPARK-7387][ml][doc] CrossValidator example code in Python
Author: Ram Sriharsha <rsriharsha@hw11853.local> Closes apache#6358 from harsha2010/SPARK-7387 and squashes the following commits: 63efda2 [Ram Sriharsha] more examples for classifier to distinguish mapreduce from spark properly aeb6bb6 [Ram Sriharsha] Python Style Fix 54a500c [Ram Sriharsha] Merge branch 'master' into SPARK-7387 615e91c [Ram Sriharsha] cleanup 204c4e3 [Ram Sriharsha] Merge branch 'master' into SPARK-7387 7246d35 [Ram Sriharsha] [SPARK-7387][ml][doc] CrossValidator example code in Python
No description provided.