Skip to content

Commit 5910910

Browse files
Fixes #4292 about using PFI with BPT and CMPB (#4306)
*Changes in PredictionTransformer.cs and Calibrator.cs to fix the problem of the create methods not being called, to make CMP load its internal calibrator and predictor first so to assign the correct paramaters types and runtimes, and added a PredictionTransformerLoadTypeAttribute so that the binary prediction transformer knows what type to assign when loading a CMP as its internal model. *Added a working sample for using PFI with BPT and CMPB while loading a model from disk. This is based entirely in the original sample. *Added file CalibratedModelParametersTests.cs with tests that the CMPs modified in this PR are now being correctly loaded from disk. *Changed a couple of tests in LbfgsTests.cs that failed because they used casts that now return 'null'.
1 parent bcdac55 commit 5910910

File tree

6 files changed

+437
-49
lines changed

6 files changed

+437
-49
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using Microsoft.ML;
5+
using Microsoft.ML.Calibrators;
6+
using Microsoft.ML.Data;
7+
using Microsoft.ML.Trainers;
8+
9+
namespace Samples.Dynamic.Trainers.BinaryClassification
10+
{
11+
public static class PermutationFeatureImportanceLoadFromDisk
12+
{
13+
public static void Example()
14+
{
15+
16+
var mlContext = new MLContext(seed: 1);
17+
var samples = GenerateData();
18+
var data = mlContext.Data.LoadFromEnumerable(samples);
19+
20+
// Create pipeline
21+
var featureColumns =
22+
new string[] { nameof(Data.Feature1), nameof(Data.Feature2) };
23+
var pipeline = mlContext.Transforms
24+
.Concatenate("Features", featureColumns)
25+
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
26+
.Append(mlContext.BinaryClassification.Trainers.SdcaLogisticRegression()
27+
);
28+
29+
// Create and save model
30+
var model0 = pipeline.Fit(data);
31+
var lt = model0.LastTransformer;
32+
var modelPath = "./model.zip";
33+
mlContext.Model.Save(model0, data.Schema, modelPath);
34+
35+
// Load model
36+
var model = mlContext.Model.Load(modelPath, out var schema);
37+
38+
// Transform the dataset.
39+
var transformedData = model.Transform(data);
40+
41+
var linearPredictor = (model as TransformerChain<ITransformer>).LastTransformer as BinaryPredictionTransformer<CalibratedModelParametersBase<LinearBinaryModelParameters, PlattCalibrator>>;
42+
43+
// Execute PFI with the linearPredictor
44+
var permutationMetrics = mlContext.BinaryClassification
45+
.PermutationFeatureImportance(linearPredictor, transformedData,
46+
permutationCount: 30);
47+
48+
// Sort indices according to PFI results
49+
var sortedIndices = permutationMetrics
50+
.Select((metrics, index) => new { index, metrics.AreaUnderRocCurve })
51+
.OrderByDescending(
52+
feature => Math.Abs(feature.AreaUnderRocCurve.Mean))
53+
.Select(feature => feature.index);
54+
55+
Console.WriteLine("Feature\tModel Weight\tChange in AUC"
56+
+ "\t95% Confidence in the Mean Change in AUC");
57+
var auc = permutationMetrics.Select(x => x.AreaUnderRocCurve).ToArray();
58+
foreach (int i in sortedIndices)
59+
{
60+
Console.WriteLine("{0}\t{1:0.00}\t{2:G4}\t{3:G4}",
61+
featureColumns[i],
62+
linearPredictor.Model.SubModel.Weights[i], // this way we can access the weights inside the submodel
63+
auc[i].Mean,
64+
1.96 * auc[i].StandardError);
65+
}
66+
67+
// Expected output:
68+
// Feature Model Weight Change in AUC 95% Confidence in the Mean Change in AUC
69+
// Feature2 35.15 -0.387 0.002015
70+
// Feature1 17.94 -0.1514 0.0008963
71+
}
72+
73+
private class Data
74+
{
75+
public bool Label { get; set; }
76+
77+
public float Feature1 { get; set; }
78+
79+
public float Feature2 { get; set; }
80+
}
81+
82+
/// Generate Data
83+
private static IEnumerable<Data> GenerateData(int nExamples = 10000,
84+
double bias = 0, double weight1 = 1, double weight2 = 2, int seed = 1)
85+
{
86+
var rng = new Random(seed);
87+
for (int i = 0; i < nExamples; i++)
88+
{
89+
var data = new Data
90+
{
91+
Feature1 = (float)(rng.Next(10) * (rng.NextDouble() - 0.5)),
92+
Feature2 = (float)(rng.Next(10) * (rng.NextDouble() - 0.5)),
93+
};
94+
95+
// Create a noisy label.
96+
var value = (float)(bias + weight1 * data.Feature1 + weight2 *
97+
data.Feature2 + rng.NextDouble() - 0.5);
98+
99+
data.Label = Sigmoid(value) > 0.5;
100+
yield return data;
101+
}
102+
}
103+
104+
private static double Sigmoid(double x) => 1.0 / (1.0 + Math.Exp(-1 * x));
105+
}
106+
}

src/Microsoft.ML.Data/Prediction/Calibrator.cs

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using System.Collections.Immutable;
99
using System.IO;
1010
using System.Linq;
11+
using System.Reflection;
1112
using Microsoft.ML;
1213
using Microsoft.ML.Calibrators;
1314
using Microsoft.ML.CommandLine;
@@ -396,6 +397,7 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
396397
}
397398

398399
[BestFriend]
400+
[PredictionTransformerLoadType(typeof(CalibratedModelParametersBase<,>))]
399401
internal sealed class ValueMapperCalibratedModelParameters<TSubModel, TCalibrator> :
400402
ValueMapperCalibratedModelParametersBase<TSubModel, TCalibrator>, ICanSaveModel
401403
where TSubModel : class
@@ -430,8 +432,8 @@ private static VersionInfo GetVersionInfoBulk()
430432
loaderAssemblyName: typeof(ValueMapperCalibratedModelParameters<TSubModel, TCalibrator>).Assembly.FullName);
431433
}
432434

433-
private ValueMapperCalibratedModelParameters(IHostEnvironment env, ModelLoadContext ctx)
434-
: base(env, RegistrationName, GetPredictor(env, ctx), GetCalibrator(env, ctx))
435+
private ValueMapperCalibratedModelParameters(IHostEnvironment env, ModelLoadContext ctx, TSubModel predictor, TCalibrator calibrator)
436+
: base(env, RegistrationName, predictor, calibrator)
435437
{
436438
}
437439

@@ -443,7 +445,16 @@ private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelL
443445
var ver2 = GetVersionInfoBulk();
444446
var ver = ctx.Header.ModelSignature == ver2.ModelSignature ? ver2 : ver1;
445447
ctx.CheckAtModel(ver);
446-
return new ValueMapperCalibratedModelParameters<TSubModel, TCalibrator>(env, ctx);
448+
449+
// Load first the predictor and calibrator
450+
var predictor = GetPredictor(env, ctx);
451+
var calibrator = GetCalibrator(env, ctx);
452+
453+
// Create a generic type using the correct parameter types of predictor and calibrator
454+
Type genericType = typeof(ValueMapperCalibratedModelParameters<,>);
455+
var genericInstance = CreateCalibratedModelParameters.Create(env, ctx, predictor, calibrator, genericType);
456+
457+
return (CalibratedModelParametersBase)genericInstance;
447458
}
448459

449460
void ICanSaveModel.Save(ModelSaveContext ctx)
@@ -456,6 +467,7 @@ void ICanSaveModel.Save(ModelSaveContext ctx)
456467
}
457468

458469
[BestFriend]
470+
[PredictionTransformerLoadType(typeof(CalibratedModelParametersBase<,>))]
459471
internal sealed class FeatureWeightsCalibratedModelParameters<TSubModel, TCalibrator> :
460472
ValueMapperCalibratedModelParametersBase<TSubModel, TCalibrator>,
461473
IPredictorWithFeatureWeights<float>,
@@ -487,8 +499,9 @@ private static VersionInfo GetVersionInfo()
487499
loaderAssemblyName: typeof(FeatureWeightsCalibratedModelParameters<TSubModel, TCalibrator>).Assembly.FullName);
488500
}
489501

490-
private FeatureWeightsCalibratedModelParameters(IHostEnvironment env, ModelLoadContext ctx)
491-
: base(env, RegistrationName, GetPredictor(env, ctx), GetCalibrator(env, ctx))
502+
private FeatureWeightsCalibratedModelParameters(IHostEnvironment env, ModelLoadContext ctx,
503+
TSubModel predictor, TCalibrator calibrator)
504+
: base(env, RegistrationName, predictor, calibrator)
492505
{
493506
Host.Check(SubModel is IPredictorWithFeatureWeights<float>, "Predictor does not implement " + nameof(IPredictorWithFeatureWeights<float>));
494507
_featureWeights = (IPredictorWithFeatureWeights<float>)SubModel;
@@ -499,7 +512,16 @@ private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelL
499512
Contracts.CheckValue(env, nameof(env));
500513
env.CheckValue(ctx, nameof(ctx));
501514
ctx.CheckAtModel(GetVersionInfo());
502-
return new FeatureWeightsCalibratedModelParameters<TSubModel, TCalibrator>(env, ctx);
515+
516+
// Load first the predictor and calibrator
517+
var predictor = GetPredictor(env, ctx);
518+
var calibrator = GetCalibrator(env, ctx);
519+
520+
// Create a generic type using the correct parameter types of predictor and calibrator
521+
Type genericType = typeof(FeatureWeightsCalibratedModelParameters<,>);
522+
var genericInstance = CreateCalibratedModelParameters.Create(env, ctx, predictor, calibrator, genericType);
523+
524+
return (CalibratedModelParametersBase) genericInstance;
503525
}
504526

505527
void ICanSaveModel.Save(ModelSaveContext ctx)
@@ -520,6 +542,7 @@ public void GetFeatureWeights(ref VBuffer<float> weights)
520542
/// Encapsulates a predictor and a calibrator that implement <see cref="IParameterMixer"/>.
521543
/// Its implementation of <see cref="IParameterMixer.CombineParameters"/> combines both the predictors and the calibrators.
522544
/// </summary>
545+
[PredictionTransformerLoadType(typeof(CalibratedModelParametersBase <,>))]
523546
internal sealed class ParameterMixingCalibratedModelParameters<TSubModel, TCalibrator> :
524547
ValueMapperCalibratedModelParametersBase<TSubModel, TCalibrator>,
525548
IParameterMixer<float>,
@@ -553,8 +576,8 @@ private static VersionInfo GetVersionInfo()
553576
loaderAssemblyName: typeof(ParameterMixingCalibratedModelParameters<TSubModel, TCalibrator>).Assembly.FullName);
554577
}
555578

556-
private ParameterMixingCalibratedModelParameters(IHostEnvironment env, ModelLoadContext ctx)
557-
: base(env, RegistrationName, GetPredictor(env, ctx), GetCalibrator(env, ctx))
579+
private ParameterMixingCalibratedModelParameters(IHostEnvironment env, ModelLoadContext ctx, TSubModel predictor, TCalibrator calibrator)
580+
: base(env, RegistrationName, predictor, calibrator)
558581
{
559582
Host.Check(SubModel is IParameterMixer<float>, "Predictor does not implement " + nameof(IParameterMixer));
560583
Host.Check(SubModel is IPredictorWithFeatureWeights<float>, "Predictor does not implement " + nameof(IPredictorWithFeatureWeights<float>));
@@ -566,7 +589,16 @@ private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelL
566589
Contracts.CheckValue(env, nameof(env));
567590
env.CheckValue(ctx, nameof(ctx));
568591
ctx.CheckAtModel(GetVersionInfo());
569-
return new ParameterMixingCalibratedModelParameters<TSubModel, TCalibrator>(env, ctx);
592+
593+
// Load first the predictor and calibrator
594+
var predictor = GetPredictor(env, ctx);
595+
var calibrator = GetCalibrator(env, ctx);
596+
597+
// Create a generic type using the correct parameter types of predictor and calibrator
598+
Type genericType = typeof(ParameterMixingCalibratedModelParameters<,>);
599+
object genericInstance = CreateCalibratedModelParameters.Create(env, ctx, predictor, calibrator, genericType);
600+
601+
return (CalibratedModelParametersBase) genericInstance;
570602
}
571603

572604
void ICanSaveModel.Save(ModelSaveContext ctx)
@@ -777,6 +809,28 @@ ValueMapper<TSrc, VBuffer<float>> IFeatureContributionMapper.GetFeatureContribut
777809
}
778810
}
779811

812+
internal static class CreateCalibratedModelParameters
813+
{
814+
internal static object Create(IHostEnvironment env, ModelLoadContext ctx, object predictor, ICalibrator calibrator, Type calibratedModelParametersType)
815+
{
816+
Type[] genericTypeArgs = { predictor.GetType(), calibrator.GetType() };
817+
Type constructed = calibratedModelParametersType.MakeGenericType(genericTypeArgs);
818+
819+
Type[] constructorArgs = {
820+
typeof(IHostEnvironment),
821+
typeof(ModelLoadContext),
822+
predictor.GetType(),
823+
calibrator.GetType()
824+
};
825+
826+
// Call the appropiate constructor of the created generic type passing on the previously loaded predictor and calibrator
827+
var genericCtor = constructed.GetConstructor(BindingFlags.NonPublic | BindingFlags.Instance, null, constructorArgs, null);
828+
object genericInstance = genericCtor.Invoke(new object[] { env, ctx, predictor, calibrator });
829+
830+
return genericInstance;
831+
}
832+
}
833+
780834
[BestFriend]
781835
internal static class CalibratorUtils
782836
{

0 commit comments

Comments
 (0)