Skip to content

Commit

Permalink
Support onnx export with previous OpSet version (#5176)
Browse files Browse the repository at this point in the history
* initial checkin

* fix

* fix build error

* change exception capture method

* print

* seperate DEBUG mode in test

* fix one build error

* use throw instead of assert for consistent test

* apply exception to all

* more fix

* review comments

* review comments2

* review part of comments

* review comments

* resolve comments

* revert files

* add description
  • Loading branch information
wangyems authored Jun 3, 2020
1 parent 1ea2b47 commit d1bf425
Show file tree
Hide file tree
Showing 44 changed files with 251 additions and 24 deletions.
6 changes: 6 additions & 0 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ internal abstract class OnnxContext
/// <returns>Whether the column is mapped in this context</returns>
public abstract bool ContainsColumn(string colName);

/// <summary>
/// Check the required OpSet version satisfies our requirement
/// </summary>
/// <returns></returns>
public abstract void CheckOpSetVersion(int thisTransformerMinumumOpSetVersion, string registerTransformerName);

/// <summary>
/// Stops tracking a column.
/// </summary>
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1745,6 +1745,9 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColu
_host.CheckValue(scoreProbablityColumnNames, nameof(scoreProbablityColumnNames));
_host.Check(Utils.Size(scoreProbablityColumnNames) == 2);

const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, "PlattCalibrator");

// The Affine operator is no longer supported in the v11 opset.
// So we have to decompose it using Mul and Add
string opType = "Mul";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,9 @@ public void SaveAsOnnx(OnnxContext ctx)
Host.CheckValue(ctx, nameof(ctx));
Contracts.Assert(CanSaveOnnx(ctx));

const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

for (int iinfo = 0; iinfo < _columns.Length; ++iinfo)
{
var colInfo = _parent._columns[iinfo];
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Data/Transforms/ColumnCopying.cs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()

public void SaveAsOnnx(OnnxContext ctx)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

var opType = "Identity";

foreach (var column in _columns)
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Data/Transforms/ColumnSelecting.cs
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,9 @@ IDataTransform ITransformTemplate.ApplyToData(IHostEnvironment env, IDataView ne

public void SaveAsOnnx(OnnxContext ctx)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

var outputToInputMap = _mapper.OutputToInputMap;
for(int i = 0; i < outputToInputMap.Length; i++)
{
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Data/Transforms/Hashing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,9 @@ private void AddMetaKeyValues(int i, DataViewSchema.Annotations.Builder builder)

private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, string dstVariable)
{
const int minimumOpSetVersion = 11;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

string castOutput;
string isGreaterThanZeroOutput = "";
OnnxNode castNode;
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Data/Transforms/KeyToValue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,9 @@ public override JToken SavePfa(BoundPfaContext ctx, JToken srcToken)

public override bool SaveOnnx(OnnxContext ctx, string srcVariableName, string dstVariableName)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

string opType;

// Onnx expects the input keys to be int64s. But the input data can come from an ML.NET node that
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Data/Transforms/KeyToVector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,9 @@ private JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo info, JToke

private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

var dim = info.TypeSrc.GetValueCount();

string opType = "Cast";
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Data/Transforms/Normalizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,9 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColumnOptions info, stri
Contracts.Assert(_parent.Columns[iinfo] == info);
Contracts.Assert(CanSaveOnnx(ctx));

const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

int valueCount = info.InputType.GetValueCount();
if (valueCount == 0)
return false;
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,9 @@ public void SaveAsOnnx(OnnxContext ctx)

public bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

string opType;
var slots = _slotDropper[iinfo].GetPreservedSlots();
// vector column is not suppressed
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Data/Transforms/TypeConverting.cs
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,9 @@ public void SaveAsOnnx(OnnxContext ctx)

private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

var opType = "Cast";
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), "");
var t = _parent._columns[iinfo].OutputKind.ToInternalDataKind().ToType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,9 @@ private void CastInputToFloat<T>(OnnxContext ctx, out OnnxNode node, out long[]

private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

OnnxNode node;
long[] termIds;
string opType = "LabelEncoder";
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3032,6 +3032,9 @@ private protected virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputNames,
Host.CheckValue(ctx, nameof(ctx));
Host.Check(Utils.Size(outputNames) >= 1);

const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, "TreeEnsembleModelParameters");

//Nodes.
var nodesTreeids = new List<long>();
var nodesIds = new List<long>();
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.FastTree/FastTreeTweedie.cs
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,9 @@ internal static FastTreeTweedieModelParameters Create(IHostEnvironment env, Mode

bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

// Mapping score to prediction
var fastTreeOutput = ctx.AddIntermediateVariable(null, "FastTreeOutput", true);
base.SaveAsOnnx(ctx, new[] { fastTreeOutput }, featureColumn);
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.FastTree/RandomForestRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ internal static FastForestRegressionModelParameters Create(IHostEnvironment env,

bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

// Mapping score to prediction
var fastTreeOutput = ctx.AddIntermediateVariable(null, "FastTreeOutput", true);
var numTrees = ctx.AddInitializer((float)TrainedEnsemble.NumTrees, "NumTrees");
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
// v
// L [l] <--- ArgMin <--- Y [l, k]

const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

// Allocate C, which is a constant tensor in prediction phase
var shapeC = new long[] { _centroids.Length, _centroids[0].Length };
var tensorC = new List<float>();
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Mkl.Components/VectorWhitening.cs
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,9 @@ public void SaveAsOnnx(OnnxContext ctx)

private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

var model = _parent._models[iinfo];
int dimension = _srcTypes[iinfo].GetValueCount();
Host.Assert(model.Length == dimension * dimension);
Expand Down
16 changes: 14 additions & 2 deletions src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ namespace Microsoft.ML.Model.OnnxConverter
/// </summary>
internal sealed class OnnxContextImpl : OnnxContext
{
private const int CurrentOpSetVersion = 12;
private const int MinimumOpSetVersion = 9;
private readonly List<OnnxCSharpToProtoWrapper.NodeProto> _nodes;
private readonly List<OnnxUtils.ModelArgs> _inputs;
// The map from IDataView column names to variable names.
Expand All @@ -32,9 +34,10 @@ internal sealed class OnnxContextImpl : OnnxContext
private readonly string _producerVersion;
private readonly long _modelVersion;
private readonly OnnxVersion _onnxVersion;
private readonly int _opSetVersion;

public OnnxContextImpl(IHostEnvironment env, string name, string producerName,
string producerVersion, long modelVersion, string domain, OnnxVersion onnxVersion)
string producerVersion, long modelVersion, string domain, OnnxVersion onnxVersion, int opSetVersion = CurrentOpSetVersion)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(OnnxContext));
Expand All @@ -55,6 +58,9 @@ public OnnxContextImpl(IHostEnvironment env, string name, string producerName,
_modelVersion = modelVersion;
_domain = domain;
_onnxVersion = onnxVersion;
_opSetVersion = opSetVersion <= CurrentOpSetVersion ?
opSetVersion >= MinimumOpSetVersion ? opSetVersion : throw _host.ExceptParam(nameof(opSetVersion), $"Requested OpSet version {opSetVersion} is lower than the minimum required OpSet version {MinimumOpSetVersion}") :
throw _host.ExceptParam(nameof(opSetVersion), $"Requested OpSet version {opSetVersion} is higher than the current most updated OpSet version {CurrentOpSetVersion}");
}

public override bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName);
Expand Down Expand Up @@ -127,6 +133,12 @@ public override string GetNodeName(string prefix)
return GetUniqueName(prefix, _nodeNames.Contains);
}

public override void CheckOpSetVersion(int thisTransformerMinumumOpSetVersion, string registerTransformerName)
{
if (_opSetVersion < thisTransformerMinumumOpSetVersion)
throw _host.ExceptParam(nameof(thisTransformerMinumumOpSetVersion), $"Requested OpSet version {_opSetVersion} is lower than {registerTransformerName}'s minimum OpSet version requirement: {thisTransformerMinumumOpSetVersion}");
}

/// <summary>
/// Adds a node to the node list of the graph.
/// </summary>
Expand Down Expand Up @@ -409,7 +421,7 @@ public override string AddInitializer(IEnumerable<ulong> values, bool isUint64,
/// Makes the ONNX model based on the context.
/// </summary>
public OnnxCSharpToProtoWrapper.ModelProto MakeModel()
=> OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues, _initializers);
=> OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _opSetVersion, _inputs, _outputs, _intermediateValues, _initializers);

/// <summary>
/// Return either "Experimental" or "Stable". The string "Experimental" indicates that some experimental features which are
Expand Down
47 changes: 40 additions & 7 deletions src/Microsoft.ML.OnnxConverter/OnnxExportExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,23 @@
using Google.Protobuf;
using Microsoft.ML.Data;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using static Microsoft.ML.Model.OnnxConverter.OnnxCSharpToProtoWrapper;

namespace Microsoft.ML
{
public static class OnnxExportExtensions
{
private static ModelProto ConvertToOnnxProtobufCore(IHostEnvironment env, OnnxContextImpl ctx, ITransformer transform, IDataView inputData)
{
var outputData = transform.Transform(inputData);
LinkedList<ITransformCanSaveOnnx> transforms = null;
using (var ch = env.Start("ONNX conversion"))
{
SaveOnnxCommand.GetPipe(ctx, ch, outputData, out IDataView root, out IDataView sink, out transforms);
return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, ch, root, sink, transforms, null, null);
}
}

/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format. Note that ONNX uses Google's Protobuf so the returned value is a Protobuf object.
Expand All @@ -26,13 +37,23 @@ internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog cat
{
var env = catalog.GetEnvironment();
var ctx = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "machinelearning.dotnet", OnnxVersion.Stable);
var outputData = transform.Transform(inputData);
LinkedList<ITransformCanSaveOnnx> transforms = null;
using (var ch = env.Start("ONNX conversion"))
{
SaveOnnxCommand.GetPipe(ctx, ch, outputData, out IDataView root, out IDataView sink, out transforms);
return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, ch, root, sink, transforms, null, null);
}
return ConvertToOnnxProtobufCore(env, ctx, transform, inputData);
}

/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format. Note that ONNX uses Google's Protobuf so the returned value is a Protobuf object.
/// </summary>
/// <param name="catalog">The class that <see cref="ConvertToOnnxProtobuf(ModelOperationsCatalog, ITransformer, IDataView, int)"/> attached to.</param>
/// <param name="transform">The <see cref="ITransformer"/> that will be converted into ONNX format.</param>
/// <param name="inputData">The input of the specified transform.</param>
/// <param name="opSetVersion">The OpSet version to use for exporting the model. This value must be greater than or equal to 9 and less than or equal to 12</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
[BestFriend]
internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, int opSetVersion)
{
var env = catalog.GetEnvironment();
var ctx = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "machinelearning.dotnet", OnnxVersion.Stable, opSetVersion);
return ConvertToOnnxProtobufCore(env, ctx, transform, inputData);
}

/// <summary>
Expand All @@ -45,5 +66,17 @@ internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog cat
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, Stream stream) =>
ConvertToOnnxProtobuf(catalog, transform, inputData).WriteTo(stream);

/// <summary>
/// Convert the specified <see cref="ITransformer"/> to ONNX format and writes to a stream.
/// </summary>
/// <param name="catalog">The class that <see cref="ConvertToOnnx(ModelOperationsCatalog, ITransformer, IDataView, int, Stream)"/> attached to.</param>
/// <param name="transform">The <see cref="ITransformer"/> that will be converted into ONNX format.</param>
/// <param name="inputData">The input of the specified transform.</param>
/// <param name="opSetVersion">The OpSet version to use for exporting the model. This value must be greater than or equal to 9 and less than or equal to 12</param>
/// <param name="stream">The stream to write the protobuf model to.</param>
/// <returns>An ONNX model equivalent to the converted ML.NET model.</returns>
public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, int opSetVersion, Stream stream) =>
ConvertToOnnxProtobuf(catalog, transform, inputData, opSetVersion).WriteTo(stream);
}
}
4 changes: 2 additions & 2 deletions src/Microsoft.ML.OnnxConverter/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ public ModelArgs(string name, TensorProto.Types.DataType dataType, List<long> di
}

public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, string name,
string domain, string producerVersion, long modelVersion, List<ModelArgs> inputs,
string domain, string producerVersion, long modelVersion, int opSetVersion, List<ModelArgs> inputs,
List<ModelArgs> outputs, List<ModelArgs> intermediateValues, List<TensorProto> initializers)
{
Contracts.CheckValue(nodes, nameof(nodes));
Expand All @@ -305,7 +305,7 @@ public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, s
model.IrVersion = (long)OnnxCSharpToProtoWrapper.Version.IrVersion;
model.ModelVersion = modelVersion;
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx.ml", Version = 2 });
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = 11 });
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "", Version = opSetVersion });
model.Graph = new GraphProto();
var graph = model.Graph;
graph.Node.Add(nodes);
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.PCA/PcaTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,9 @@ private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName,
{
Host.CheckValue(ctx, nameof(ctx));

const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

TransformInfo transformInfo = _parent._transformInfos[iinfo];

// When the transformer is loaded from a model file,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ private protected virtual bool SaveAsOnnx(OnnxContext ctx, string[] outputs, str
{
Host.CheckValue(ctx, nameof(ctx));
Host.Check(Utils.Size(outputs) >= 1);

const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, "LinearModel");

string opType = "LinearRegressor";
string scoreVarName = (Utils.Size(outputs) >= 2) ? outputs[1] : outputs[0]; // Get Score from PredictedLabel and/or Score columns
var node = ctx.CreateNode(opType, new[] { featureColumn }, new[] { scoreVarName }, ctx.GetNodeName(opType));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,10 @@ private JToken SaveAsPfaCore(BoundPfaContext ctx, JToken input)
private bool SaveAsOnnxCore(OnnxContext ctx, string[] outputs, string featureColumn)
{
Host.CheckValue(ctx, nameof(ctx));

const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, "MultiClassLogisticRegression");

Host.Assert(outputs[0] == DefaultColumnNames.PredictedLabel);
Host.Assert(outputs[1] == DefaultColumnNames.Score);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,9 @@ ValueMapper<TIn, TOut> IValueMapper.GetMapper<TIn, TOut>()
/// </summary>
bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, "MulticlassNaiveBayes");

float[] featureHistogram = new float[_featureHistogram[0].Length * _labelHistogram.Length];
float[] labelHistogramExpanded = new float[_featureHistogram[0].Length * _labelHistogram.Length];

Expand Down
Loading

0 comments on commit d1bf425

Please sign in to comment.