Skip to content

Commit

Permalink
Add option to execute only the last transform in TransformWrapper.
Browse files Browse the repository at this point in the history
  • Loading branch information
codemzs committed May 10, 2019
1 parent f6faab1 commit 0879374
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ public static IDataView LoadSelectedTransforms(ModelLoadContext ctx, IDataView s
internal TransformerChain<ITransformer> GetTransformer()
{
var result = new TransformerChain<ITransformer>();
IDataTransform lastTransformer = null;
foreach (var transform in _transforms)
{
if (transform.Transform is RowToRowMapperTransform mapper)
Expand All @@ -412,9 +413,11 @@ internal TransformerChain<ITransformer> GetTransformer()
}
else
{
ITransformer transformer = new TransformWrapper(_host, transform.Transform);
ITransformer transformer = new TransformWrapper(_host, transform.Transform, false, lastTransformer is RowToRowMapperTransform);
result = result.Append(transformer);
}

lastTransformer = transform.Transform;
}
return result;
}
Expand Down
40 changes: 28 additions & 12 deletions src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,27 @@ internal sealed class TransformWrapper : ITransformer
private readonly IDataView _xf;
private readonly bool _allowSave;
private readonly bool _isRowToRowMapper;
private readonly bool _useLastTransformOnly;

public TransformWrapper(IHostEnvironment env, IDataView xf, bool allowSave = false)
public TransformWrapper(IHostEnvironment env, IDataView xf, bool allowSave = false, bool useLastTransformOnly = false)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(TransformWrapper));
_host.CheckValue(xf, nameof(xf));
_xf = xf;
_allowSave = allowSave;
_isRowToRowMapper = IsChainRowToRowMapper(_xf);
_useLastTransformOnly = useLastTransformOnly;
}

public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));

var dv = new EmptyDataView(_host, inputSchema);
var output = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, dv);
var output = _useLastTransformOnly ? ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, dv) :
ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, dv);

return output.Schema;
}

Expand Down Expand Up @@ -115,7 +119,8 @@ private TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
_isRowToRowMapper = IsChainRowToRowMapper(_xf);
}

public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input);
public IDataView Transform(IDataView input) => _useLastTransformOnly ? ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, input) :
ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input);

private static bool IsChainRowToRowMapper(IDataView view)
{
Expand All @@ -133,18 +138,29 @@ IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
var input = new EmptyDataView(_host, inputSchema);
var revMaps = new List<IRowToRowMapper>();
IDataView chain;
for (chain = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input); chain is IDataTransform xf; chain = xf.Source)
if (_useLastTransformOnly)
{
chain = ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, input);
return new CompositeRowToRowMapper(inputSchema, new[] { (IRowToRowMapper)chain });
}
else
{
// Everything in the chain ought to be a row mapper.
_host.Assert(xf is IRowToRowMapper);
revMaps.Add((IRowToRowMapper)xf);
var revMaps = new List<IRowToRowMapper>();
for (chain = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input);
chain is IDataTransform xf;
chain = xf.Source)
{
// Everything in the chain ought to be a row mapper.
_host.Assert(xf is IRowToRowMapper);
revMaps.Add((IRowToRowMapper)xf);
}

// The walkback should have ended at the input.
Contracts.Assert(chain == input);
revMaps.Reverse();
return new CompositeRowToRowMapper(inputSchema, revMaps.ToArray());
}
// The walkback should have ended at the input.
Contracts.Assert(chain == input);
revMaps.Reverse();
return new CompositeRowToRowMapper(inputSchema, revMaps.ToArray());
}
}

Expand Down

0 comments on commit 0879374

Please sign in to comment.