Skip to content

PFI (Permutation Feature Importance) API needs to be simpler to use #4216

Closed
@CESARDELATORRE

Description

@CESARDELATORRE

PFI (Permutation Feature Importance) API needs to be simpler to use

1. First, it is awkward to need to access to the LastTransformer method from the model (Chain of Transformers). In addition, if you are using additional methods to structure your training, evaluation and PFI calculation and try to pass the model as an ITransformer (the usual way) you need to cast it back to the concrete type of transformer chain (such as TransformerChain<RegressionPredictionTransformer<LightGbmRegressionModelParameters>>), which then requires a hard reference to the type of algorithm used.

This is the code to calculate the PFI metrics:

// Make predictions (Transform the dataset)
IDataView transformedData = trainedModel.Transform(trainingDataView);

// Extract the trainer (last transformer in the model)
var singleLightGbmModel = (trainedModel as TransformerChain<RegressionPredictionTransformer<LightGbmRegressionModelParameters>>).LastTransformer;

// or simpler if the trainedModel was 'var' right after the call to Fit(): 
// var singleLightGbmModel = trainedModel.LastTransformer;

//Calculate Feature Permutation
ImmutableArray<RegressionMetricsStatistics> permutationMetrics =
                                mlContext
                                    .Regression.PermutationFeatureImportance(predictionTransformer: singleLightGbmModel,
                                                                                data: transformedData,
                                                                                labelColumnName: "fare_amount",  
                                                                                numberOfExamplesToUse: 100,
                                                                                permutationCount: 1);

Needing to only use/provide the last transformer feels a bit convoluted...
The API should be simpler to use here and make such a thing transparent to the user?

2. Second, once you get the permutation metrics (such as ImmutableArray<RegressionMetricsStatistics> permutationMetrics), you only get the values based on the indexes, but you don't have the names of the input columns. It is then not straightforward to correlate it to the input column names since you need to use the indexes to be used across two separated arrays that , if sorted previously, it won't match...

You need to do something like the following or comparable loops driven by the indexes in the permutationMetrics array:

First, obtain all the column names used in the PFI process and exclude the ones not used:

var usedColumnNamesInPFI = dataView.Schema
                    .Where(col => (col.Name != "SamplingKeyColumn") && (col.Name != "Features") && (col.Name != "Score"))
                    .Select(col => col.Name);

Then you need to correlate and find the column names based on the indexes in the permutationMetrics:

            var results = usedColumnNamesInPFI
                .Select((t, i) => new FeatureImportance
                {
                    Name = t,
                    RSquaredMean = Math.Abs(permutationMetrics[i].RSquared.Mean)
                })
                .OrderByDescending(x => x.RSquaredMean);

This should be directly provided by the API and you'd simply need to access it and show it.
The current code feels very much convoluted...

Metadata

Metadata

Assignees

No one assigned

    Labels

    APIIssues pertaining the friendly APIP2Priority of the issue for triage purpose: Needs to be fixed at some point.enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions