Description
PFI (Permutation Feature Importance) API needs to be simpler to use
1. First, it is awkward to need to access to the LastTransformer method from the model (Chain of Transformers). In addition, if you are using additional methods to structure your training, evaluation and PFI calculation and try to pass the model as an ITransformer (the usual way) you need to cast it back to the concrete type of transformer chain (such as TransformerChain<RegressionPredictionTransformer<LightGbmRegressionModelParameters>>
), which then requires a hard reference to the type of algorithm used.
This is the code to calculate the PFI metrics:
// Make predictions (Transform the dataset)
IDataView transformedData = trainedModel.Transform(trainingDataView);
// Extract the trainer (last transformer in the model)
var singleLightGbmModel = (trainedModel as TransformerChain<RegressionPredictionTransformer<LightGbmRegressionModelParameters>>).LastTransformer;
// or simpler if the trainedModel was 'var' right after the call to Fit():
// var singleLightGbmModel = trainedModel.LastTransformer;
//Calculate Feature Permutation
ImmutableArray<RegressionMetricsStatistics> permutationMetrics =
mlContext
.Regression.PermutationFeatureImportance(predictionTransformer: singleLightGbmModel,
data: transformedData,
labelColumnName: "fare_amount",
numberOfExamplesToUse: 100,
permutationCount: 1);
Needing to only use/provide the last transformer feels a bit convoluted...
The API should be simpler to use here and make such a thing transparent to the user?
2. Second, once you get the permutation metrics (such as ImmutableArray<RegressionMetricsStatistics> permutationMetrics
), you only get the values based on the indexes, but you don't have the names of the input columns. It is then not straightforward to correlate it to the input column names since you need to use the indexes to be used across two separated arrays that , if sorted previously, it won't match...
You need to do something like the following or comparable loops driven by the indexes in the permutationMetrics array:
First, obtain all the column names used in the PFI process and exclude the ones not used:
var usedColumnNamesInPFI = dataView.Schema
.Where(col => (col.Name != "SamplingKeyColumn") && (col.Name != "Features") && (col.Name != "Score"))
.Select(col => col.Name);
Then you need to correlate and find the column names based on the indexes in the permutationMetrics:
var results = usedColumnNamesInPFI
.Select((t, i) => new FeatureImportance
{
Name = t,
RSquaredMean = Math.Abs(permutationMetrics[i].RSquared.Mean)
})
.OrderByDescending(x => x.RSquaredMean);
This should be directly provided by the API and you'd simply need to access it and show it.
The current code feels very much convoluted...