Skip to content

API Proposal: Update PFI API to be easier to use #5625

Closed
@JakeRadMSFT

Description

@JakeRadMSFT

Background and Motivation

The current PFI API is difficult to use. We've had a few issues opened to make it easier but we can use this issue to track a proposed API.

Prior Issue:
#4216

Example Support Issue to help developers use it:
dotnet/machinelearning-modelbuilder#1031 (comment)

The main issue with the API is that it returns an array and it's not easy to get back to the column name/feature name from the index.

VBuffer<ReadOnlyMemory<char>> nameBuffer = default;
            preprocessedTrainData.Schema["Features"].Annotations.GetValue("SlotNames", ref nameBuffer); // NOTE: The column name "Features" needs to match the featureColumnName used in the trainer, the name "SlotNames" is always the same regardless of trainer.
            var featureColumnNames = nameBuffer.DenseValues().ToList();

The second biggest issue (which actually comes earlier in the process :). Is that it's hard to know what to pass for ISingleFeaturePredictionTransformer argument. Perhaps this is something we can figure out how to extract for them from the training pipeline?

// Option 1: to extract predictor, requires to know the type in advance:
            // var predictor = ((TransformerChain<RegressionPredictionTransformer<LightGbmRegressionModelParameters>>)mlModel).LastTransformer;

            // Option 2: Should work always, as long as you _know_ the predictor is the last transformer in the chain.
            var predictor = ((IEnumerable<ITransformer>)mlModel).Last();

            // Option 3, need to load from disk the model first
            //var path = @"C:\Users\anvelazq\Desktop\PfiSample\model.zip";
            //mlContext.Model.Save(mlModel, trainingDataView.Schema, path);
            //var mlModel2 = mlContext.Model.Load(path, out var _);
            //var predictor = ((TransformerChain<ITransformer>) mlModel2).LastTransformer;

If we can do that ... then we can just take in "Microsoft.ML.IEstimator<Microsoft.ML.ITransformer> estimator" similar to the CrossValidate APIs.

Proposed API

namespace Microsoft.ML
{
     public static class PermutationFeatureImportanceExtensions {

     public static System.Collections.Immutable.ImmutableArray<Microsoft.ML.Data.RegressionMetricsStatistics> PermutationFeatureImportance<TModel> (this Microsoft.ML.RegressionCatalog catalog, Microsoft.ML.ISingleFeaturePredictionTransformer<TModel> predictionTransformer, Microsoft.ML.IDataView data, string labelColumnName = "Label", bool useFeatureWeightFilter = false, int? numberOfExamplesToUse = default, int permutationCount = 1) where TModel : class;
+    public static System.Collections.Dictionary<string, Microsoft.ML.Data.RegressionMetricsStatistics> PermutationFeatureImportance<TModel> (this Microsoft.ML.RegressionCatalog catalog, Microsoft.ML.IEstimator<Microsoft.ML.ITransformer> estimator, Microsoft.ML.IDataView data, string labelColumnName = "Label", bool useFeatureWeightFilter = false, int? numberOfExamplesToUse = default, int permutationCount = 1) where TModel : class;
     }

You may find the Framework Design Guidelines helpful.

Usage Examples

This is how it works today: dotnet/machinelearning-modelbuilder#1031 (comment)
Below is how I think it should work. The key things to note is the similarities to CrossValidate API.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using PfiSample.Model;
using System.Collections.Immutable;

namespace PfiSample.ConsoleApp
{
    public static class ModelBuilder
    {
        private static string TRAIN_DATA_FILEPATH = @"C:\Users\anvelazq\Desktop\PfiSample\PfiSample.ConsoleApp\taxi-fare-train.csv";
        private static string MODEL_FILE = ConsumeModel.MLNetModelPath;

        // Create MLContext to be shared across the model creation workflow objects 
        // Set a random seed for repeatable/deterministic results across multiple trainings.
        private static MLContext mlContext = new MLContext(seed: 1);

        public static void CreateModel()
        {
            // Load Data
            IDataView trainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>(
                                            path: TRAIN_DATA_FILEPATH,
                                            hasHeader: true,
                                            separatorChar: ',',
                                            allowQuoting: true,
                                            allowSparse: false);

            // Build training pipeline and Train Model

            // Data process configuration with pipeline data transformations 
            var dataProcessPipeline = mlContext.Transforms.Categorical.OneHotEncoding(new[] { new InputOutputColumnPair("vendor_id", "vendor_id"), new InputOutputColumnPair("payment_type", "payment_type") })
                                      .Append(mlContext.Transforms.Concatenate("Features", new[] { "vendor_id", "payment_type", "rate_code", "passenger_count", "trip_time_in_secs", "trip_distance" }));
            // Set the training algorithm 
            var trainer = mlContext.Regression.Trainers.LightGbm(labelColumnName: "fare_amount", featureColumnName: "Features");

            IEstimator<ITransformer> trainingPipeline = dataProcessPipeline.Append(trainer);
           
           ITransformer model = trainingPipeline.Fit(trainingDataView);
            
            // Calculate PFI
            CalculatePFI(mlContext, trainingDataView, trainingPipeline);
            
            // Evaluate quality of Model
            Evaluate(mlContext, trainingDataView, trainingPipeline);

            // Save model
            SaveModel(mlContext, mlModel, MODEL_FILE, trainingDataView.Schema);
        }


        private static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
        {
            // Cross-Validate with single dataset (since we don't have two datasets, one for training and for evaluate)
            // in order to evaluate and get the model's accuracy metrics
            Console.WriteLine("=============== Cross-validating to get model's accuracy metrics ===============");
            var crossValidationResults = mlContext.Regression.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: 5, labelColumnName: "fare_amount");
            PrintRegressionFoldsAverageMetrics(crossValidationResults);
        }

        private static void CalculatePFI(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
        {
            

            Dictionary<string, RegressionMetricsStatistics> permutationFeatureImportance =
                mlContext
                .Regression
                .PermutationFeatureImportance(trainingPipeline, trainingDataView, permutationCount: 1, labelColumnName: "fare_amount");



            Console.WriteLine("Feature\tPFI");
            foreach (KeyValuePair<string, RegressionMetricsStatistics> entry in permutationFeatureImportance )
            {
                Console.WriteLine($"{entry.Key}\t{entry.Value.RSquared.Mean:F6}");
            }
        }

        private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema)
        {
            // Save/persist the trained model to a .ZIP file
            Console.WriteLine($"=============== Saving the model  ===============");
            mlContext.Model.Save(mlModel, modelInputSchema, GetAbsolutePath(modelRelativePath));
            Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath));
        }

        public static string GetAbsolutePath(string relativePath)
        {
            FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
            string assemblyFolderPath = _dataRoot.Directory.FullName;

            string fullPath = Path.Combine(assemblyFolderPath, relativePath);

            return fullPath;
        }
    }
}

Alternative Designs

If there is any opposition or technical challenges for making PFI have a similar API to CrossValidate ... I'm open to alternatives but I don't know the ML.NET APIs well enough to come up with other patterns.

Risks

I think the biggest risk/challenge is that folks can do a lot of things with pipelines and models to make them incompatible with PFI. I believe it takes exponentially longer to calculate PFI relative to number of columns. Certain things like OneHotHash can create hundreds of columns ...

Metadata

Metadata

Assignees

No one assigned

    Labels

    APIIssues pertaining the friendly APIenhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions