Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds PriorTrainer Onnx conversion #4515

Merged
merged 4 commits into from
Dec 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ public sealed class BinaryPredictionTransformer<TModel> : SingleFeaturePredictio
{
internal readonly string ThresholdColumn;
internal readonly float Threshold;
internal readonly string LabelColumnName;

[BestFriend]
internal BinaryPredictionTransformer(IHostEnvironment env, TModel model, DataViewSchema inputSchema, string featureColumn,
Expand All @@ -383,6 +384,17 @@ internal BinaryPredictionTransformer(IHostEnvironment env, TModel model, DataVie
SetScorer();
}

internal BinaryPredictionTransformer(IHostEnvironment env, TModel model, DataViewSchema inputSchema, string featureColumn, string labelColumn,
Lynx1820 marked this conversation as resolved.
Show resolved Hide resolved
float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer<TModel>)), model, inputSchema, featureColumn)
{
Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn));
Threshold = threshold;
ThresholdColumn = thresholdColumn;
LabelColumnName = labelColumn;

SetScorer();
}
internal BinaryPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer<TModel>)), ctx)
{
Expand All @@ -409,7 +421,7 @@ private void InitializationLogic(ModelLoadContext ctx, out float threshold, out

private void SetScorer()
{
var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumnName);
var schema = new RoleMappedSchema(TrainSchema, LabelColumnName, FeatureColumnName);
var args = new BinaryClassifierScorer.Arguments { Threshold = Threshold, ThresholdColumn = ThresholdColumn };
Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
}
Expand Down
11 changes: 9 additions & 2 deletions src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -426,9 +426,16 @@ private protected override bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema

var mapper = ValueMapper as ISingleCanSaveOnnx;
Contracts.CheckValue(mapper, nameof(mapper));
Contracts.Assert(schema.Feature.HasValue);
Contracts.Assert(Utils.Size(outputNames) == 3); // Predicted Label, Score and Probablity.

// Prior doesn't have a feature column and uses the training label column to determine predicted labels
if (!schema.Feature.HasValue)
{
Contracts.Assert(schema.Label.HasValue);
Lynx1820 marked this conversation as resolved.
Show resolved Hide resolved
var labelColumnName = schema.Label.Value.Name;
return mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(labelColumnName));
Lynx1820 marked this conversation as resolved.
Show resolved Hide resolved
}

var featName = schema.Feature.Value.Name;
if (!ctx.ContainsColumn(featName))
return false;
Expand Down Expand Up @@ -511,7 +518,7 @@ public CalibratedRowMapper(RoleMappedSchema schema, SchemaBindableBinaryPredicto

public IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> GetInputColumnRoles()
{
yield return RoleMappedSchema.ColumnRole.Feature.Bind(InputRoleMappedSchema.Feature?.Name);
yield return (InputRoleMappedSchema.Feature.HasValue) ? RoleMappedSchema.ColumnRole.Feature.Bind(InputRoleMappedSchema.Feature?.Name) : RoleMappedSchema.ColumnRole.Label.Bind(InputRoleMappedSchema.Label?.Name);
}

private Delegate[] CreateGetters(DataViewRow input, bool[] active)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;

Expand Down Expand Up @@ -240,9 +241,9 @@ internal PriorTrainer(IHostEnvironment env, String labelColumn, String weightCol
/// </summary>
public BinaryPredictionTransformer<PriorModelParameters> Fit(IDataView input)
{
RoleMappedData trainRoles = new RoleMappedData(input, feature: null, label: _labelColumnName, weight: _weightColumnName);
RoleMappedData trainRoles = new RoleMappedData(input, label: _labelColumnName, feature: null, weight: _weightColumnName);
var pred = ((ITrainer<PriorModelParameters>)this).Train(new TrainContext(trainRoles));
return new BinaryPredictionTransformer<PriorModelParameters>(_host, pred, input.Schema, featureColumn: null);
return new BinaryPredictionTransformer<PriorModelParameters>(_host, pred, input.Schema, featureColumn: null, labelColumn: _labelColumnName);
}

private PriorModelParameters Train(TrainContext context)
Expand Down Expand Up @@ -330,7 +331,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
public sealed class PriorModelParameters :
ModelParametersBase<float>,
IDistPredictorProducing<float, float>,
IValueMapperDist
IValueMapperDist, ISingleCanSaveOnnx
{
internal const string LoaderSignature = "PriorPredictor";
private static VersionInfo GetVersionInfo()
Expand All @@ -346,6 +347,7 @@ private static VersionInfo GetVersionInfo()

private readonly float _prob;
private readonly float _raw;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;

/// <summary>
/// Instantiates a model that returns the prior probability of the positive class in the training set.
Expand Down Expand Up @@ -397,6 +399,38 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.Writer.Write(_prob);
}

bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputs, string labelColumn)
{
Lynx1820 marked this conversation as resolved.
Show resolved Hide resolved
Host.CheckValue(ctx, nameof(ctx));
Host.Check(Utils.Size(outputs) >= 3);

string scoreVarName = outputs[1];
string probVarName = outputs[2];
var prob = ctx.AddInitializer(_prob, "probability");
var score = ctx.AddInitializer(_raw, "score");

var xorOutput = ctx.AddIntermediateVariable(null, "XorOutput", true);
string opType = "Xor";
ctx.CreateNode(opType, new[] { labelColumn, labelColumn }, new[] { xorOutput }, ctx.GetNodeName(opType), "");

var notOutput = ctx.AddIntermediateVariable(null, "NotOutput", true);
opType = "Not";
ctx.CreateNode(opType, xorOutput, notOutput, ctx.GetNodeName(opType), "");

var castOutput = ctx.AddIntermediateVariable(null, "CastOutput", true);
opType = "Cast";
var node = ctx.CreateNode(opType, notOutput, castOutput, ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
node.AddAttribute("to", t);

opType = "Mul";
ctx.CreateNode(opType, new[] { castOutput, prob }, new[] { probVarName }, ctx.GetNodeName(opType), "");

opType = "Mul";
ctx.CreateNode(opType, new[] { castOutput, score }, new[] { scoreVarName }, ctx.GetNodeName(opType), "");
return true;
}

private protected override PredictionKind PredictionKind => PredictionKind.BinaryClassification;

private readonly DataViewType _inputType;
Expand Down
5 changes: 3 additions & 2 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ public void BinaryClassificationTrainersOnnxConversionTest()
mlContext.BinaryClassification.Trainers.FastTree(),
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(),
mlContext.BinaryClassification.Trainers.LinearSvm(),
mlContext.BinaryClassification.Trainers.Prior(),
mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(),
mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(),
mlContext.BinaryClassification.Trainers.SgdCalibrated(),
Expand Down Expand Up @@ -301,8 +302,8 @@ public void BinaryClassificationTrainersOnnxConversionTest()
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedR4ScalarColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3);
CompareSelectedScalarColumns<Boolean>(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult);
CompareSelectedR4ScalarColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3); //compare scores
CompareSelectedScalarColumns<Boolean>(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult); //compare predicted labels
}
}
Done();
Expand Down