Skip to content

Commit

Permalink
Isolate ONNX implementations in separate DLL and NuGet (#462)
Browse files Browse the repository at this point in the history
Abstraction of ONNX exporting to interfaces, and isolation of actual implementation to separate DLL. Creation of a new NuGet to isolate Protobuf dependency.
  • Loading branch information
TomFinley authored Jul 5, 2018
1 parent 4d574d6 commit 52cc874
Show file tree
Hide file tree
Showing 35 changed files with 528 additions and 338 deletions.
2 changes: 1 addition & 1 deletion Microsoft.ML.sln
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.InferenceTesti
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Data", "src\Microsoft.ML.Data\Microsoft.ML.Data.csproj", "{AD92D96B-0E96-4F22-8DCE-892E13B1F282}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.UniversalModelFormat", "src\Microsoft.ML.UniversalModelFormat\Microsoft.ML.UniversalModelFormat.csproj", "{65D0603E-B96C-4DFC-BDD1-705891B88C18}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Onnx", "src\Microsoft.ML.Onnx\Microsoft.ML.Onnx.csproj", "{65D0603E-B96C-4DFC-BDD1-705891B88C18}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.StandardLearners", "src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj", "{707BB22C-7E5F-497A-8C2F-74578F675705}"
EndProject
Expand Down
13 changes: 13 additions & 0 deletions pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.nupkgproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<Project Sdk="Microsoft.NET.Sdk" DefaultTargets="Pack">

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<PackageDescription>ML.NET component for exporting ONNX Models</PackageDescription>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="../Microsoft.ML/Microsoft.ML.nupkgproj" />
<PackageReference Include="Google.Protobuf" Version="$(GoogleProtobufPackageVersion)" />
</ItemGroup>

</Project>
5 changes: 5 additions & 0 deletions pkg/Microsoft.ML.Onnx/Microsoft.ML.Onnx.symbols.nupkgproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<Project DefaultTargets="Pack">

<Import Project="Microsoft.ML.Onnx.nupkgproj" />

</Project>
1 change: 0 additions & 1 deletion pkg/Microsoft.ML/Microsoft.ML.nupkgproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Google.Protobuf" Version="$(GoogleProtobufPackageVersion)" />
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonPackageVersion)" />
<PackageReference Include="System.Reflection.Emit.Lightweight" Version="$(SystemReflectionEmitLightweightPackageVersion)" />
<PackageReference Include="System.Threading.Tasks.Dataflow" Version="$(SystemThreadingTasksDataflowPackageVersion)" />
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Console/Microsoft.ML.Console.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
<ProjectReference Include="..\Microsoft.ML.KMeansClustering\Microsoft.ML.KMeansClustering.csproj" />
<ProjectReference Include="..\Microsoft.ML.LightGBM\Microsoft.ML.LightGBM.csproj" />
<ProjectReference Include="..\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj" />
<ProjectReference Include="..\Microsoft.ML.Onnx\Microsoft.ML.Onnx.csproj" />
<ProjectReference Include="..\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj" />
<ProjectReference Include="..\Microsoft.ML.PipelineInference\Microsoft.ML.PipelineInference.csproj" />
<ProjectReference Include="..\Microsoft.ML.ResultProcessor\Microsoft.ML.ResultProcessor.csproj" />
<ProjectReference Include="..\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
<ProjectReference Include="..\Microsoft.ML.Sweeper\Microsoft.ML.Sweeper.csproj" />
<ProjectReference Include="..\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj" />
<ProjectReference Include="..\Microsoft.ML.UniversalModelFormat\Microsoft.ML.UniversalModelFormat.csproj" />

<NativeAssemblyReference Include="FastTreeNative" />
<NativeAssemblyReference Include="CpuMathNative" />
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ private static VersionInfo GetVersionInfo()
/// Returns the underlying data view of the composite loader.
/// This can be used to programmatically explore the chain of transforms that's inside the composite loader.
/// </summary>
internal IDataView View { get; }
public IDataView View { get; }

/// <summary>
/// Creates a loader according to the specified <paramref name="args"/>.
Expand Down
1 change: 0 additions & 1 deletion src/Microsoft.ML.Data/Microsoft.ML.Data.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
<ItemGroup>
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
<ProjectReference Include="..\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.csproj" />
<ProjectReference Include="..\Microsoft.ML.UniversalModelFormat\Microsoft.ML.UniversalModelFormat.csproj" />
</ItemGroup>

</Project>
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public interface ICanSaveOnnx
}

/// <summary>
/// This data model component is savable as Onnx.
/// This data model component is savable as ONNX.
/// </summary>
public interface ITransformCanSaveOnnx: ICanSaveOnnx, IDataTransform
{
Expand Down
272 changes: 64 additions & 208 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,245 +2,101 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.Runtime.UniversalModelFormat.Onnx;
using Microsoft.ML.Runtime.Data;

namespace Microsoft.ML.Runtime.Model.Onnx
{
/// <summary>
/// A context for defining a ONNX output.
/// A context for defining a ONNX output. The context internally contains the model-in-progress being built. This
/// same context object is iteratively given to exportable components via the <see cref="ICanSaveOnnx"/> interface
/// and subinterfaces, that attempt to express their operations as ONNX nodes, if they can. At the point that it is
/// given to a component, all other components up to that component have already attempted to express themselves in
/// this context, with their outputs possibly available in the ONNX graph.
/// </summary>
public sealed class OnnxContext
public abstract class OnnxContext
{
private readonly List<NodeProto> _nodes;
private readonly List<OnnxUtils.ModelArgs> _inputs;
private readonly List<OnnxUtils.ModelArgs> _intermediateValues;
private readonly List<OnnxUtils.ModelArgs> _outputs;
private readonly Dictionary<string, string> _columnNameMap;
private readonly HashSet<string> _variableMap;
private readonly HashSet<string> _nodeNames;
private readonly string _name;
private readonly string _producerName;
private readonly IHost _host;
private readonly string _domain;
private readonly string _producerVersion;
private readonly long _modelVersion;

public OnnxContext(IHostEnvironment env, string name, string producerName,
string producerVersion, long modelVersion, string domain)
{
Contracts.CheckValue(env, nameof(env));
Contracts.CheckValue(name, nameof(name));
Contracts.CheckValue(name, nameof(domain));

_host = env.Register(nameof(OnnxContext));
_nodes = new List<NodeProto>();
_intermediateValues = new List<OnnxUtils.ModelArgs>();
_inputs = new List<OnnxUtils.ModelArgs>();
_outputs = new List<OnnxUtils.ModelArgs>();
_columnNameMap = new Dictionary<string, string>();
_variableMap = new HashSet<string>();
_nodeNames = new HashSet<string>();
_name = name;
_producerName = producerName;
_producerVersion = producerVersion;
_modelVersion = modelVersion;
_domain = domain;
}

public bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName);

/// <summary>
/// Stops tracking a column. If removeVariable is true then it also removes the
/// variable associated with it, this is useful in the event where an output variable is
/// created before realizing the transform cannot actually save as ONNX.
/// </summary>
/// <param name="colName">IDataView column name to stop tracking</param>
/// <param name="removeVariable">Remove associated ONNX variable at the time.</param>
public void RemoveColumn(string colName, bool removeVariable)
{

if (removeVariable)
{
foreach (var val in _intermediateValues)
{
if (val.Name == _columnNameMap[colName])
{
_intermediateValues.Remove(val);
break;
}
}
}

if (_columnNameMap.ContainsKey(colName))
_columnNameMap.Remove(colName);
}

/// <summary>
/// Removes an ONNX variable. If removeColumn is true then it also removes the
/// IDataView column associated with it.
/// </summary>
/// <param name="variableName">ONNX variable to remove.</param>
/// <param name="removeColumn">IDataView column to stop tracking</param>
public void RemoveVariable(string variableName, bool removeColumn)
{
_host.Assert(_columnNameMap.ContainsValue(variableName));
if (removeColumn)
{
foreach (var val in _intermediateValues)
{
if (val.Name == variableName)
{
_intermediateValues.Remove(val);
break;
}
}
}

string columnName = _columnNameMap.Single(kvp => string.Compare(kvp.Value, variableName) == 0).Key;

Contracts.Assert(_variableMap.Contains(columnName));

_columnNameMap.Remove(columnName);
_variableMap.Remove(columnName);
}

/// <summary>
/// Generates a unique name for the node based on a prefix.
/// </summary>
public string GetNodeName(string prefix)
{
_host.CheckValue(prefix, nameof(prefix));
return GetUniqueName(prefix, c => _nodeNames.Contains(c));
}
/// <param name="prefix">The prefix for the node</param>
/// <returns>A name that has not yet been returned from this function, starting with <paramref name="prefix"/></returns>
public abstract string GetNodeName(string prefix);

/// <summary>
/// Adds a node to the node list of the graph.
/// Looks up whether a given data view column has a mapping in the ONNX context. Once confirmed, callers can
/// safely call <see cref="GetVariableName(string)"/>.
/// </summary>
/// <param name="node"></param>
public void AddNode(NodeProto node)
{
_host.CheckValue(node, nameof(node));
_host.Assert(!_nodeNames.Contains(node.Name));

_nodeNames.Add(node.Name);
_nodes.Add(node);
}
/// <param name="colName">The data view column name</param>
/// <returns>Whether the column is mapped in this context</returns>
public abstract bool ContainsColumn(string colName);

/// <summary>
/// Generates a unique name based on a prefix.
/// Stops tracking a column.
/// </summary>
private string GetUniqueName(string prefix, Func<string, bool> pred)
{
_host.CheckValue(prefix, nameof(prefix));
_host.CheckValue(pred, nameof(pred));

if (!pred(prefix))
return prefix;

int count = 0;
while (pred(prefix + count++)) ;
return prefix + --count;
}
/// <param name="colName">Column name to stop tracking</param>
/// <param name="removeVariable">Remove associated ONNX variable. This is useful in the event where an output
/// variable is created through <see cref="AddIntermediateVariable(ColumnType, string, bool)"/>before realizing
/// the transform cannot actually save as ONNX.</param>
public abstract void RemoveColumn(string colName, bool removeVariable = false);

/// <summary>
/// Retrieves the variable name that maps to the IDataView column name at a
/// given point in the pipeline execution.
/// Removes an ONNX variable. If removeColumn is true then it also removes the tracking for the <see
/// cref="IDataView"/> column associated with it.
/// </summary>
/// <returns>Column Name mapping.</returns>
public string GetVariableName(string colName)
{
_host.CheckValue(colName, nameof(colName));
_host.Assert(_columnNameMap.ContainsKey(colName));

return _columnNameMap[colName];
}

/// <summary>
/// Retrieves the variable name that maps to the IDataView column name at a
/// given point in the pipeline execution.
/// </summary>
/// <returns>Column Name mapping.</returns>
public string TryGetVariableName(string colName)
{
if (_columnNameMap.ContainsKey(colName))
return GetVariableName(colName);

return null;
}

/// <summary>
/// Generates a unique column name based on the IDataView column name if
/// there is a collision between names in the pipeline at any point.
/// </summary>
/// <param name="colName">IDataView column name.</param>
/// <returns>Unique variable name.</returns>
private string AddVariable(string colName)
{
_host.CheckValue(colName, nameof(colName));

if (!_columnNameMap.ContainsKey(colName))
_columnNameMap.Add(colName, colName);
else
_columnNameMap[colName] = GetUniqueName(colName, s => _variableMap.Contains(s));

_variableMap.Add(_columnNameMap[colName]);
return _columnNameMap[colName];
}
/// <param name="variableName">ONNX variable to remove. Note that this is an ONNX variable name, not an <see
/// cref="IDataView"/> column name</param>
/// <param name="removeColumn">IDataView column to stop tracking</param>
public abstract void RemoveVariable(string variableName, bool removeColumn);

/// <summary>
/// Adds an intermediate column to the list.
/// ONNX variables are referred to by name. At each stage of a ML.NET pipeline, the corresponding
/// <see cref="IDataView"/>'s column names will map to a variable in the ONNX graph if the intermediate steps
/// used to calculate that value are things we knew how to save as ONNX. Retrieves the variable name that maps
/// to the <see cref="IDataView"/> column name at a given point in the pipeline execution. Callers should
/// probably confirm with <see cref="ContainsColumn(string)"/> whether a mapping for that data view column
/// already exists.
/// </summary>
public string AddIntermediateVariable(ColumnType type, string colName, bool skip = false)
{

colName = AddVariable(colName);

//Let the runtime figure the shape.
if (!skip)
{
_host.CheckValue(type, nameof(type));

_intermediateValues.Add(OnnxUtils.GetModelArgs(type, colName));
}

return colName;
}
/// <param name="colName">The data view column name</param>
/// <returns>The ONNX variable name corresponding to that data view column</returns>
public abstract string GetVariableName(string colName);

/// <summary>
/// Adds an output variable to the list.
/// Establishes a new mapping from an data view column in the context, if necessary generates a unique name, and
/// returns that newly allocated name.
/// </summary>
public string AddOutputVariable(ColumnType type, string colName, List<long> dim = null)
{
_host.CheckValue(type, nameof(type));

if (!ContainsColumn(colName))
AddVariable(colName);

colName = GetVariableName(colName);
_outputs.Add(OnnxUtils.GetModelArgs(type, colName, dim));
return colName;
}
/// <param name="type">The data view type associated with this column name</param>
/// <param name="colName">The data view column name</param>
/// <param name="skip">Whether we should skip the process of establishing the mapping from data view column to
/// ONNX variable name.</param>
/// <returns>The returned value is the name of the variable corresponding </returns>
public abstract string AddIntermediateVariable(ColumnType type, string colName, bool skip = false);

/// <summary>
/// Adds an input variable to the list.
/// Creates an ONNX node
/// </summary>
public void AddInputVariable(ColumnType type, string colName)
{
_host.CheckValue(type, nameof(type));
_host.CheckValue(colName, nameof(colName));

colName = AddVariable(colName);
_inputs.Add(OnnxUtils.GetModelArgs(type, colName));
}
/// <param name="opType">The name of the ONNX operator to apply</param>
/// <param name="inputs">The names of the variables as inputs</param>
/// <param name="outputs">The names of the variables to create as outputs,
/// which ought to have been something returned from <see cref="AddIntermediateVariable(ColumnType, string, bool)"/></param>
/// <param name="name">The name of the operator, which ought to be something returned from <see cref="GetNodeName(string)"/></param>
/// <param name="domain">The domain of the ONNX operator, if non-default</param>
/// <returns>A node added to the in-progress ONNX graph, that attributes can be set on</returns>
public abstract OnnxNode CreateNode(string opType, IEnumerable<string> inputs,
IEnumerable<string> outputs, string name, string domain = null);

/// <summary>
/// Makes the ONNX model based on the context.
/// Convenience alternative to <see cref="CreateNode(string, IEnumerable{string}, IEnumerable{string}, string, string)"/>
/// for the case where there is exactly one input and output.
/// </summary>
public ModelProto MakeModel()
=> OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues);
/// <param name="opType">The name of the ONNX operator to apply</param>
/// <param name="input">The name of the variable as input</param>
/// <param name="output">The name of the variable as output,
/// which ought to have been something returned from <see cref="OnnxContext.AddIntermediateVariable(ColumnType, string, bool)"/></param>
/// <param name="name">The name of the operator, which ought to be something returned from <see cref="OnnxContext.GetNodeName(string)"/></param>
/// <param name="domain">The domain of the ONNX operator, if non-default</param>
/// <returns>A node added to the in-progress ONNX graph, that attributes can be set on</returns>
public OnnxNode CreateNode(string opType, string input, string output, string name, string domain = null)
=> CreateNode(opType, new[] { input }, new[] { output }, name, domain);
}
}
Loading

0 comments on commit 52cc874

Please sign in to comment.