-
Notifications
You must be signed in to change notification settings - Fork 28.6k
[SPARK-12182][ML] Distributed binning for trees in spark.ml #10231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@NathanHowell would you be able to review this? cc @jkbradley |
Yeah I can take a look tonight or tomorrow
|
Test build #47451 has finished for PR 10231 at commit
|
.groupByKey(numPartitions) | ||
.map { case (idx, samples) => | ||
val thresholds = findSplitsForContinuousFeature(samples.toArray, metadata, idx) | ||
val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(as mentioned in jenkins): scala style long line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
At first glance this seems to share a lot of code with the original implementation in MLLib (they both even work with RDDs of LabeledPoints) - maybe we could move much of this to a common util class or similar? |
This JIRA was actually created as a blocker JIRA for SPARK-12183 which is for removing the MLlib code entirely and wrapping to spark.ml. So, the code duplication should be very short-lived. |
Ah great - if were killing the old code soon then no worries on the temporary duplication. |
Test build #47453 has finished for PR 10231 at commit
|
// Unordered features | ||
// 2^(maxFeatureValue - 1) - 1 combinations | ||
val featureArity = metadata.featureArity(i) | ||
val split: IndexedSeq[Split] = Range(0, metadata.numSplits(i)).map { splitIndex => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could use an Array.tablulate
here. Something like
Array.tabulate[Split](numSplits(i)){splitIndex =>
...
}
This avoids allocating two collections, one for the splits
range and the other for splits.toArray
.
Also note that the type parameter [Split]
is required here. This is because the compiler would otherwise infer an Array[CategoricalSplit]
as return type which, because arrays are not covariant, is not a subtype of Array[Split]
and would thus not compile
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks for the suggestion!
Test build #47583 has finished for PR 10231 at commit
|
@NathanHowell do you think you'll have any time to take a look at this? |
@sethah looks good to me. 👍 |
@NathanHowell Thank you for reviewing! |
Would you have time to test this on a small dataset? The original PR confirmed it's faster for a larger dataset, but I'm curious if it affects timing (adversely) on small data. |
I can set something up. Do you have a specific dataset size in mind or even a specific dataset? |
Test build #2645 has finished for PR 10231 at commit
|
No specific dataset size. I was thinking of something in this ballpark:
Thanks! |
metadata: DecisionTreeMetadata, | ||
continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = { | ||
|
||
val continuousSplits = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put type here for code clarity
Could you also make this change: [https://github.com//pull/8246/files#diff-8ad842a043888473bb2b527e818de04bR645] Done with pass. I added a few minor comments which weren't in the spark.mllib PR. |
@jkbradley I ran some local timings comparing before/after this change. I used I just ran five trials each, but I can set up something more robust if needed.
|
That does not seem that bad. I'd say we should go ahead with your PR. If we want to optimize for small data, we can add a local implementation at some point. (But that's far-future.) |
Test build #53552 has finished for PR 10231 at commit
|
Test build #53553 has finished for PR 10231 at commit
|
Test build #53555 has finished for PR 10231 at commit
|
Test build #53591 has finished for PR 10231 at commit
|
@@ -956,7 +956,7 @@ private[ml] object RandomForest extends Logging { | |||
valueCounts.map(_._1) | |||
} else { | |||
// stride between splits | |||
val stride: Double = featureSamples.length.toDouble / (numSplits + 1) | |||
val stride: Double = featureSamples.size.toDouble / (numSplits + 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will do a second pass over the Iterable
. Would it be preferable to combine this into the foldLeft
above so it only does a single pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! The latest commit should take care of it.
Test build #53595 has finished for PR 10231 at commit
|
LGTM |
This PR changes the `findSplits` method in spark.ml to perform split calculations on the workers. This PR is meant to copy [PR-8246](apache#8246) which added the same feature for MLlib. Author: sethah <seth.hendrickson16@gmail.com> Closes apache#10231 from sethah/SPARK-12182.
This PR changes the
findSplits
method in spark.ml to perform split calculations on the workers. This PR is meant to copy PR-8246 which added the same feature for MLlib.