Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/Microsoft.ML.Transforms/Text/TextCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,31 @@ public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransf
=> new WordBagEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(),
outputColumnName, inputColumnName, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting);

/// <summary>
/// Create a <see cref="WordBagEstimator"/>, which maps the column specified in <paramref name="inputColumnName"/>
/// to a vector of n-gram counts in a new column named <paramref name="outputColumnName"/>.
/// </summary>
/// <remarks>
/// <see cref="WordBagEstimator"/> is different from <see cref="NgramExtractingEstimator"/> in that the former
/// tokenizes text internally and the latter takes tokenized text as input.
/// </remarks>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.
/// This column's data type will be known-size vector of <see cref="System.Single"/>.</param>
/// <param name="inputColumnName">Name of the column to take the data from.
/// <param name="maximumNgramsCount">Maximum number of n-grams to store in the dictionary.</param>
/// <param name="termSeparator">Separator used to separate terms/frequency pairs.</param>
/// <param name="freqSeparator">Separator used to separate terms from their frequency.</param>
/// This estimator operates over vector of text.</param>
public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransforms catalog,
string outputColumnName,
char termSeparator,
char freqSeparator,
string inputColumnName = null,
int maximumNgramsCount = NgramExtractingEstimator.Defaults.MaximumNgramsCount)
=> new WordBagEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(),
outputColumnName, inputColumnName, 1, 0, true, maximumNgramsCount, NgramExtractingEstimator.WeightingCriteria.Tf, termSeparator: termSeparator, freqSeparator: freqSeparator);

/// <summary>
/// Create a <see cref="WordBagEstimator"/>, which maps the multiple columns specified in <paramref name="inputColumnNames"/>
/// to a vector of n-gram counts in a new column named <paramref name="outputColumnName"/>.
Expand Down
200 changes: 198 additions & 2 deletions src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
Expand All @@ -12,6 +13,7 @@
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms.Text;
using static Microsoft.ML.Transforms.Text.WordBagBuildingTransformer;

[assembly: LoadableClass(WordBagBuildingTransformer.Summary, typeof(IDataTransform), typeof(WordBagBuildingTransformer), typeof(WordBagBuildingTransformer.Options), typeof(SignatureDataTransform),
"Word Bag Transform", "WordBagTransform", "WordBag")]
Expand All @@ -21,6 +23,16 @@

[assembly: EntryPointModule(typeof(NgramExtractorTransform.NgramExtractorArguments))]

// These are for the internal only TextExpandingTransformer. Not exposed publically
[assembly: LoadableClass(TextExpandingTransformer.Summary, typeof(IDataTransform), typeof(TextExpandingTransformer), null, typeof(SignatureLoadDataTransform),
TextExpandingTransformer.UserName, TextExpandingTransformer.LoaderSignature)]

[assembly: LoadableClass(typeof(TextExpandingTransformer), null, typeof(SignatureLoadModel),
TextExpandingTransformer.UserName, TextExpandingTransformer.LoaderSignature)]

[assembly: LoadableClass(typeof(IRowMapper), typeof(TextExpandingTransformer), null, typeof(SignatureLoadRowMapper),
TextExpandingTransformer.UserName, TextExpandingTransformer.LoaderSignature)]

namespace Microsoft.ML.Transforms.Text
{
/// <summary>
Expand Down Expand Up @@ -144,18 +156,195 @@ internal static IEstimator<ITransformer> CreateEstimator(IHostEnvironment env, O
NgramLength = column.NgramLength,
SkipLength = column.SkipLength,
Weighting = column.Weighting,
UseAllLengths = column.UseAllLengths
UseAllLengths = column.UseAllLengths,
};
}

IEstimator<ITransformer> estimator = NgramExtractionUtils.GetConcatEstimator(h, options.Columns);
estimator = estimator.Append(new WordTokenizingEstimator(env, tokenizeColumns));
if (options.FreqSeparator != default)
{
estimator = estimator.Append(new TextExpandingEstimator(h, tokenizeColumns[0].InputColumnName, options.FreqSeparator, options.TermSeparator));
}
estimator = estimator.Append(new WordTokenizingEstimator(h, tokenizeColumns));
estimator = estimator.Append(NgramExtractorTransform.CreateEstimator(h, extractorArgs, estimator.GetOutputSchema(inputSchema)));
return estimator;
}

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) =>
(IDataTransform)CreateEstimator(env, options, SchemaShape.Create(input.Schema)).Fit(input).Transform(input);

#region TextExpander

// Internal only estimator used to facilitate the expansion of ngrams with pre-defined weights
internal sealed class TextExpandingEstimator : TrivialEstimator<TextExpandingTransformer>
{
private readonly string _columnName;
public TextExpandingEstimator(IHostEnvironment env, string columnName, char freqSeparator, char termSeparator)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TextExpandingEstimator)), new TextExpandingTransformer(env, columnName, freqSeparator, termSeparator))
{
_columnName = columnName;
}

public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
if (!inputSchema.TryFindColumn(_columnName, out SchemaShape.Column outCol) && outCol.ItemType != TextDataViewType.Instance)
{
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _columnName);
}

return inputSchema;
}
}

// Internal only transformer used to facilitate the expansion of ngrams with pre-defined weights
internal sealed class TextExpandingTransformer : RowToRowTransformerBase
{
internal const string Summary = "Expands text in the format of term:freq; to have the correct number of terms";
internal const string UserName = "Text Expanding Transform";
internal const string LoadName = "TextExpand";

internal const string LoaderSignature = "TextExpandTransform";

private readonly string _columnName;
private readonly char _freqSeparator;
private readonly char _termSeparator;

public TextExpandingTransformer(IHostEnvironment env, string columnName, char freqSeparator, char termSeparator)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TextExpandingTransformer)))
{
_columnName = columnName;
_freqSeparator = freqSeparator;
_termSeparator = termSeparator;
}

private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "TEXT EXP",
verWrittenCur: 0x00010001, // Initial
verReadableCur: 0x00010001,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(TextExpandingTransformer).Assembly.FullName);
}

/// <summary>
/// Factory method for SignatureLoadModel.
/// </summary>
private TextExpandingTransformer(IHostEnvironment env, ModelLoadContext ctx) :
base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnConcatenatingTransformer)))
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
// *** Binary format ***
// string: column n ame
// char: frequency separator
// char: term separator

_columnName = ctx.Reader.ReadString();
_freqSeparator = ctx.Reader.ReadChar();
_termSeparator = ctx.Reader.ReadChar();
}

/// <summary>
/// Factory method for SignatureLoadRowMapper.
/// </summary>
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
=> new TextExpandingTransformer(env, ctx).MakeRowMapper(inputSchema);

/// <summary>
/// Factory method for SignatureLoadDataTransform.
/// </summary>
private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
=> new TextExpandingTransformer(env, ctx).MakeDataTransform(input);

private protected override IRowMapper MakeRowMapper(DataViewSchema schema)
{
return new Mapper(Host, schema, this);
}

private protected override void SaveModel(ModelSaveContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

// *** Binary format ***
// string: column n ame
// char: frequency separator
// char: term separator

ctx.Writer.Write(_columnName);
ctx.Writer.Write(_freqSeparator);
ctx.Writer.Write(_termSeparator);
}

private sealed class Mapper : MapperBase
{
private readonly TextExpandingTransformer _parent;
public Mapper(IHost host, DataViewSchema inputSchema, RowToRowTransformerBase parent)
: base(host, inputSchema, parent)
{
_parent = (TextExpandingTransformer)parent;
}

protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
return new DataViewSchema.DetachedColumn[]
{
new DataViewSchema.DetachedColumn(_parent._columnName, TextDataViewType.Instance)
};
}

protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
{
disposer = null;
ValueGetter<ReadOnlyMemory<char>> srcGetter = input.GetGetter<ReadOnlyMemory<char>>(input.Schema.GetColumnOrNull(_parent._columnName).Value);
ReadOnlyMemory<char> inputMem = default;
var sb = new StringBuilder();

ValueGetter<ReadOnlyMemory<char>> result = (ref ReadOnlyMemory<char> dst) =>
{
sb.Clear();
srcGetter(ref inputMem);
var inputText = inputMem.ToString();
foreach (var termFreq in inputText.Split(_parent._termSeparator))
{
var tf = termFreq.Split(_parent._freqSeparator);
if (tf.Length != 2)
sb.Append(tf[0] + " ");
else
{
for (int i = 0; i < int.Parse(tf[1]); i++)
sb.Append(tf[0] + " ");
}
}

dst = sb.ToString().AsMemory();
};

return result;
}

private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
{
var active = new bool[InputSchema.Count];
if (activeOutput(0))
{
active[InputSchema.GetColumnOrNull(_parent._columnName).Value.Index] = true;
}
return col => active[col];
}

private protected override void SaveModel(ModelSaveContext ctx)
{
_parent.SaveModel(ctx);
}
}
}

#endregion TextExpander
}

/// <summary>
Expand Down Expand Up @@ -235,6 +424,13 @@ internal abstract class ArgumentsBase

[Argument(ArgumentType.AtMostOnce, HelpText = "The weighting criteria")]
public NgramExtractingEstimator.WeightingCriteria Weighting = NgramExtractingEstimator.Defaults.Weighting;

[Argument(ArgumentType.AtMostOnce, HelpText = "Separator used to separate terms/frequency pairs.")]
public char TermSeparator = default;

[Argument(ArgumentType.AtMostOnce, HelpText = "Separator used to separate terms from their frequency.")]
public char FreqSeparator = default;

}

[TlcModule.Component(Name = "NGram", FriendlyName = "NGram Extractor Transform", Alias = "NGramExtractorTransform,NGramExtractor",
Expand Down
30 changes: 24 additions & 6 deletions src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ public sealed class WordBagEstimator : IEstimator<ITransformer>
private readonly bool _useAllLengths;
private readonly int _maxNumTerms;
private readonly NgramExtractingEstimator.WeightingCriteria _weighting;
private readonly char _termSeparator;
private readonly char _freqSeparator;

/// <summary>
/// Options for how the n-grams are extracted.
Expand Down Expand Up @@ -99,15 +101,19 @@ public Options()
/// <param name="useAllLengths">Whether to include all n-gram lengths up to <paramref name="ngramLength"/> or only <paramref name="ngramLength"/>.</param>
/// <param name="maximumNgramsCount">Maximum number of n-grams to store in the dictionary.</param>
/// <param name="weighting">Statistical measure used to evaluate how important a word is to a document in a corpus.</param>
/// <param name="termSeparator">Separator used to separate terms/frequency pairs.</param>
/// <param name="freqSeparator">Separator used to separate terms from their frequency.</param>
internal WordBagEstimator(IHostEnvironment env,
string outputColumnName,
string inputColumnName = null,
int ngramLength = 1,
int skipLength = 0,
bool useAllLengths = true,
int maximumNgramsCount = 10000000,
NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf)
: this(env, outputColumnName, new[] { inputColumnName ?? outputColumnName }, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting)
NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf,
char termSeparator = default,
char freqSeparator = default)
: this(env, outputColumnName, new[] { inputColumnName ?? outputColumnName }, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting, termSeparator, freqSeparator)
{
}

Expand All @@ -123,15 +129,19 @@ internal WordBagEstimator(IHostEnvironment env,
/// <param name="useAllLengths">Whether to include all n-gram lengths up to <paramref name="ngramLength"/> or only <paramref name="ngramLength"/>.</param>
/// <param name="maximumNgramsCount">Maximum number of n-grams to store in the dictionary.</param>
/// <param name="weighting">Statistical measure used to evaluate how important a word is to a document in a corpus.</param>
/// <param name="termSeparator">Separator used to separate terms/frequency pairs.</param>
/// <param name="freqSeparator">Separator used to separate terms from their frequency.</param>
internal WordBagEstimator(IHostEnvironment env,
string outputColumnName,
string[] inputColumnNames,
int ngramLength = 1,
int skipLength = 0,
bool useAllLengths = true,
int maximumNgramsCount = 10000000,
NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf)
: this(env, new[] { (outputColumnName, inputColumnNames) }, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting)
NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf,
char termSeparator = default,
char freqSeparator = default)
: this(env, new[] { (outputColumnName, inputColumnNames) }, ngramLength, skipLength, useAllLengths, maximumNgramsCount, weighting, termSeparator, freqSeparator)
{
}

Expand All @@ -146,13 +156,17 @@ internal WordBagEstimator(IHostEnvironment env,
/// <param name="useAllLengths">Whether to include all n-gram lengths up to <paramref name="ngramLength"/> or only <paramref name="ngramLength"/>.</param>
/// <param name="maximumNgramsCount">Maximum number of n-grams to store in the dictionary.</param>
/// <param name="weighting">Statistical measure used to evaluate how important a word is to a document in a corpus.</param>
/// <param name="termSeparator">Separator used to separate terms/frequency pairs.</param>
/// <param name="freqSeparator">Separator used to separate terms from their frequency.</param>
internal WordBagEstimator(IHostEnvironment env,
(string outputColumnName, string[] inputColumnNames)[] columns,
int ngramLength = 1,
int skipLength = 0,
bool useAllLengths = true,
int maximumNgramsCount = 10000000,
NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf)
NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf,
char termSeparator = default,
char freqSeparator = default)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(WordBagEstimator));
Expand All @@ -169,6 +183,8 @@ internal WordBagEstimator(IHostEnvironment env,
_useAllLengths = useAllLengths;
_maxNumTerms = maximumNgramsCount;
_weighting = weighting;
_termSeparator = termSeparator;
_freqSeparator = freqSeparator;
}

/// <summary> Trains and returns a <see cref="ITransformer"/>.</summary>
Expand All @@ -187,7 +203,9 @@ private WordBagBuildingTransformer.Options CreateOptions()
SkipLength = _skipLength,
UseAllLengths = _useAllLengths,
MaxNumTerms = new[] { _maxNumTerms },
Weighting = _weighting
Weighting = _weighting,
TermSeparator = _termSeparator,
FreqSeparator = _freqSeparator,
};
}

Expand Down
Loading