Skip to content

ONNXTransform Upgrade to Enable Non-tensor Types #3881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jul 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions build/Dependencies.props
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@
<PropertyGroup>
<BenchmarkDotNetVersion>0.11.3</BenchmarkDotNetVersion>
<MicrosoftCodeAnalysisTestingVersion>1.0.0-beta1-63812-02</MicrosoftCodeAnalysisTestingVersion>
<MicrosoftMLTestModelsPackageVersion>0.0.4-test</MicrosoftMLTestModelsPackageVersion>
<MicrosoftMLTestModelsPackageVersion>0.0.5-test</MicrosoftMLTestModelsPackageVersion>
<MicrosoftMLTensorFlowTestModelsVersion>0.0.11-test</MicrosoftMLTensorFlowTestModelsVersion>
<MicrosoftMLOnnxTestModelsVersion>0.0.4-test</MicrosoftMLOnnxTestModelsVersion>
<MicrosoftMLOnnxTestModelsVersion>0.0.5-test</MicrosoftMLOnnxTestModelsVersion>
</PropertyGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

<ItemGroup>
<ProjectReference Include="../Microsoft.ML/Microsoft.ML.nupkgproj" />
<PackageReference Include="Google.Protobuf" Version="$(GoogleProtobufPackageVersion)" />
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="$(MicrosoftMLOnnxRuntimePackageVersion)"/>
</ItemGroup>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="$(MicrosoftMLOnnxRuntimePackageVersion)" />
<PackageReference Include="Google.Protobuf" Version="$(GoogleProtobufPackageVersion)" />
</ItemGroup>

<ItemGroup>
<Compile Include="..\Microsoft.ML.OnnxConverter\OnnxMl.cs">
<Link>OnnxMl.cs</Link>
</Compile>
</ItemGroup>
</Project>
101 changes: 101 additions & 0 deletions src/Microsoft.ML.OnnxTransformer/OnnxMapType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;

namespace Microsoft.ML.Transforms.Onnx
{
/// <summary>
/// The corresponding <see cref="DataViewSchema.Column.Type"/> of ONNX's map type in <see cref="IDataView"/>'s type system.
/// In other words, if an ONNX model produces a map, a column in <see cref="IDataView"/> may be typed to <see cref="OnnxMapType"/>.
/// Its underlying type is <see cref="IDictionary{TKey, TValue}"/>, where the generic type "TKey" and "TValue" are the input arguments of
/// <see cref="OnnxMapType.OnnxMapType(Type,Type)"/>.
/// </summary>
public sealed class OnnxMapType : StructuredDataViewType
{
/// <summary>
/// Create the corresponding <see cref="DataViewType"/> for ONNX map.
/// </summary>
/// <param name="keyType">Key type of the associated ONNX map.</param>
/// <param name="valueType">Value type of the associated ONNX map.</param>
public OnnxMapType(Type keyType, Type valueType) : base(typeof(IDictionary<,>).MakeGenericType(keyType, valueType))
{
DataViewTypeManager.Register(this, RawType, new[] { new OnnxMapTypeAttribute(keyType, valueType) });
}

public override bool Equals(DataViewType other)
{
if (other is OnnxMapType)
return RawType == other.RawType;
else
return false;
}

public override int GetHashCode()
{
return RawType.GetHashCode();
}
}

/// <summary>
/// To declare <see cref="OnnxMapType"/> column in <see cref="IDataView"/> as a field
/// in a <see langword="class"/>, the associated field should be marked with <see cref="OnnxMapTypeAttribute"/>.
/// Its uses are similar to those of <see cref="VectorTypeAttribute"/> and other <see langword="class"/>es derived
/// from <see cref="DataViewTypeAttribute"/>.
/// </summary>
public sealed class OnnxMapTypeAttribute : DataViewTypeAttribute
{
private Type _keyType;
private Type _valueType;

/// <summary>
/// Create a map (aka dictionary) type.
/// </summary>
public OnnxMapTypeAttribute()
{
}

/// <summary>
/// Create a map (aka dictionary) type. A map is a collection of key-value
/// pairs. <paramref name="keyType"/> specifies the type of keys and <paramref name="valueType"/>
/// is the type of values.
/// </summary>
public OnnxMapTypeAttribute(Type keyType, Type valueType)
{
_keyType = keyType;
_valueType = valueType;
}

/// <summary>
/// Map types with the same key type and the same value type should be equal.
/// </summary>
public override bool Equals(DataViewTypeAttribute other)
{
if (other is OnnxMapTypeAttribute otherSequence)
return _keyType.Equals(otherSequence._keyType) && _valueType.Equals(otherSequence._valueType);
return false;
}

/// <summary>
/// Produce the same hash code for map types with the same key type and the same value type.
/// </summary>
public override int GetHashCode()
{
return Hashing.CombineHash(_keyType.GetHashCode(), _valueType.GetHashCode());
}

/// <summary>
/// An implementation of <see cref="DataViewTypeAttribute.Register"/>.
/// </summary>
public override void Register()
{
var enumerableType = typeof(IDictionary<,>);
var type = enumerableType.MakeGenericType(_keyType, _valueType);
DataViewTypeManager.Register(new OnnxMapType(_keyType, _valueType), type, new[] { this });
}
}
}
102 changes: 102 additions & 0 deletions src/Microsoft.ML.OnnxTransformer/OnnxSequenceType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using Microsoft.ML.Data;

namespace Microsoft.ML.Transforms.Onnx
{
/// <summary>
/// The corresponding <see cref="DataViewSchema.Column.Type"/> of ONNX's sequence type in <see cref="IDataView"/>'s type system.
/// In other words, if an ONNX model produces a sequence, a column in <see cref="IDataView"/> may be typed to <see cref="OnnxSequenceType"/>.
/// Its underlying type is <see cref="IEnumerable{T}"/>, where the generic type "T" is the input argument of
/// <see cref="OnnxSequenceType.OnnxSequenceType(Type)"/>.
/// </summary>
public sealed class OnnxSequenceType : StructuredDataViewType
{
private static Type MakeNativeType(Type elementType)
{
var enumerableTypeInfo = typeof(IEnumerable<>);
var enumerableType = enumerableTypeInfo.MakeGenericType(elementType);
return enumerableType;
}

/// <summary>
/// Create the corresponding <see cref="DataViewType"/> for ONNX sequence.
/// </summary>
/// <param name="elementType">The element type of a sequence.</param>
public OnnxSequenceType(Type elementType) : base(MakeNativeType(elementType))
{
DataViewTypeManager.Register(this, RawType, new[] { new OnnxSequenceTypeAttribute(elementType) });
}

public override bool Equals(DataViewType other)
{
if (other is OnnxSequenceType)
return RawType == other.RawType;
else
return false;
}

public override int GetHashCode()
{
return RawType.GetHashCode();
}
}

/// <summary>
/// To declare <see cref="OnnxSequenceType"/> column in <see cref="IDataView"/> as a field
/// in a <see langword="class"/>, the associated field should be marked with <see cref="OnnxSequenceTypeAttribute"/>.
/// Its uses are similar to those of <see cref="VectorTypeAttribute"/> and other <see langword="class"/>es derived
/// from <see cref="DataViewTypeAttribute"/>.
/// </summary>
public sealed class OnnxSequenceTypeAttribute : DataViewTypeAttribute
{
private Type _elemType;

/// <summary>
/// Create a sequence type.
/// </summary>
public OnnxSequenceTypeAttribute()
{
}

/// <summary>
/// Create a <paramref name="elemType"/>-sequence type.
/// </summary>
public OnnxSequenceTypeAttribute(Type elemType)
{
_elemType = elemType;
}

/// <summary>
/// Sequence types with the same element type should be equal.
/// </summary>
public override bool Equals(DataViewTypeAttribute other)
{
if (other is OnnxSequenceTypeAttribute otherSequence)
return _elemType.Equals(otherSequence._elemType);
return false;
}

/// <summary>
/// Produce the same hash code for sequence types with the same element type.
/// </summary>
public override int GetHashCode()
{
return _elemType.GetHashCode();
}

/// <summary>
/// An implementation of <see cref="DataViewTypeAttribute.Register"/>.
/// </summary>
public override void Register()
{
var enumerableType = typeof(IEnumerable<>);
var type = enumerableType.MakeGenericType(_elemType);
DataViewTypeManager.Register(new OnnxSequenceType(_elemType), type, new[] { this });
}
}
}
Loading