Skip to content

Commit

Permalink
Added onnx export support for WordTokenizingTransformer and NgramExtr…
Browse files Browse the repository at this point in the history
…actingTransformer (#4451)

* Added onnx export support for string related transforms

* Updated baseline test files

A large portion of this commit is upgrading the baseline test files. The rest of the fixes deal with build breaks resulting from the upgrade of ORT version.

* Fixed bugs in ValueToKeyMappingTransformer and added additional tests
  • Loading branch information
harishsk authored Nov 13, 2019
1 parent 5910910 commit 693250b
Show file tree
Hide file tree
Showing 10 changed files with 608 additions and 50 deletions.
23 changes: 11 additions & 12 deletions src/Microsoft.ML.Data/Transforms/KeyToVector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -606,16 +606,11 @@ public void SaveAsOnnx(OnnxContext ctx)
ColInfo info = _infos[iinfo];
string inputColumnName = info.InputColumnName;
if (!ctx.ContainsColumn(inputColumnName))
{
ctx.RemoveColumn(info.Name, false);
continue;
}

if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(inputColumnName),
ctx.AddIntermediateVariable(_types[iinfo], info.Name)))
{
ctx.RemoveColumn(info.Name, true);
}
var srcVariableName = ctx.GetVariableName(inputColumnName);
var dstVariableName = ctx.AddIntermediateVariable(_types[iinfo], info.Name);
SaveAsOnnxCore(ctx, iinfo, info, srcVariableName, dstVariableName);
}
}

Expand Down Expand Up @@ -692,7 +687,7 @@ private JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo info, JToke
PfaUtils.Call("cast.fanoutDouble", -1, 0, keyCount, false), PfaUtils.FuncRef("u." + funcName));
}

private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
var shape = ctx.RetrieveShapeOrNull(srcVariableName);
// Make sure that shape must present for calculating the reduction axes. The shape here is generally not null
Expand All @@ -703,8 +698,13 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
// default ONNX LabelEncoder just matches the behavior of Bag=false.
var encodedVariableName = _parent._columns[iinfo].OutputCountVector ? ctx.AddIntermediateVariable(null, "encoded", true) : dstVariableName;

string opType = "OneHotEncoder";
var node = ctx.CreateNode(opType, srcVariableName, encodedVariableName, ctx.GetNodeName(opType));
string opType = "Cast";
var castOutput = ctx.AddIntermediateVariable(info.TypeSrc, opType, true);
var castNode = ctx.CreateNode(opType, srcVariableName, castOutput, ctx.GetNodeName(opType), "");
castNode.AddAttribute("to", typeof(long));

opType = "OneHotEncoder";
var node = ctx.CreateNode(opType, castOutput, encodedVariableName, ctx.GetNodeName(opType));
node.AddAttribute("cats_int64s", Enumerable.Range(0, info.TypeSrc.GetItemType().GetKeyCountAsInt32(Host)).Select(x => (long)x));
node.AddAttribute("zeros", true);
if (_parent._columns[iinfo].OutputCountVector)
Expand All @@ -717,7 +717,6 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
reduceNode.AddAttribute("axes", new long[] { shape.Count - 1 });
reduceNode.AddAttribute("keepdims", 0);
}
return true;
}
}
}
Expand Down
70 changes: 59 additions & 11 deletions src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -768,22 +768,70 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b

private Delegate MakeGetter<T>(DataViewRow row, int src) => _termMap[src].GetMappingGetter(row);

private IEnumerable<T> GetTermsAndIds<T>(int iinfo, out long[] termIds)
{
var terms = default(VBuffer<T>);
var map = (TermMap<T>)_termMap[iinfo].Map;
map.GetTerms(ref terms);

var termValues = terms.DenseValues();
var keyMapper = map.GetKeyMapper();

int i = 0;
termIds = new long[map.Count];
foreach (var term in termValues)
{
uint id = 0;
keyMapper(term, ref id);
termIds[i++] = id;
}
return termValues;
}

private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
if (!(info.TypeSrc.GetItemType() is TextDataViewType))
OnnxNode node;
long[] termIds;
string opType = "LabelEncoder";
var labelEncoderOutput = ctx.AddIntermediateVariable(_types[iinfo], "LabelEncoderOutput", true);

if (info.TypeSrc.GetItemType().Equals(TextDataViewType.Instance))
{
node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType));
var terms = GetTermsAndIds<ReadOnlyMemory<char>>(iinfo, out termIds);
node.AddAttribute("keys_strings", terms);
}
else if (info.TypeSrc.GetItemType().Equals(NumberDataViewType.Single))
{
node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType));
var terms = GetTermsAndIds<float>(iinfo, out termIds);
node.AddAttribute("keys_floats", terms);
}
else
{
// LabelEncoder-2 in ORT v1 only supports the following mappings
// int64-> float
// int64-> string
// float -> int64
// float -> string
// string -> int64
// string -> float
// In ML.NET the output of ValueToKeyMappingTransformer is always an integer type.
// Therefore the only input types we can accept for Onnx conversion are strings and floats handled above.
return false;
}

var terms = default(VBuffer<ReadOnlyMemory<char>>);
TermMap<ReadOnlyMemory<char>> map = (TermMap<ReadOnlyMemory<char>>)_termMap[iinfo].Map;
map.GetTerms(ref terms);
string opType = "LabelEncoder";
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
node.AddAttribute("classes_strings", terms.DenseValues());
node.AddAttribute("default_int64", -1);
//default_string needs to be an empty string but there is a BUG in Lotus that
//throws a validation error when default_string is empty. As a work around, set
//default_string to a space.
node.AddAttribute("default_string", " ");
node.AddAttribute("values_int64s", termIds);

// Onnx outputs an Int64, but ML.NET outputs a keytype. So cast it here
InternalDataKind dataKind;
InternalDataKindExtensions.TryGetDataKind(_parent._unboundMaps[iinfo].OutputType.RawType, out dataKind);

opType = "Cast";
var castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), "");
castNode.AddAttribute("to", dataKind.ToType());

return true;
}

Expand Down
19 changes: 15 additions & 4 deletions src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,8 @@ public static NamedOnnxValue CreateScalarNamedOnnxValue<T>(string name, T data)
throw new NotImplementedException($"Not implemented type {typeof(T)}");

if (typeof(T) == typeof(ReadOnlyMemory<char>))
{
return NamedOnnxValue.CreateFromTensor<string>(name, new DenseTensor<string>(new string[] { data.ToString() }, new int[] { 1, 1 }, false));
}
return NamedOnnxValue.CreateFromTensor<string>(name, new DenseTensor<string>(new string[] { data.ToString() }, new int[] { 1, 1 }));

return NamedOnnxValue.CreateFromTensor<T>(name, new DenseTensor<T>(new T[] { data }, new int[] { 1, 1 }));
}

Expand All @@ -452,7 +451,19 @@ public static NamedOnnxValue CreateNamedOnnxValue<T>(string name, ReadOnlySpan<T
{
if (!_onnxTypeMap.Contains(typeof(T)))
throw new NotImplementedException($"Not implemented type {typeof(T)}");
return NamedOnnxValue.CreateFromTensor<T>(name, new DenseTensor<T>(data.ToArray(), shape.Select(x => (int)x).ToArray()));

var dimensions = shape.Select(x => (int)x).ToArray();

if (typeof(T) == typeof(ReadOnlyMemory<char>))
{
string[] stringData = new string[data.Length];
for (int i = 0; i < data.Length; i++)
stringData[i] = data[i].ToString();

return NamedOnnxValue.CreateFromTensor<string>(name, new DenseTensor<string>(stringData, dimensions));
}

return NamedOnnxValue.CreateFromTensor<T>(name, new DenseTensor<T>(data.ToArray(), dimensions));
}

/// <summary>
Expand Down
158 changes: 157 additions & 1 deletion src/Microsoft.ML.Transforms/Text/NgramTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms.Text;

Expand Down Expand Up @@ -124,6 +125,7 @@ private sealed class TransformInfo
public readonly bool[] NonEmptyLevels;
public readonly int NgramLength;
public readonly int SkipLength;
public readonly bool UseAllLengths;
public readonly NgramExtractingEstimator.WeightingCriteria Weighting;

public bool RequireIdf => Weighting == NgramExtractingEstimator.WeightingCriteria.Idf || Weighting == NgramExtractingEstimator.WeightingCriteria.TfIdf;
Expand All @@ -133,6 +135,7 @@ public TransformInfo(NgramExtractingEstimator.ColumnOptions info)
NgramLength = info.NgramLength;
SkipLength = info.SkipLength;
Weighting = info.Weighting;
UseAllLengths = info.UseAllLengths;
NonEmptyLevels = new bool[NgramLength];
}

Expand Down Expand Up @@ -469,7 +472,7 @@ private protected override void SaveModel(ModelSaveContext ctx)

private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);

private sealed class Mapper : OneToOneMapperBase
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
{
private readonly DataViewType[] _srcTypes;
private readonly int[] _srcCols;
Expand Down Expand Up @@ -551,6 +554,81 @@ private void GetSlotNames(int iinfo, int size, ref VBuffer<ReadOnlyMemory<char>>
dst = dstEditor.Commit();
}

private IEnumerable<long> GetNgramData(int iinfo, out long[] ngramCounts, out double[] weights, out List<long> indexes)
{
var transformInfo = _parent._transformInfos[iinfo];
var itemType = _srcTypes[iinfo].GetItemType();

Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length);
Host.Assert(InputSchema[_srcCols[iinfo]].HasKeyValues());

// Get the key values of the unigrams.
var keyCount = itemType.GetKeyCountAsInt32(Host);

var maxNGramLength = transformInfo.NgramLength;

var pool = _parent._ngramMaps[iinfo];

// the ngrams in ML.NET are sequentially organized. e.g. {a, a|b, b, b|c...}
// in onnx, they need to be separated by type. e.g. {a, b, c, a|b, b|c...}
// since the resulting vectors need to match, we need to create a mapping
// between the two and store it in the node attributes

// create seprate lists to track the ids of 1-grams, 2-grams etc
// because they need to be in adjacent regions in the same list
// when supplied to onnx
// We later concatenate all these separate n-gram lists
var ngramIds = new List<long>[maxNGramLength];
var ngramIndexes = new List<long>[maxNGramLength];
for (int i = 0; i < ngramIds.Length; i++)
{
ngramIds[i] = new List<long>();
ngramIndexes[i] = new List<long>();
//ngramWeights[i] = new List<float>();
}

weights = new double[pool.Count];

uint[] ngram = new uint[maxNGramLength];
for (int i = 0; i < pool.Count; i++)
{
var n = pool.GetById(i, ref ngram);
Host.Assert(n >= 0);

// add the id of each gram to the corresponding ids list
for (int j = 0; j < n; j++)
ngramIds[n - 1].Add(ngram[j]);

// add the indexes to the corresponding list
ngramIndexes[n - 1].Add(i);

if (transformInfo.RequireIdf)
weights[i] = _parent._invDocFreqs[iinfo][i];
else
weights[i] = 1.0f;
}

// initialize the ngramCounts array with start-index of each n-gram
int start = 0;
ngramCounts = new long[maxNGramLength];
for (int i = 0; i < maxNGramLength; i++)
{
ngramCounts[i] = start;
start += ngramIds[i].Count;
}

// concatenate all the lists and return
IEnumerable<long> allNGramIds = ngramIds[0];
indexes = ngramIndexes[0];
for (int i = 1; i < maxNGramLength; i++)
{
allNGramIds = Enumerable.Concat(allNGramIds, ngramIds[i]);
indexes = indexes.Concat(ngramIndexes[i]).ToList();
}

return allNGramIds;
}

private void ComposeNgramString(uint[] ngram, int count, StringBuilder sb, int keyCount, in VBuffer<ReadOnlyMemory<char>> terms)
{
Host.AssertValue(sb);
Expand Down Expand Up @@ -660,6 +738,84 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
}
return del;
}

public bool CanSaveOnnx(OnnxContext ctx) => true;

public void SaveAsOnnx(OnnxContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));

int numColumns = _parent.ColumnPairs.Length;
for (int iinfo = 0; iinfo < numColumns; ++iinfo)
{
string inputColumnName = _parent.ColumnPairs[iinfo].inputColumnName;
if (!ctx.ContainsColumn(inputColumnName))
continue;

string outputColumnName = _parent.ColumnPairs[iinfo].outputColumnName;
string dstVariableName = ctx.AddIntermediateVariable(_srcTypes[iinfo], outputColumnName, true);
SaveAsOnnxCore(ctx, iinfo, ctx.GetVariableName(inputColumnName), dstVariableName);
}
}

private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName )
{
VBuffer<ReadOnlyMemory<char>> slotNames = default;
GetSlotNames(iinfo, 0, ref slotNames);

var transformInfo = _parent._transformInfos[iinfo];

// TfIdfVectorizer accepts strings, int32 and int64 tensors.
// But in the ML.NET implementation of the NGramTransform, it only accepts keys as inputs
// That are the result of ValueToKeyMapping transformer, which outputs UInt32 values
// So, if it is UInt32 or UInt64, cast the output here to Int32/Int64
string opType;
var vectorType = _srcTypes[iinfo] as VectorDataViewType;

if ((vectorType != null) &&
((vectorType.RawType == typeof(VBuffer<UInt32>)) || (vectorType.RawType == typeof(VBuffer<UInt64>))))
{
var dataKind = _srcTypes[iinfo] == NumberDataViewType.UInt32 ? DataKind.Int32 : DataKind.Int64;

opType = "Cast";
string castOutput = ctx.AddIntermediateVariable(_srcTypes[iinfo], "CastOutput", true);

var castNode = ctx.CreateNode(opType, srcVariableName, castOutput, ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(dataKind).ToType();
castNode.AddAttribute("to", t);

srcVariableName = castOutput;
}

opType = "TfIdfVectorizer";
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), "");
node.AddAttribute("max_gram_length", transformInfo.NgramLength);
node.AddAttribute("max_skip_count", transformInfo.SkipLength);
node.AddAttribute("min_gram_length", transformInfo.UseAllLengths ? 1 : transformInfo.NgramLength);

string mode;
if (transformInfo.RequireIdf)
{
mode = transformInfo.Weighting == NgramExtractingEstimator.WeightingCriteria.Idf ? "IDF" : "TFIDF";
}
else
{
mode = "TF";
}
node.AddAttribute("mode", mode);

long[] ngramCounts;
double[] ngramWeights;
List<long> ngramIndexes;

var ngramIds = GetNgramData(iinfo, out ngramCounts, out ngramWeights, out ngramIndexes);

node.AddAttribute("ngram_counts", ngramCounts);
node.AddAttribute("pool_int64s", ngramIds);
node.AddAttribute("ngram_indexes", ngramIndexes);
node.AddAttribute("weights", ngramWeights);
}

}
}

Expand Down
Loading

0 comments on commit 693250b

Please sign in to comment.