Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ internal abstract class OnnxNode
public abstract void AddAttribute(string argName, string[] value);
public abstract void AddAttribute(string argName, IEnumerable<string> value);
public abstract void AddAttribute(string argName, IEnumerable<bool> value);
public abstract void AddAttribute(string argName, Type t);
}
}
19 changes: 7 additions & 12 deletions src/Microsoft.ML.Data/Transforms/TypeConverting.cs
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,13 @@ internal static bool GetNewType(IExceptionContext ectx, DataViewType srcType, In
return true;
}

private sealed class Mapper : OneToOneMapperBase, ICanSaveOnnx
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
{
private readonly TypeConvertingTransformer _parent;
private readonly DataViewType[] _types;
private readonly int[] _srcCols;

public bool CanSaveOnnx(OnnxContext ctx) => ctx.GetOnnxVersion() == OnnxVersion.Experimental;
public bool CanSaveOnnx(OnnxContext ctx) => true;

public Mapper(TypeConvertingTransformer parent, DataViewSchema inputSchema)
: base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
Expand Down Expand Up @@ -497,22 +497,17 @@ public void SaveAsOnnx(OnnxContext ctx)

private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
var opType = "CSharp";
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
node.AddAttribute("type", LoaderSignature);
node.AddAttribute("to", (byte)_parent._columns[iinfo].OutputKind);
if (_parent._columns[iinfo].OutputKeyCount != null)
{
var key = (KeyDataViewType)_types[iinfo].GetItemType();
node.AddAttribute("max", key.Count);
}
var opType = "Cast";
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), "");
var t = _parent._columns[iinfo].OutputKind.ToInternalDataKind().ToType();
Copy link
Member Author

@ganik ganik Aug 29, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

var t [](start = 16, length = 6)

Need to handle key type #Resolved

node.AddAttribute("to", t);
return true;
}
}
}

/// <summary>
/// Estimator for <see cref="KeyToVectorMappingTransformer"/>. Converts the underlying input column type to a new type.
/// Estimator for <see cref="TypeConvertingTransformer"/>. Converts the underlying input column type to a new type.
/// The input and output column types need to be compatible.
/// <see cref="PrimitiveDataViewType"/>
/// </summary>
Expand Down
8 changes: 7 additions & 1 deletion src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,17 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
var addNodeY = ctx.CreateNode("Add", new[] { nameZ, nameC2 }, new[] { nameY }, ctx.GetNodeName("Add"), "");

// Compute the most-matched cluster index, L
var nameL = outputNames[0];
var nameL = "ArgMinInt64";
var predictNodeL = ctx.CreateNode("ArgMin", nameY, nameL, ctx.GetNodeName("ArgMin"), "");
predictNodeL.AddAttribute("axis", 1);
predictNodeL.AddAttribute("keepdims", 1);

// ArgMin outputs an Int64. But ML.NET's KMeans trainer outputs a UINT32.
// Cast the output here to UInt32 to make them compatible
var predictedNode = ctx.CreateNode("Cast", nameL, outputNames[0], ctx.GetNodeName("Cast"), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType();
predictedNode.AddAttribute("to", t);

return true;
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/Microsoft.ML.OnnxConverter/OnnxNodeImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,7 @@ public override void AddAttribute(string argName, string value)
=> OnnxUtils.NodeAddAttributes(_node, argName, value);
public override void AddAttribute(string argName, bool value)
=> OnnxUtils.NodeAddAttributes(_node, argName, value);
public override void AddAttribute(string argName, Type value)
=> OnnxUtils.NodeAddAttributes(_node, argName, value);
}
}
79 changes: 48 additions & 31 deletions src/Microsoft.ML.OnnxConverter/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ private static AttributeProto MakeAttribute(string key)
return attribute;
}

private static AttributeProto MakeAttribute(string key, TensorProto.Types.DataType value)
{
AttributeProto attribute = MakeAttribute(key);
attribute.Type = AttributeProto.Types.AttributeType.Int;
attribute.I = (int)value;
return attribute;
}

private static AttributeProto MakeAttribute(string key, double value)
{
AttributeProto attribute = MakeAttribute(key);
Expand Down Expand Up @@ -211,6 +219,45 @@ public static void NodeAddAttributes(NodeProto node, string argName, IEnumerable

public static void NodeAddAttributes(NodeProto node, string argName, bool value)
=> node.Attribute.Add(MakeAttribute(argName, value));
public static void NodeAddAttributes(NodeProto node, string argName, Type value)
=> node.Attribute.Add(MakeAttribute(argName, ConvertToTensorProtoType(value)));

private static TensorProto.Types.DataType ConvertToTensorProtoType(Type rawType)
{
var dataType = TensorProto.Types.DataType.Undefined;

if (rawType == typeof(bool))
dataType = TensorProto.Types.DataType.Float;
else if (rawType == typeof(ReadOnlyMemory<char>))
dataType = TensorProto.Types.DataType.String;
else if (rawType == typeof(sbyte))
dataType = TensorProto.Types.DataType.Int8;
else if (rawType == typeof(byte))
dataType = TensorProto.Types.DataType.Uint8;
else if (rawType == typeof(short))
dataType = TensorProto.Types.DataType.Int16;
else if (rawType == typeof(ushort))
dataType = TensorProto.Types.DataType.Uint16;
else if (rawType == typeof(int))
dataType = TensorProto.Types.DataType.Int32;
else if (rawType == typeof(uint))
dataType = TensorProto.Types.DataType.Uint32;
else if (rawType == typeof(long))
dataType = TensorProto.Types.DataType.Int64;
else if (rawType == typeof(ulong))
dataType = TensorProto.Types.DataType.Uint64;
else if (rawType == typeof(float))
dataType = TensorProto.Types.DataType.Float;
else if (rawType == typeof(double))
dataType = TensorProto.Types.DataType.Double;
else
{
string msg = "Unsupported type: " + rawType.ToString();
Contracts.Check(false, msg);
}

return dataType;
}

private static ByteString StringToByteString(ReadOnlyMemory<char> str) => ByteString.CopyFrom(Encoding.UTF8.GetBytes(str.ToString()));
private static IEnumerable<ByteString> StringToByteString(IEnumerable<ReadOnlyMemory<char>> str)
Expand Down Expand Up @@ -295,42 +342,12 @@ public static ModelArgs GetModelArgs(DataViewType type, string colName,
Contracts.CheckValue(type, nameof(type));
Contracts.CheckNonEmpty(colName, nameof(colName));

TensorProto.Types.DataType dataType = TensorProto.Types.DataType.Undefined;
Type rawType;
if (type is VectorDataViewType vectorType)
rawType = vectorType.ItemType.RawType;
else
rawType = type.RawType;

if (rawType == typeof(bool))
dataType = TensorProto.Types.DataType.Float;
else if (rawType == typeof(ReadOnlyMemory<char>))
dataType = TensorProto.Types.DataType.String;
else if (rawType == typeof(sbyte))
dataType = TensorProto.Types.DataType.Int8;
else if (rawType == typeof(byte))
dataType = TensorProto.Types.DataType.Uint8;
else if (rawType == typeof(short))
dataType = TensorProto.Types.DataType.Int16;
else if (rawType == typeof(ushort))
dataType = TensorProto.Types.DataType.Uint16;
else if (rawType == typeof(int))
dataType = TensorProto.Types.DataType.Int32;
else if (rawType == typeof(uint))
dataType = TensorProto.Types.DataType.Int64;
else if (rawType == typeof(long))
dataType = TensorProto.Types.DataType.Int64;
else if (rawType == typeof(ulong))
dataType = TensorProto.Types.DataType.Uint64;
else if (rawType == typeof(float))
dataType = TensorProto.Types.DataType.Float;
else if (rawType == typeof(double))
dataType = TensorProto.Types.DataType.Double;
else
{
string msg = "Unsupported type: " + type.ToString();
Contracts.Check(false, msg);
}
var dataType = ConvertToTensorProtoType(rawType);

string name = colName;
List<long> dimsLocal = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@
"name": "F20",
"type": {
"tensorType": {
"elemType": "INT64",
"elemType": "UINT32",
"shape": {
"dim": [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@
"name": "F20",
"type": {
"tensorType": {
"elemType": "INT64",
"elemType": "UINT32",
"shape": {
"dim": [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@
"name": "F20",
"type": {
"tensorType": {
"elemType": "INT64",
"elemType": "UINT32",
"shape": {
"dim": [
{
Expand Down
21 changes: 19 additions & 2 deletions test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
"Score"
],
"output": [
"PredictedLabel"
"ArgMinInt64"
],
"name": "ArgMin",
"opType": "ArgMin",
Expand All @@ -126,6 +126,23 @@
}
]
},
{
"input": [
"ArgMinInt64"
],
"output": [
"PredictedLabel"
],
"name": "Cast",
"opType": "Cast",
"attribute": [
{
"name": "to",
"i": "12",
"type": "INT"
}
]
},
{
"input": [
"Features0"
Expand Down Expand Up @@ -272,7 +289,7 @@
"name": "PredictedLabel0",
"type": {
"tensorType": {
"elemType": "INT64",
"elemType": "UINT32",
"shape": {
"dim": [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@
"name": "Label1",
"type": {
"tensorType": {
"elemType": "INT64",
"elemType": "UINT32",
"shape": {
"dim": [
{
Expand Down Expand Up @@ -348,7 +348,7 @@
"name": "PredictedLabel0",
"type": {
"tensorType": {
"elemType": "INT64",
"elemType": "UINT32",
"shape": {
"dim": [
{
Expand Down Expand Up @@ -404,7 +404,7 @@
"name": "Label0",
"type": {
"tensorType": {
"elemType": "INT64",
"elemType": "UINT32",
"shape": {
"dim": [
{
Expand Down
Loading