Skip to content

Commit

Permalink
PR feedback.
Browse files Browse the repository at this point in the history
  • Loading branch information
codemzs committed May 15, 2019
1 parent e74e52a commit 0e3dc6a
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ internal TransformerChain<ITransformer> GetTransformer()
}
else
{
ITransformer transformer = new TransformWrapper(_host, transform.Transform, false, true);
ITransformer transformer = new TransformWrapper(_host, transform.Transform);
result = result.Append(transformer);
}
}
Expand Down
69 changes: 6 additions & 63 deletions src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,61 +27,28 @@ internal sealed class TransformWrapper : ITransformer

private readonly IHost _host;
private readonly IDataView _xf;
private readonly bool _allowSave;
private readonly bool _isRowToRowMapper;
private readonly bool _useLastTransformOnly;

public TransformWrapper(IHostEnvironment env, IDataView xf, bool allowSave = false, bool useLastTransformOnly = false)
public TransformWrapper(IHostEnvironment env, IDataView xf)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(TransformWrapper));
_host.CheckValue(xf, nameof(xf));
_xf = xf;
_allowSave = allowSave;
_isRowToRowMapper = IsChainRowToRowMapper(_xf);
_useLastTransformOnly = useLastTransformOnly;
}

public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));

var dv = new EmptyDataView(_host, inputSchema);
var output = _useLastTransformOnly ? ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, dv) :
ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, dv);
var output = ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, dv);

return output.Schema;
}

void ICanSaveModel.Save(ModelSaveContext ctx)
{
if (!_allowSave)
throw _host.Except("Saving is not permitted.");
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

var dataPipe = _xf;
var transforms = new List<IDataTransform>();
while (dataPipe is IDataTransform xf)
{
// REVIEW: a malicious user could construct a loop in the Source chain, that would
// cause this method to iterate forever (and throw something when the list overflows). There's
// no way to insulate from ALL malicious behavior.
transforms.Add(xf);
dataPipe = xf.Source;
Contracts.AssertValue(dataPipe);
}
transforms.Reverse();

ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_host, c, dataPipe.Schema));

ctx.Writer.Write(transforms.Count);
for (int i = 0; i < transforms.Count; i++)
{
var dirName = string.Format(TransformDirTemplate, i);
ctx.SaveModel(transforms[i], dirName);
}
}
void ICanSaveModel.Save(ModelSaveContext ctx) => throw _host.Except("Saving is not permitted.");

private static VersionInfo GetVersionInfo()
{
Expand All @@ -100,7 +67,6 @@ private TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(TransformWrapper));
_host.CheckValue(ctx, nameof(ctx));
_allowSave = true;
ctx.CheckAtModel(GetVersionInfo());
int n = ctx.Reader.ReadInt32();
_host.CheckDecode(n >= 0);
Expand All @@ -119,8 +85,7 @@ private TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
_isRowToRowMapper = IsChainRowToRowMapper(_xf);
}

public IDataView Transform(IDataView input) => _useLastTransformOnly ? ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, input) :
ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input);
public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, input);

private static bool IsChainRowToRowMapper(IDataView view)
{
Expand All @@ -137,30 +102,8 @@ private static bool IsChainRowToRowMapper(IDataView view)
IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
var input = new EmptyDataView(_host, inputSchema);
IDataView chain;
if (_useLastTransformOnly)
{
chain = ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, input);
return new CompositeRowToRowMapper(inputSchema, new[] { (IRowToRowMapper)chain });
}
else
{
var revMaps = new List<IRowToRowMapper>();
for (chain = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input);
chain is IDataTransform xf;
chain = xf.Source)
{
// Everything in the chain ought to be a row mapper.
_host.Assert(xf is IRowToRowMapper);
revMaps.Add((IRowToRowMapper)xf);
}

// The walkback should have ended at the input.
Contracts.Assert(chain == input);
revMaps.Reverse();
return new CompositeRowToRowMapper(inputSchema, revMaps.ToArray());
}
return new CompositeRowToRowMapper(inputSchema,
new[] { (IRowToRowMapper)ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, new EmptyDataView(_host, inputSchema)) });
}
}

Expand Down
37 changes: 37 additions & 0 deletions test/Microsoft.ML.Functional.Tests/ModelFiles.cs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,43 @@ void AssertIsGam(ITransformer trans)
Done();
}

public class ModelInput
{
#pragma warning disable SA1401
public string[] CategoricalFeatures;
public float[] NumericalFeatures;
#pragma warning restore SA1401
}

public class ModelOutput
{
#pragma warning disable SA1401
public float[] Score;
#pragma warning restore SA1401
}


[Fact]
public void LoadModelWithOptionalColumnTransform()
{
SchemaDefinition inputSchemaDefinition = SchemaDefinition.Create(typeof(ModelInput));
inputSchemaDefinition[nameof(ModelInput.CategoricalFeatures)].ColumnType = new VectorDataViewType(TextDataViewType.Instance, 5);
inputSchemaDefinition[nameof(ModelInput.NumericalFeatures)].ColumnType = new VectorDataViewType(NumberDataViewType.Single, 3);
var mlContext = new MLContext();
ITransformer trainedModel;
DataViewSchema dataViewSchema;
using (var stream = new FileStream(Path.Join(Directory.GetCurrentDirectory(), @"..\..\..\..\test\data\backcompat\modelwithoptionalcolumntransform.zip"),
FileMode.Open, FileAccess.Read, FileShare.Read))
{
trainedModel = mlContext.Model.Load(stream, out dataViewSchema);
}

var model = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(trainedModel, inputSchemaDefinition: inputSchemaDefinition);
var prediction = model.Predict(new ModelInput() { CategoricalFeatures = new[] { "ABC", "ABC", "ABC", "ABC", "ABC" }, NumericalFeatures = new float [] { 1, 1, 1 } });

Assert.Equal(1, prediction.Score[0]);
}

[Fact]
public void SaveAndLoadModelWithLoader()
{
Expand Down
32 changes: 30 additions & 2 deletions test/Microsoft.ML.Tests/Scenarios/WordBagTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@ public static void WordBags()
var textPipeline =
mlContext.Transforms.Text.ProduceWordBags("Text", "Text",
ngramLength: 3, useAllLengths: false, weighting: NgramExtractingEstimator.WeightingCriteria.Tf).Append(
mlContext.Transforms.Text.ProduceWordBags("Text2", "Text2",
mlContext.Transforms.Text.ProduceWordBags("Text2", new[] { "Text2", "Text2" },
ngramLength: 3, useAllLengths: false, weighting: NgramExtractingEstimator.WeightingCriteria.Tf));


var textTransformer = textPipeline.Fit(dataview);
var transformedDataView = textTransformer.Transform(dataview);
var predictionEngine = mlContext.Model.CreatePredictionEngine<TextData, TransformedTextData>(textTransformer);
var prediction = predictionEngine.Predict(samples[0]);
Assert.Equal(prediction.Text, new float[] {
Expand All @@ -46,6 +45,35 @@ public static void WordBags()
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 });
}

[Fact]
public static void WordBagsHash()
{
var mlContext = new MLContext();
var samples = new List<TextData>()
{
new TextData(){ Text = "This is an example to compute bag-of-word features." },
new TextData(){ Text = "ML.NET's ProduceWordBags API produces bag-of-word features from input text." },
new TextData(){ Text = "It does so by first tokenizing text/string into words/tokens then " },
new TextData(){ Text = "computing n-grams and their neumeric values." },
new TextData(){ Text = "Each position in the output vector corresponds to a particular n-gram." },
new TextData(){ Text = "The value at each position corresponds to," },
new TextData(){ Text = "the number of times n-gram occured in the data (Tf), or" },
new TextData(){ Text = "the inverse of the number of documents contain the n-gram (Idf)," },
new TextData(){ Text = "or compute both and multipy together (Tf-Idf)." },
};

var dataview = mlContext.Data.LoadFromEnumerable(samples);
var textPipeline =
mlContext.Transforms.Text.ProduceHashedWordBags("Text", "Text", ngramLength: 3, useAllLengths: false).Append(
mlContext.Transforms.Text.ProduceHashedWordBags("Text2", new[] { "Text2", "Text2" }, ngramLength: 3, useAllLengths: false));


var textTransformer = textPipeline.Fit(dataview);
var predictionEngine = mlContext.Model.CreatePredictionEngine<TextData, TransformedTextData>(textTransformer);
var prediction = predictionEngine.Predict(samples[0]);
Assert.Equal(65536, prediction.Text.Length);
}

private class TextData
{
public string Text { get; set; }
Expand Down
Binary file not shown.

0 comments on commit 0e3dc6a

Please sign in to comment.