Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to execute only the last transform in TransformWrapper and have WordBagEstimator return transformer chain #3700

Merged
merged 13 commits into from
May 16, 2019
89 changes: 9 additions & 80 deletions src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Data.DataLoadSave;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Runtime;

[assembly: LoadableClass(typeof(TransformWrapper), null, typeof(SignatureLoadModel),
codemzs marked this conversation as resolved.
Show resolved Hide resolved
"Transform wrapper", TransformWrapper.LoaderSignature)]

namespace Microsoft.ML.Data
{
/// <summary>
Expand All @@ -23,20 +16,19 @@ namespace Microsoft.ML.Data
internal sealed class TransformWrapper : ITransformer
{
internal const string LoaderSignature = "TransformWrapper";
private const string TransformDirTemplate = "Step_{0:000}";

private readonly IHost _host;
private readonly IDataView _xf;
private readonly bool _allowSave;
private readonly bool _isRowToRowMapper;

public TransformWrapper(IHostEnvironment env, IDataView xf, bool allowSave = false)
public TransformWrapper(IHostEnvironment env, IDataView xf)
{
Contracts.CheckValue(env, nameof(env));
Contracts.Check(xf is IDataTransform);

_host = env.Register(nameof(TransformWrapper));
_host.CheckValue(xf, nameof(xf));
_xf = xf;
_allowSave = allowSave;
_isRowToRowMapper = IsChainRowToRowMapper(_xf);
codemzs marked this conversation as resolved.
Show resolved Hide resolved
}

Expand All @@ -45,39 +37,12 @@ public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
_host.CheckValue(inputSchema, nameof(inputSchema));

var dv = new EmptyDataView(_host, inputSchema);
var output = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, dv);
var output = ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, dv);
codemzs marked this conversation as resolved.
Show resolved Hide resolved

return output.Schema;
}

void ICanSaveModel.Save(ModelSaveContext ctx)
{
if (!_allowSave)
throw _host.Except("Saving is not permitted.");
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

var dataPipe = _xf;
var transforms = new List<IDataTransform>();
while (dataPipe is IDataTransform xf)
{
// REVIEW: a malicious user could construct a loop in the Source chain, that would
// cause this method to iterate forever (and throw something when the list overflows). There's
// no way to insulate from ALL malicious behavior.
transforms.Add(xf);
dataPipe = xf.Source;
Contracts.AssertValue(dataPipe);
}
transforms.Reverse();

ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_host, c, dataPipe.Schema));

ctx.Writer.Write(transforms.Count);
for (int i = 0; i < transforms.Count; i++)
{
var dirName = string.Format(TransformDirTemplate, i);
ctx.SaveModel(transforms[i], dirName);
}
}
void ICanSaveModel.Save(ModelSaveContext ctx) => throw _host.Except("Saving is not permitted.");
codemzs marked this conversation as resolved.
Show resolved Hide resolved

private static VersionInfo GetVersionInfo()
codemzs marked this conversation as resolved.
Show resolved Hide resolved
{
Expand All @@ -90,32 +55,7 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(TransformWrapper).Assembly.FullName);
}

// Factory for SignatureLoadModel.
private TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(TransformWrapper));
_host.CheckValue(ctx, nameof(ctx));
_allowSave = true;
ctx.CheckAtModel(GetVersionInfo());
int n = ctx.Reader.ReadInt32();
_host.CheckDecode(n >= 0);

ctx.LoadModel<ILegacyDataLoader, SignatureLoadDataLoader>(env, out var loader, "Loader", new MultiFileSource(null));

IDataView data = loader;
for (int i = 0; i < n; i++)
{
var dirName = string.Format(TransformDirTemplate, i);
ctx.LoadModel<IDataTransform, SignatureLoadDataTransform>(env, out var xf, dirName, data);
data = xf;
}

_xf = data;
_isRowToRowMapper = IsChainRowToRowMapper(_xf);
}

public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input);
public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, input);

private static bool IsChainRowToRowMapper(IDataView view)
{
Expand All @@ -132,19 +72,8 @@ private static bool IsChainRowToRowMapper(IDataView view)
IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
var input = new EmptyDataView(_host, inputSchema);
var revMaps = new List<IRowToRowMapper>();
IDataView chain;
for (chain = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input); chain is IDataTransform xf; chain = xf.Source)
{
// Everything in the chain ought to be a row mapper.
_host.Assert(xf is IRowToRowMapper);
revMaps.Add((IRowToRowMapper)xf);
}
// The walkback should have ended at the input.
Contracts.Assert(chain == input);
revMaps.Reverse();
return new CompositeRowToRowMapper(inputSchema, revMaps.ToArray());
return new CompositeRowToRowMapper(inputSchema,
new[] { (IRowToRowMapper)ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, new EmptyDataView(_host, inputSchema)) });
codemzs marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down
23 changes: 14 additions & 9 deletions src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ internal sealed class Options : NgramExtractorTransform.ArgumentsBase
internal const string Summary = "Produces a bag of counts of n-grams (sequences of consecutive words of length 1-n) in a given text. It does so by building "
+ "a dictionary of n-grams and using the id in the dictionary as the index in the bag.";

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
internal static ITransformer CreateTransfomer(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(RegistrationName);
Expand Down Expand Up @@ -150,10 +150,16 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa
}

IDataView view = input;
view = NgramExtractionUtils.ApplyConcatOnSources(h, options.Columns, view);
view = new WordTokenizingEstimator(env, tokenizeColumns).Fit(view).Transform(view);
return NgramExtractorTransform.CreateDataTransform(h, extractorArgs, view);
ITransformer t0 = NgramExtractionUtils.ApplyConcatOnSources(h, options.Columns);
view = t0.Transform(view);
ITransformer t1 = new WordTokenizingEstimator(env, tokenizeColumns).Fit(view);
view = t1.Transform(view);
ITransformer t2 = NgramExtractorTransform.Create(h, extractorArgs, t1.Transform(view));
codemzs marked this conversation as resolved.
Show resolved Hide resolved
return new TransformerChain<ITransformer>(new[] { t0, t1, t2 });
}

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) =>
(IDataTransform)CreateTransfomer(env, options, input).Transform(input);
}

/// <summary>
Expand Down Expand Up @@ -489,13 +495,11 @@ public ITransformer Create(IHostEnvironment env, IDataView input, ExtractorColum

internal static class NgramExtractionUtils
{
public static IDataView ApplyConcatOnSources(IHostEnvironment env, ManyToOneColumn[] columns, IDataView input)
public static ITransformer ApplyConcatOnSources(IHostEnvironment env, ManyToOneColumn[] columns)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(columns, nameof(columns));
env.CheckValue(input, nameof(input));

IDataView view = input;
var concatColumns = new List<ColumnConcatenatingTransformer.ColumnOptions>();
foreach (var col in columns)
{
Expand All @@ -506,10 +510,11 @@ public static IDataView ApplyConcatOnSources(IHostEnvironment env, ManyToOneColu
if (col.Source.Length > 1)
concatColumns.Add(new ColumnConcatenatingTransformer.ColumnOptions(col.Name, col.Source));
}

if (concatColumns.Count > 0)
return new ColumnConcatenatingTransformer(env, concatColumns.ToArray()).Transform(view);
return new ColumnConcatenatingTransformer(env, concatColumns.ToArray());

return view;
return new TransformerChain<ITransformer>();
codemzs marked this conversation as resolved.
Show resolved Hide resolved
}

/// <summary>
Expand Down
14 changes: 10 additions & 4 deletions src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ internal sealed class Options : NgramHashExtractingTransformer.ArgumentsBase
internal const string Summary = "Produces a bag of counts of n-grams (sequences of consecutive words of length 1-n) in a given text. "
+ "It does so by hashing each n-gram and using the hash value as the index in the bag.";

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
internal static ITransformer CreateTransformer(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(RegistrationName);
Expand Down Expand Up @@ -132,7 +132,7 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa
};
}

view = new WordTokenizingEstimator(env, tokenizeColumns.ToArray()).Fit(view).Transform(view);
ITransformer t1 = new WordTokenizingEstimator(env, tokenizeColumns.ToArray()).Fit(view);

var featurizeArgs =
new NgramHashExtractingTransformer.Options
Expand All @@ -147,11 +147,17 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa
MaximumNumberOfInverts = options.MaximumNumberOfInverts
};

view = NgramHashExtractingTransformer.Create(h, featurizeArgs, view).Transform(view);
view = t1.Transform(view);
ITransformer t2 = NgramHashExtractingTransformer.Create(h, featurizeArgs, t1.Transform(view));
codemzs marked this conversation as resolved.
Show resolved Hide resolved

// Since we added columns with new names, we need to explicitly drop them before we return the IDataTransform.
return ColumnSelectingTransformer.CreateDrop(h, view, tmpColNames.ToArray()) as IDataTransform;
ITransformer t3 = new ColumnSelectingTransformer(env, null, tmpColNames.ToArray());

return new TransformerChain<ITransformer>(new[] { t1, t2, t3 });
}

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) =>
(IDataTransform)CreateTransformer(env, options, input).Transform(input);
}

/// <summary>
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ public ITransformer Fit(IDataView input)
Weighting = _weighting
};

return new TransformWrapper(_host, WordBagBuildingTransformer.Create(_host, options, input), true);
return WordBagBuildingTransformer.CreateTransfomer(_host, options, input);
}

/// <summary>
Expand Down Expand Up @@ -365,7 +365,7 @@ public ITransformer Fit(IDataView input)
MaximumNumberOfInverts = _maximumNumberOfInverts
};

return new TransformWrapper(_host, WordHashBagProducingTransformer.Create(_host, options, input), true);
return WordHashBagProducingTransformer.CreateTransformer(_host, options, input);
}

/// <summary>
Expand Down
6 changes: 5 additions & 1 deletion test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,11 @@ public void EntryPointPipelineEnsembleText()
new WordHashBagProducingTransformer.Options()
{
Columns =
new[] { new WordHashBagProducingTransformer.Column() { Name = "Features", Source = new[] { "Text" } }, }
new[]
{
new WordHashBagProducingTransformer.Column()
{Name = "Features", Source = new[] {"Text"}},
}
},
data);
}
Expand Down
37 changes: 37 additions & 0 deletions test/Microsoft.ML.Functional.Tests/ModelFiles.cs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,43 @@ void AssertIsGam(ITransformer trans)
Done();
}

public class ModelInput
{
#pragma warning disable SA1401
public string[] CategoricalFeatures;
public float[] NumericalFeatures;
#pragma warning restore SA1401
}

public class ModelOutput
{
#pragma warning disable SA1401
public float[] Score;
#pragma warning restore SA1401
}


[Fact]
public void LoadModelWithOptionalColumnTransform()
{
SchemaDefinition inputSchemaDefinition = SchemaDefinition.Create(typeof(ModelInput));
inputSchemaDefinition[nameof(ModelInput.CategoricalFeatures)].ColumnType = new VectorDataViewType(TextDataViewType.Instance, 5);
inputSchemaDefinition[nameof(ModelInput.NumericalFeatures)].ColumnType = new VectorDataViewType(NumberDataViewType.Single, 3);
var mlContext = new MLContext();
ITransformer trainedModel;
DataViewSchema dataViewSchema;
using (var stream = new FileStream(Path.Combine(Directory.GetCurrentDirectory(), @"..\..\..\..\test\data\backcompat\modelwithoptionalcolumntransform.zip"),
FileMode.Open, FileAccess.Read, FileShare.Read))
{
trainedModel = mlContext.Model.Load(stream, out dataViewSchema);
codemzs marked this conversation as resolved.
Show resolved Hide resolved
}

var model = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(trainedModel, inputSchemaDefinition: inputSchemaDefinition);
codemzs marked this conversation as resolved.
Show resolved Hide resolved
var prediction = model.Predict(new ModelInput() { CategoricalFeatures = new[] { "ABC", "ABC", "ABC", "ABC", "ABC" }, NumericalFeatures = new float [] { 1, 1, 1 } });

Assert.Equal(1, prediction.Score[0]);
}

[Fact]
public void SaveAndLoadModelWithLoader()
{
Expand Down
Loading