Skip to content

Allow user to save PredictorTransform in file and then convert it to … #3986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions src/Microsoft.ML.OnnxConverter/SaveOnnxCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,20 @@ public sealed class Arguments : DataCommand.ArgumentsBase
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Whether we should attempt to load the predictor and attach the scorer to the pipeline if one is present.", ShortName = "pred", SortOrder = 9)]
public bool? LoadPredictor;

[Argument(ArgumentType.Required, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 10)]
/// <summary>
/// Entry point API can save either <see cref="TransformModel"/> or <see cref="PredictorModel"/>.
/// <see cref="Model"/> is used when the saved model is typed to <see cref="TransformModel"/>.
/// </summary>
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Model that needs to be converted to ONNX format.", SortOrder = 10)]
public TransformModel Model;

/// <summary>
/// Entry point API can save either <see cref="TransformModel"/> or <see cref="PredictorModel"/>.
/// <see cref="PredictiveModel"/> is used when the saved model is typed to <see cref="PredictorModel"/>.
/// </summary>
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, HelpText = "Predictor model that needs to be converted to ONNX format.", SortOrder = 12)]
public PredictorModel PredictiveModel;

[Argument(ArgumentType.AtMostOnce, HelpText = "The targeted ONNX version. It can be either \"Stable\" or \"Experimental\". If \"Experimental\" is used, produced model can contain components that is not officially supported in ONNX standard.", SortOrder = 11)]
public OnnxVersion OnnxVersion;
}
Expand All @@ -72,6 +83,7 @@ public sealed class Arguments : DataCommand.ArgumentsBase
private readonly HashSet<string> _inputsToDrop;
private readonly HashSet<string> _outputsToDrop;
private readonly TransformModel _model;
private readonly PredictorModel _predictiveModel;
private const string ProducerName = "ML.NET";
private const long ModelVersion = 0;

Expand All @@ -96,7 +108,13 @@ public SaveOnnxCommand(IHostEnvironment env, Arguments args)
_inputsToDrop = CreateDropMap(args.InputsToDropArray ?? args.InputsToDrop?.Split(','));
_outputsToDrop = CreateDropMap(args.OutputsToDropArray ?? args.OutputsToDrop?.Split(','));
_domain = args.Domain;

if (args.Model != null && args.PredictiveModel != null)
throw env.Except(nameof(args.Model) + " and " + nameof(args.PredictiveModel) +
" cannot be specified at the same time when calling ONNX converter. Please check the content of " + nameof(args) + ".");

_model = args.Model;
_predictiveModel = args.PredictiveModel;
}

private static HashSet<string> CreateDropMap(string[] toDrop)
Expand Down Expand Up @@ -198,7 +216,7 @@ private void Run(IChannel ch)
IDataView view;
RoleMappedSchema trainSchema = null;

if (_model == null)
if (_model == null && _predictiveModel == null)
{
if (string.IsNullOrEmpty(ImplOptions.InputModelFile))
{
Expand All @@ -213,8 +231,16 @@ private void Run(IChannel ch)

view = loader;
}
else
else if (_model != null)
{
view = _model.Apply(Host, new EmptyDataView(Host, _model.InputSchema));
}
else
{
view = _predictiveModel.TransformModel.Apply(Host, new EmptyDataView(Host, _predictiveModel.TransformModel.InputSchema));
rawPred = _predictiveModel.Predictor;
trainSchema = _predictiveModel.GetTrainingSchema(Host);
}

// Create the ONNX context for storing global information
var assembly = System.Reflection.Assembly.GetExecutingAssembly();
Expand Down
14 changes: 12 additions & 2 deletions test/BaselineOutput/Common/EntryPoints/core_manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -2275,9 +2275,10 @@
"Name": "Model",
"Type": "TransformModel",
"Desc": "Model that needs to be converted to ONNX format.",
"Required": true,
"Required": false,
"SortOrder": 10.0,
"IsNullable": false
"IsNullable": false,
"Default": null
},
{
"Name": "OnnxVersion",
Expand All @@ -2293,6 +2294,15 @@
"SortOrder": 11.0,
"IsNullable": false,
"Default": "Stable"
},
{
"Name": "PredictiveModel",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PredictiveModel [](start = 19, length = 15)

Can it be named PredictorModel ?

"Type": "PredictorModel",
"Desc": "Predictor model that needs to be converted to ONNX format.",
"Required": false,
"SortOrder": 12.0,
"IsNullable": false,
"Default": null
}
],
"Outputs": []
Expand Down
126 changes: 124 additions & 2 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using System.Text.RegularExpressions;
using Google.Protobuf;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.RunTests;
using Microsoft.ML.Runtime;
Expand Down Expand Up @@ -186,7 +187,7 @@ void CommandLineOnnxConversionTest()
string modelPath = GetOutputPath("ModelWithLessIO.zip");
var trainingPathArgs = $"data={dataPath} out={modelPath}";
var trainingArgs = " loader=text{col=Label:BL:0 col=F1:R4:1-8 col=F2:TX:9} xf=Cat{col=F2} xf=Concat{col=Features:F1,F2} tr=ft{numberOfThreads=1 numberOfLeaves=8 numberOfTrees=3} seed=1";
Assert.Equal(0, Maml.Main(new[] { "train " + trainingPathArgs + trainingArgs}));
Assert.Equal(0, Maml.Main(new[] { "train " + trainingPathArgs + trainingArgs }));

var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "BinaryClassification", "BreastCancer");
var onnxTextName = "ModelWithLessIO.txt";
Expand Down Expand Up @@ -403,6 +404,127 @@ public void MulticlassLogisticRegressionOnnxConversionTest()
Done();
}

[Fact]
public void LoadingPredictorModelAndOnnxConversionTest()
{
string dataPath = GetDataPath("iris.txt");
string modelPath = Path.GetTempPath() + Guid.NewGuid().ToString() + ".model.bin";
string onnxPath = Path.GetTempPath() + Guid.NewGuid().ToString() + ".model.onnx";
string onnxJsonPath = Path.GetTempPath() + Guid.NewGuid().ToString() + ".model.onnx.json";

string inputGraph = string.Format(@"
{{
'Inputs': {{
'inputFile': '{0}'
}},
'Nodes': [
{{
'Name': 'Data.TextLoader',
'Inputs':
{{
'InputFile': '$inputFile',
'Arguments':
{{
'UseThreads': true,
'HeaderFile': null,
'MaxRows': null,
'AllowQuoting': true,
'AllowSparse': true,
'InputSize': null,
'TrimWhitespace': false,
'HasHeader': false,
'Column':
[
{{'Name':'Sepal_Width','Type':null,'Source':[{{'Min':2,'Max':2,'AutoEnd':false,'VariableEnd':false,'AllOther':false,'ForceVector':false}}],'KeyCount':null}},
{{'Name':'Petal_Length','Type':null,'Source':[{{'Min':3,'Max':4,'AutoEnd':false,'VariableEnd':false,'AllOther':false,'ForceVector':false}}],'KeyCount':null}},
]
}}
}},
'Outputs':
{{
'Data': '$training_data'
}}
}},
{{
'Inputs': {{
'FeatureColumnName': 'Petal_Length',
'LabelColumnName': 'Sepal_Width',
'TrainingData': '$training_data',
}},
'Name': 'Trainers.StochasticDualCoordinateAscentRegressor',
'Outputs': {{
'PredictorModel': '$output_model'
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is typed to PredictorModel, which cannot be loaded correctly before.

}}
}}
],
'Outputs': {{
'output_model': '{1}'
}}
}}", dataPath.Replace("\\", "\\\\"), modelPath.Replace("\\", "\\\\"));

// Write entry point graph into file so that it can be invoke by graph runner below.
var jsonPath = DeleteOutputPath("graph.json");
File.WriteAllLines(jsonPath, new[] { inputGraph });

// Execute the saved entry point graph to produce a predictive model.
var args = new ExecuteGraphCommand.Arguments() { GraphPath = jsonPath };
var cmd = new ExecuteGraphCommand(Env, args);
cmd.Run();

// Make entry point graph to conduct ONNX conversion.
inputGraph = string.Format(@"
{{
'Inputs': {{
'model': '{0}'
}},
'Nodes': [
{{
'Inputs': {{
'Domain': 'com.microsoft.models',
'Json': '{1}',
'PredictiveModel': '$model',
'Onnx': '{2}',
'OnnxVersion': 'Experimental'
}},
'Name': 'Models.OnnxConverter',
'Outputs': {{}}
}}
],
'Outputs': {{}}
}}
", modelPath.Replace("\\", "\\\\"), onnxJsonPath.Replace("\\", "\\\\"), onnxPath.Replace("\\", "\\\\"));

// Write entry point graph for ONNX conversion into file so that it can be invoke by graph runner below.
jsonPath = DeleteOutputPath("graph.json");
File.WriteAllLines(jsonPath, new[] { inputGraph });

// Onnx converter's assembly is not loaded by default, so we need to register it before calling it.
Env.ComponentCatalog.RegisterAssembly(typeof(OnnxExportExtensions).Assembly);

// Execute the saved entry point graph to convert the saved model to ONNX format.
args = new ExecuteGraphCommand.Arguments() { GraphPath = jsonPath };
cmd = new ExecuteGraphCommand(Env, args);
cmd.Run();

// Load the resulted ONNX model from the file so that we can check if the conversion looks good.
var model = new OnnxCSharpToProtoWrapper.ModelProto();
using (var modelStream = File.OpenRead(onnxPath))
model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(modelStream);

// Make sure a PredictorModel is loaded by seeing if a predictive model exists. In this the
// predictive model is "LinearRegressor" (converted from StochasticDualCoordinateAscentRegressor
// in the original training entry-point graph.
Assert.Equal("Scaler", model.Graph.Node[0].OpType);
Assert.Equal("LinearRegressor", model.Graph.Node[1].OpType);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assert.Equal("LinearRegressor", model.Graph.Node[1].OpType); [](start = 11, length = 61)

Is it possible to run original model and converted models and see if results are the same?


File.Delete(modelPath);
File.Delete(onnxPath);
File.Delete(onnxJsonPath);

Done();
}


[Fact]
public void RemoveVariablesInPipelineTest()
{
Expand Down Expand Up @@ -451,7 +573,7 @@ public void RemoveVariablesInPipelineTest()

private class SmallSentimentExample
{
[LoadColumn(0,3), VectorType(4)]
[LoadColumn(0, 3), VectorType(4)]
public string[] Tokens;
}

Expand Down