-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New PFI API #5934
New PFI API #5934
Conversation
Codecov Report
@@ Coverage Diff @@
## main #5934 +/- ##
==========================================
+ Coverage 68.20% 68.23% +0.03%
==========================================
Files 1142 1142
Lines 242534 242769 +235
Branches 25378 25385 +7
==========================================
+ Hits 165416 165663 +247
+ Misses 70418 70412 -6
+ Partials 6700 6694 -6
Flags with carried forward coverage won't be shown. Click here to find out more.
|
pfi = ML.Regression.PermutationFeatureImportance(model, data); | ||
pfiDict = ml2.Regression.PermutationFeatureImportance((ITransformer)model, data); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this cast needed? Shouldn't model
already be something that implements ITransformer
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. The issue is that the PFI method (of the same name) takes a ISingleFeaturePredictionTransformer<TModel>
, and so the compiler defaults to that method instead of the method that just takes an ITransformer. model
in this case is just a single estimator so its the actual type instead of ITransformer. Is there another way to work around that?
@@ -57,6 +71,12 @@ public void TestPfiRegressionOnDenseFeatures(bool saveModel) | |||
// X3: 2 | |||
// X4Rand: 3 | |||
|
|||
// Make sure that PFI from the array and the dictionary both have the same value for each feature. | |||
Assert.Equal(JsonConvert.SerializeObject(pfi[0]), JsonConvert.SerializeObject(pfiDict["X1"])); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One concern I have with testing like this is that if RegressionMetricsStatistics
wasn't convertible to JSON, and just serialized to {}
(an empty object), this assert would pass and not really verify anything.
Is there a way we can ensure that these MetricsStatistics classes serialize to JSON correctly if we are going to compare them this way?
Another option would be to write a simple "AssertStatisticsEquals" method that compared all the properties of 2 instances.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would that be better? I thought about doing that, but each type of statistics, Regression/Binary/Multiclass/Ranking, is a littler different so I would need different methods for each one. Thats why I ended up using this approach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure of a good way to know if they serialize correctly. I have pretty much always assumed Newtonsoft.json would just work, but I'm sure there are cases where it doesn't. In this case it does as I manually checked...
Would just verifying its not an empty string or just empty {} be enough? Or would we need better verification. If its not, I can always just make an equals method for these objects. None of the stats ones have it so I could just make it for them all.
@@ -241,7 +320,7 @@ public void TestPfiBinaryClassificationOnDenseFeatures(bool saveModel) | |||
Assert.Equal(1, MinDeltaIndex(pfi, m => m.PositiveRecall.Mean)); | |||
Assert.Equal(3, MaxDeltaIndex(pfi, m => m.NegativePrecision.Mean)); | |||
Assert.Equal(1, MinDeltaIndex(pfi, m => m.NegativePrecision.Mean)); | |||
Assert.Equal(3, MaxDeltaIndex(pfi, m => m.NegativeRecall.Mean)); | |||
Assert.Equal(0, MaxDeltaIndex(pfi, m => m.NegativeRecall.Mean)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?? Is this change intentional? Your changes shouldn't have affected these numbers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may be a good one to sync on, but the problem is that PFI changes the random state. So to make sure that both methods of PFI return the same results I reset the seed to 42 each time. Since that reset didn't happen before it slightly changed the results. Its possible I could find a seed though that would have the same results. Do you think that would be better:?
src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs
Outdated
Show resolved
Hide resolved
src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs
Outdated
Show resolved
Hide resolved
src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs
Outdated
Show resolved
Hide resolved
src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs
Outdated
Show resolved
Hide resolved
src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs
Outdated
Show resolved
Hide resolved
src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good. Should really help out our customers who want to use this API.
Nice work.
69300eb
to
be2ad61
Compare
Based on the feedback we received (thanks @houghj16 for all your work on that study!), our PFI API was too complex (and lets be honest, getting past the initial cast stops everyone...).
This PR fixes #5625 by adding a new PFI API that is much more user friendly in both how to call it, and how the results are returned. The added tests also make sure the output from the original API and the new API are identical if you match up the features correctly.
@JakeRadMSFT @briacht