Skip to content

Commit

Permalink
Added onnx export support for OptionalColumnTransform (dotnet#4454)
Browse files Browse the repository at this point in the history
* Initial work for adding onnx export support for OptionalColumnTransform

* Implemented support for optional initializers in OnnxTranformer to support OptionalColumnTransform

* Fixed handling of double values and non-long numeric types

* Removed redundant line

* Updated review comment
  • Loading branch information
harishsk authored Nov 14, 2019
1 parent f96761b commit c1e190a
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 75 deletions.
18 changes: 12 additions & 6 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -128,50 +128,56 @@ public OnnxNode CreateNode(string opType, string input, string output, string na
/// </summary>
/// <param name="value">The float number which is going to be added</param>
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
/// <returns>The initializer's ONNX name</returns>
public abstract string AddInitializer(float value, string name = null);
public abstract string AddInitializer(float value, string name = null, bool makeUniqueName = true);

/// <summary>
/// Call this function can declare a global long
/// </summary>
/// <param name="value">The long number which is going to be added into the ONNX graph</param>
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
/// <returns>The initializer's ONNX name</returns>
public abstract string AddInitializer(long value, string name = null);
public abstract string AddInitializer(long value, string name = null, bool makeUniqueName = true);

/// <summary>
/// Call this function can declare a global string
/// </summary>
/// <param name="value">The string which is going to be added into the ONNX graph</param>
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
/// <returns>The initializer's ONNX name</returns>
public abstract string AddInitializer(string value, string name = null);
public abstract string AddInitializer(string value, string name = null, bool makeUniqueName = true);

/// <summary>
/// Call this function can declare a global float tensor
/// </summary>
/// <param name="values">The floats which are going to be added into the ONNX graph</param>
/// <param name="dims">The shape that the floats</param>
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
/// <returns>The initializer's ONNX name</returns>
public abstract string AddInitializer(IEnumerable<float> values, IEnumerable<long> dims, string name = null);
public abstract string AddInitializer(IEnumerable<float> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);

/// <summary>
/// Call this function can declare a global long tensor
/// </summary>
/// <param name="values">The longs which are going to be added into the ONNX graph</param>
/// <param name="dims">The shape that the floats</param>
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
/// <returns>The initializer's ONNX name</returns>
public abstract string AddInitializer(IEnumerable<long> values, IEnumerable<long> dims, string name = null);
public abstract string AddInitializer(IEnumerable<long> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);

/// <summary>
/// Call this function can declare a global string tensor
/// </summary>
/// <param name="values">The strings which are going to be added into the ONNX graph</param>
/// <param name="dims">The shape that the strings</param>
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
/// <param name="makeUniqueName">Whether a unique name should be picked for this initializer.</param>
/// <returns>The initializer's ONNX name</returns>
public abstract string AddInitializer(IEnumerable<string> values, IEnumerable<long> dims, string name = null);
public abstract string AddInitializer(IEnumerable<string> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true);
}
}
29 changes: 15 additions & 14 deletions src/Microsoft.ML.OnnxConverter/OnnxContextImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,12 @@ public string TryGetVariableName(string colName)
/// there is a collision between names in the pipeline at any point.
/// </summary>
/// <param name="colName">IDataView column name.</param>
/// <param name="makeUniqueName">Whether a unique name should be chosen for this variable.</param>
/// <returns>Unique variable name.</returns>
public string AddVariable(string colName)
public string AddVariable(string colName, bool makeUniqueName = true)
{
_host.CheckNonEmpty(colName, nameof(colName));
_columnNameMap[colName] = GetUniqueName(colName, _variableNames.Contains);
_columnNameMap[colName] = makeUniqueName ? GetUniqueName(colName, _variableNames.Contains) : colName;
_variableNames.Add(_columnNameMap[colName]);
return _columnNameMap[colName];
}
Expand Down Expand Up @@ -269,56 +270,56 @@ public override List<long> RetrieveShapeOrNull(string variableName)
}

/// Adds constant tensor into the graph.
public override string AddInitializer(float value, string name = null)
public override string AddInitializer(float value, string name = null, bool makeUniqueName = true)
{
name = AddVariable(name ?? "float");
name = AddVariable(name ?? "float", makeUniqueName);
_initializers.Add(OnnxUtils.MakeFloat(name, value));
return name;
}

public override string AddInitializer(string value, string name = null)
public override string AddInitializer(string value, string name = null, bool makeUniqueName = true)
{
name = AddVariable(name ?? "string");
name = AddVariable(name ?? "string", makeUniqueName);
_initializers.Add(OnnxUtils.MakeString(name, value));
return name;
}

public override string AddInitializer(long value, string name = null)
public override string AddInitializer(long value, string name = null, bool makeUniqueName = true)
{
name = AddVariable(name ?? "int64");
name = AddVariable(name ?? "int64", makeUniqueName);
_initializers.Add(OnnxUtils.MakeInt64(name, value));
return name;
}

public override string AddInitializer(IEnumerable<float> values, IEnumerable<long> dims, string name = null)
public override string AddInitializer(IEnumerable<float> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
{
_host.CheckValue(values, nameof(values));
if (dims != null)
_host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");

name = AddVariable(name ?? "floats");
name = AddVariable(name ?? "floats", makeUniqueName);
_initializers.Add(OnnxUtils.MakeFloats(name, values, dims));
return name;
}

public override string AddInitializer(IEnumerable<long> values, IEnumerable<long> dims, string name = null)
public override string AddInitializer(IEnumerable<long> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
{
_host.CheckValue(values, nameof(values));
if (dims != null)
_host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");

name = AddVariable(name ?? "int64s");
name = AddVariable(name ?? "int64s", makeUniqueName);
_initializers.Add(OnnxUtils.MakeInt64s(name, values, dims));
return name;
}

public override string AddInitializer(IEnumerable<string> values, IEnumerable<long> dims, string name = null)
public override string AddInitializer(IEnumerable<string> values, IEnumerable<long> dims, string name = null, bool makeUniqueName = true)
{
_host.CheckValue(values, nameof(values));
if (dims != null)
_host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");

name = AddVariable(name ?? "strings");
name = AddVariable(name ?? "strings", makeUniqueName);
_initializers.Add(OnnxUtils.MakeStrings(name, values, dims));
return name;
}
Expand Down
98 changes: 44 additions & 54 deletions src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ public sealed class OnnxModelInfo
/// </summary>
public List<string> OutputNames { get; }
/// <summary>
/// Initializers[i] is the name of the i-th initializer in <see cref="InitializersInfo"/>.
/// </summary>
public List<string> InitializerNames { get; }
/// <summary>
/// Inputs of the containing <see cref="OnnxModel"/>.
/// </summary>
public OnnxVariableInfo[] InputsInfo { get; }
Expand All @@ -46,12 +50,19 @@ public sealed class OnnxModelInfo
/// </summary>
public OnnxVariableInfo[] OutputsInfo { get; }

public OnnxModelInfo(IEnumerable<OnnxVariableInfo> inputsInfo, IEnumerable<OnnxVariableInfo> outputsInfo)
/// <summary>
/// Initializers of the containing <see cref="OnnxModel"/>
/// </summary>
public OnnxVariableInfo[] InitializersInfo { get; }

public OnnxModelInfo(IEnumerable<OnnxVariableInfo> inputsInfo, IEnumerable<OnnxVariableInfo> outputsInfo, IEnumerable<OnnxVariableInfo> initializersInfo)
{
InputNames = inputsInfo.Select(val => val.Name).ToList();
InputsInfo = inputsInfo.ToArray();
OutputNames = outputsInfo.Select(val => val.Name).ToList();
OutputsInfo = outputsInfo.ToArray();
InitializerNames = initializersInfo.Select(val => val.Name).ToList();
InitializersInfo = initializersInfo.ToArray();
}

/// <summary>
Expand All @@ -60,10 +71,16 @@ public OnnxModelInfo(IEnumerable<OnnxVariableInfo> inputsInfo, IEnumerable<OnnxV
public OnnxVariableInfo GetInput(string name)
{
var index = InputNames.IndexOf(name);
if (index < 0)
throw Contracts.ExceptParamValue(name, nameof(name), $"Input tensor, {name}, does not exist in the ONNX model. " +
$"Available input names are [{string.Join(",", InputNames)}].");
return InputsInfo[index];
if (index >= 0)
return InputsInfo[index];

index = InitializerNames.IndexOf(name);
if (index >= 0)
return InitializersInfo[index];

// If we dont find the index in the input, try find it in the initializers
throw Contracts.ExceptParamValue(name, nameof(name), $"Input tensor, {name}, does not exist in the ONNX model. " +
$"Available input names are [{string.Join(",", InputNames)}]. Available initializers are [{string.Join(",", InitializerNames)}]");
}

/// <summary>
Expand Down Expand Up @@ -180,8 +197,12 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu =
var inputTypePool = new Dictionary<string, DataViewType>();
foreach (var valueInfo in model.Graph.Input)
inputTypePool[valueInfo.Name] = OnnxTypeParser.GetDataViewType(valueInfo.Type);
var outputTypePool = new Dictionary<string, DataViewType>();

var initializerTypePool = new Dictionary<string, DataViewType>();
foreach (var valueInfo in model.Graph.Initializer)
initializerTypePool[valueInfo.Name] = OnnxTypeParser.GetScalarDataViewType(valueInfo.DataType);

var outputTypePool = new Dictionary<string, DataViewType>();
// Build casters which maps NamedOnnxValue to .NET objects.
var casterPool = new Dictionary<string, Func<NamedOnnxValue, object>>();
foreach (var valueInfo in model.Graph.Output)
Expand All @@ -190,60 +211,31 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu =
casterPool[valueInfo.Name] = OnnxTypeParser.GetDataViewValueCasterAndResultedType(valueInfo.Type, out Type actualType);
}

var onnxRuntimeInputInfos = new List<OnnxVariableInfo>();
// Collect input information for this ONNX model from ONNXRuntime's perspective.
foreach (var pair in _session.InputMetadata)
{
var name = pair.Key;
var meta = pair.Value;
var dataViewType = inputTypePool[name];

OnnxVariableInfo info = null;
if (shapeDictionary != null && shapeDictionary.ContainsKey(name))
{
// If user provides a shape of a specific tensor, the provided shape overwrites the corresponding one loaded from
// ONNX model file and the deduced DataViewVectorType.

if (!CheckOnnxShapeCompatibility(shapeDictionary[name].ToList(), meta.Dimensions.ToList()))
throw Contracts.ExceptParamValue(shapeDictionary[name], nameof(shapeDictionary),
"The specified shape " + string.Join(",", shapeDictionary[name]) +
" is not compatible with the shape " + string.Join(",", meta.Dimensions) +
" loaded from the ONNX model file. Only unknown dimension can replace or " +
"be replaced by another dimension.");
var inputInfos = GetOnnxVariablesFromMetadata(_session.InputMetadata, shapeDictionary, inputTypePool, null);
var outputInfos = GetOnnxVariablesFromMetadata(_session.OutputMetadata, shapeDictionary, outputTypePool, casterPool);
var overrideableInitializers = GetOnnxVariablesFromMetadata(_session.OverridableInitializerMetadata, shapeDictionary, inputTypePool, null);

if (dataViewType is VectorDataViewType vectorType)
{
if (shapeDictionary[name].All(value => value > 0))
dataViewType = new VectorDataViewType(vectorType.ItemType, shapeDictionary[name]);
else
dataViewType = new VectorDataViewType(vectorType.ItemType);
}
// Create a view to the used ONNX model from ONNXRuntime's perspective.
ModelInfo = new OnnxModelInfo(inputInfos, outputInfos, overrideableInitializers);
}

info = new OnnxVariableInfo(name, shapeDictionary[name].ToList(), meta.ElementType, dataViewType, null);
}
else
{
// No user-specified shape is found, so the shape loaded from ONNX model file is used.
info = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, null);
}
onnxRuntimeInputInfos.Add(info);
}
private List<OnnxVariableInfo> GetOnnxVariablesFromMetadata(IReadOnlyDictionary<string, NodeMetadata> nodeMetadata,
IDictionary<string, int[]> shapeDictionary,
Dictionary<string, DataViewType> typePool,
Dictionary<string, Func<NamedOnnxValue, object>> casterPool)
{
var onnxVariableInfos = new List<OnnxVariableInfo>();

var onnxRuntimeOutputInfos = new List<OnnxVariableInfo>();
// Collect output information for this ONNX model from ONNXRuntime's perspective.
foreach (var pair in _session.OutputMetadata)
foreach (var pair in nodeMetadata)
{
var name = pair.Key;
var meta = pair.Value;
var dataViewType = outputTypePool[name];
var caster = casterPool[name];
var dataViewType = typePool[name];
var caster = casterPool?[name];

OnnxVariableInfo info = null;
if (shapeDictionary != null && shapeDictionary.ContainsKey(name))
{
// If user provide a shape of a specific tensor, the provided shape overwrites the corresponding one loaded from
// ONNX model file.

if (!CheckOnnxShapeCompatibility(shapeDictionary[name].ToList(), meta.Dimensions.ToList()))
throw Contracts.ExceptParamValue(shapeDictionary[name], nameof(shapeDictionary),
"The specified shape " + string.Join(",", shapeDictionary[name]) +
Expand All @@ -267,11 +259,9 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu =
info = new OnnxVariableInfo(name, meta.Dimensions.ToList(), meta.ElementType, dataViewType, caster);
}

onnxRuntimeOutputInfos.Add(info);
onnxVariableInfos.Add(info);
}

// Create a view to the used ONNX model from ONNXRuntime's perspective.
ModelInfo = new OnnxModelInfo(onnxRuntimeInputInfos, onnxRuntimeOutputInfos);
return onnxVariableInfos;
}

/// <summary>
Expand Down
Loading

0 comments on commit c1e190a

Please sign in to comment.