Skip to content

Commit

Permalink
Adds PriorTrainer Onnx conversion (#4515)
Browse files Browse the repository at this point in the history
* Onnx conversion for priot

* resolving comments

* resolved comments

* resolving comments
  • Loading branch information
Lynx1820 authored Dec 17, 2019
1 parent f1f8942 commit dc7ddb4
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 8 deletions.
14 changes: 13 additions & 1 deletion src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ public sealed class BinaryPredictionTransformer<TModel> : SingleFeaturePredictio
{
internal readonly string ThresholdColumn;
internal readonly float Threshold;
internal readonly string LabelColumnName;

[BestFriend]
internal BinaryPredictionTransformer(IHostEnvironment env, TModel model, DataViewSchema inputSchema, string featureColumn,
Expand All @@ -383,6 +384,17 @@ internal BinaryPredictionTransformer(IHostEnvironment env, TModel model, DataVie
SetScorer();
}

internal BinaryPredictionTransformer(IHostEnvironment env, TModel model, DataViewSchema inputSchema, string featureColumn, string labelColumn,
float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer<TModel>)), model, inputSchema, featureColumn)
{
Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn));
Threshold = threshold;
ThresholdColumn = thresholdColumn;
LabelColumnName = labelColumn;

SetScorer();
}
internal BinaryPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer<TModel>)), ctx)
{
Expand All @@ -409,7 +421,7 @@ private void InitializationLogic(ModelLoadContext ctx, out float threshold, out

private void SetScorer()
{
var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumnName);
var schema = new RoleMappedSchema(TrainSchema, LabelColumnName, FeatureColumnName);
var args = new BinaryClassifierScorer.Arguments { Threshold = Threshold, ThresholdColumn = ThresholdColumn };
Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
}
Expand Down
11 changes: 9 additions & 2 deletions src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -426,9 +426,16 @@ private protected override bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema

var mapper = ValueMapper as ISingleCanSaveOnnx;
Contracts.CheckValue(mapper, nameof(mapper));
Contracts.Assert(schema.Feature.HasValue);
Contracts.Assert(Utils.Size(outputNames) == 3); // Predicted Label, Score and Probablity.

// Prior doesn't have a feature column and uses the training label column to determine predicted labels
if (!schema.Feature.HasValue)
{
Contracts.Assert(schema.Label.HasValue);
var labelColumnName = schema.Label.Value.Name;
return mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(labelColumnName));
}

var featName = schema.Feature.Value.Name;
if (!ctx.ContainsColumn(featName))
return false;
Expand Down Expand Up @@ -511,7 +518,7 @@ public CalibratedRowMapper(RoleMappedSchema schema, SchemaBindableBinaryPredicto

public IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles()
{
yield return RoleMappedSchema.ColumnRole.Feature.Bind(InputRoleMappedSchema.Feature?.Name);
yield return (InputRoleMappedSchema.Feature.HasValue) ? RoleMappedSchema.ColumnRole.Feature.Bind(InputRoleMappedSchema.Feature?.Name) : RoleMappedSchema.ColumnRole.Label.Bind(InputRoleMappedSchema.Label?.Name);
}

private Delegate[] CreateGetters(DataViewRow input, bool[] active)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;

Expand Down Expand Up @@ -240,9 +241,9 @@ internal PriorTrainer(IHostEnvironment env, String labelColumn, String weightCol
/// </summary>
public BinaryPredictionTransformer<PriorModelParameters> Fit(IDataView input)
{
RoleMappedData trainRoles = new RoleMappedData(input, feature: null, label: _labelColumnName, weight: _weightColumnName);
RoleMappedData trainRoles = new RoleMappedData(input, label: _labelColumnName, feature: null, weight: _weightColumnName);
var pred = ((ITrainer<PriorModelParameters>)this).Train(new TrainContext(trainRoles));
return new BinaryPredictionTransformer<PriorModelParameters>(_host, pred, input.Schema, featureColumn: null);
return new BinaryPredictionTransformer<PriorModelParameters>(_host, pred, input.Schema, featureColumn: null, labelColumn: _labelColumnName);
}

private PriorModelParameters Train(TrainContext context)
Expand Down Expand Up @@ -330,7 +331,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
public sealed class PriorModelParameters :
ModelParametersBase<float>,
IDistPredictorProducing<float, float>,
IValueMapperDist
IValueMapperDist, ISingleCanSaveOnnx
{
internal const string LoaderSignature = "PriorPredictor";
private static VersionInfo GetVersionInfo()
Expand All @@ -346,6 +347,7 @@ private static VersionInfo GetVersionInfo()

private readonly float _prob;
private readonly float _raw;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;

/// <summary>
/// Instantiates a model that returns the prior probability of the positive class in the training set.
Expand Down Expand Up @@ -397,6 +399,38 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.Writer.Write(_prob);
}

bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputs, string labelColumn)
{
Host.CheckValue(ctx, nameof(ctx));
Host.Check(Utils.Size(outputs) >= 3);

string scoreVarName = outputs[1];
string probVarName = outputs[2];
var prob = ctx.AddInitializer(_prob, "probability");
var score = ctx.AddInitializer(_raw, "score");

var xorOutput = ctx.AddIntermediateVariable(null, "XorOutput", true);
string opType = "Xor";
ctx.CreateNode(opType, new[] { labelColumn, labelColumn }, new[] { xorOutput }, ctx.GetNodeName(opType), "");

var notOutput = ctx.AddIntermediateVariable(null, "NotOutput", true);
opType = "Not";
ctx.CreateNode(opType, xorOutput, notOutput, ctx.GetNodeName(opType), "");

var castOutput = ctx.AddIntermediateVariable(null, "CastOutput", true);
opType = "Cast";
var node = ctx.CreateNode(opType, notOutput, castOutput, ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
node.AddAttribute("to", t);

opType = "Mul";
ctx.CreateNode(opType, new[] { castOutput, prob }, new[] { probVarName }, ctx.GetNodeName(opType), "");

opType = "Mul";
ctx.CreateNode(opType, new[] { castOutput, score }, new[] { scoreVarName }, ctx.GetNodeName(opType), "");
return true;
}

private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;

private readonly DataViewType _inputType;
Expand Down
5 changes: 3 additions & 2 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ public void BinaryClassificationTrainersOnnxConversionTest()
mlContext.BinaryClassification.Trainers.FastTree(),
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(),
mlContext.BinaryClassification.Trainers.LinearSvm(),
mlContext.BinaryClassification.Trainers.Prior(),
mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(),
mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(),
mlContext.BinaryClassification.Trainers.SgdCalibrated(),
Expand Down Expand Up @@ -301,8 +302,8 @@ public void BinaryClassificationTrainersOnnxConversionTest()
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedR4ScalarColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3);
CompareSelectedScalarColumns<Boolean>(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult);
CompareSelectedR4ScalarColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3); //compare scores
CompareSelectedScalarColumns<Boolean>(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult); //compare predicted labels
}
}
Done();
Expand Down

0 comments on commit dc7ddb4

Please sign in to comment.