Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1009,12 +1009,9 @@ private bool SaveAsOnnxCore(OnnxContext ctx, string[] outputs, string featureCol
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType();
castNode.AddAttribute("to", t);

// The predictedLabel is a scalar. But the onnx output of ML.NET output expects a [1x1] tensor for output. So reshape it here
opType = "Reshape";
long[] shape = { 1, 1 };
long[] shapeDim = { 2 };
var shapeVar = ctx.AddInitializer(shape, shapeDim, "ShapeVar");
var reshapeNode = ctx.CreateNode(opType, new[] { castNodeOutput, shapeVar }, new[] { predictedLabelUint32 }, ctx.GetNodeName(opType), "");
opType = "Unsqueeze";
var unsqueezeNode = ctx.CreateNode(opType, castNodeOutput, predictedLabelUint32, ctx.GetNodeName(opType), "");
unsqueezeNode.AddAttribute("axes", new long[] { 0 });

return true;
}
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Transforms/Text/TextNormalizing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstV
var opType = "Squeeze";
var squeezeOutput = ctx.AddIntermediateVariable(null, "SqueezeOutput", true);
var node = ctx.CreateNode(opType, srcVariableName, squeezeOutput, ctx.GetNodeName(opType), "");
node.AddAttribute("axes", new long[] { 0 });
node.AddAttribute("axes", new long[] { 1 });

opType = "StringNormalizer";
var normalizerOutput = ctx.AddIntermediateVariable(null, "NormalizerOutput", true);
Expand All @@ -249,7 +249,7 @@ private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstV

opType = "Unsqueeze";
node = ctx.CreateNode(opType, normalizerOutput, dstVariableName, ctx.GetNodeName(opType), "");
node.AddAttribute("axes", new long[] { 0 });
node.AddAttribute("axes", new long[] { 1 });
}
protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
Expand Down
26 changes: 19 additions & 7 deletions src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
private readonly DataViewType _type;
private readonly TokenizingByCharactersTransformer _parent;
private readonly bool[] _isSourceVector;
private readonly int[] _sourceVectorLength;
// Constructed and cached the first time it is needed.
private volatile string _keyValuesStr;
private volatile int[] _keyValuesBoundaries;
Expand All @@ -201,8 +202,13 @@ public Mapper(TokenizingByCharactersTransformer parent, DataViewSchema inputSche
var keyType = new KeyDataViewType(typeof(ushort), CharsCount);
_type = new VectorDataViewType(keyType);
_isSourceVector = new bool[_parent.ColumnPairs.Length];
_sourceVectorLength = new int[_parent.ColumnPairs.Length];
for (int i = 0; i < _isSourceVector.Length; i++)
_isSourceVector[i] = inputSchema[_parent.ColumnPairs[i].inputColumnName].Type is VectorDataViewType;
{
var type = inputSchema[_parent.ColumnPairs[i].inputColumnName].Type;
_isSourceVector[i] = type is VectorDataViewType;
_sourceVectorLength[i] = type.GetValueCount();
}
}

public bool CanSaveOnnx(OnnxContext ctx) => true;
Expand All @@ -219,27 +225,33 @@ public void SaveAsOnnx(OnnxContext ctx)
string outputColumnName = _parent.ColumnPairs[iinfo].outputColumnName;
string srcVariableName = ctx.GetVariableName(inputColumnName);
string dstVariableName = ctx.AddIntermediateVariable(_type, outputColumnName, true);
SaveAsOnnxCore(ctx, srcVariableName, dstVariableName);
SaveAsOnnxCore(ctx, iinfo, srcVariableName, dstVariableName);
}
}

private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName)
private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
{
string opType = "Tokenizer";
string tokenizerOutput = ctx.AddIntermediateVariable(null, "TokenizerOutput", true);
DataViewType dataViewType;
if (_isSourceVector[iinfo])
dataViewType = new VectorDataViewType(TextDataViewType.Instance, _sourceVectorLength[iinfo]);
else
dataViewType = TextDataViewType.Instance;

string tokenizerOutput = ctx.AddIntermediateVariable(dataViewType, "TokenizerOutput", true);
var node = ctx.CreateNode(opType, srcVariableName, tokenizerOutput, ctx.GetNodeName(opType), "com.microsoft");
node.AddAttribute("mark", _parent._useMarkerChars);
node.AddAttribute("mincharnum", 1);
node.AddAttribute("pad_value", "");
node.AddAttribute("separators", new string[] { "" });

opType = "Squeeze";
var squeezeOutput = ctx.AddIntermediateVariable(null, "SqueezeOutput", true);
var squeezeOutput = ctx.AddIntermediateVariable(dataViewType, "SqueezeOutput");
node = ctx.CreateNode(opType, tokenizerOutput, squeezeOutput, ctx.GetNodeName(opType), "");
node.AddAttribute("axes", new long[] { 0 });
node.AddAttribute("axes", new long[] { 1 });

opType = "LabelEncoder";
var labelEncoderOutput = ctx.AddIntermediateVariable(null, "LabelEncoderOutput", true);
var labelEncoderOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "LabelEncoderOutput");
node = ctx.CreateNode(opType, squeezeOutput, labelEncoderOutput, ctx.GetNodeName(opType));

IEnumerable<string> charStrings = Enumerable.Range(0, 65535).Select(x => ((char)x).ToString());
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Transforms/Text/WordTokenizing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -416,9 +416,9 @@ public void SaveAsOnnx(OnnxContext ctx)
tokenizerNode.AddAttribute("separators", separators);

opType = "Squeeze";
var squeezeOutput = ctx.AddIntermediateVariable(_type, column.Name, true);
var squeezeOutput = ctx.AddIntermediateVariable(_type, column.Name);
var squeezeNode = ctx.CreateNode(opType, intermediateVar, squeezeOutput, ctx.GetNodeName(opType), "");
squeezeNode.AddAttribute("axes", new long[] { 0 });
squeezeNode.AddAttribute("axes", new long[] { 1 });
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,14 +274,22 @@
},
{
"input": [
"CastNodeOutput",
"ShapeVar"
"CastNodeOutput"
],
"output": [
"PredictedLabel"
],
"name": "Reshape",
"opType": "Reshape"
"name": "Unsqueeze",
"opType": "Unsqueeze",
"attribute": [
{
"name": "axes",
"ints": [
"0"
],
"type": "INTS"
}
]
},
{
"input": [
Expand Down Expand Up @@ -371,17 +379,6 @@
],
"name": "model",
"initializer": [
{
"dims": [
"2"
],
"dataType": 7,
"int64Data": [
"1",
"1"
],
"name": "ShapeVar"
},
{
"dims": [
"1",
Expand Down