diff --git a/src/Microsoft.ML.Data/DataLoadSave/LegacyCompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/LegacyCompositeDataLoader.cs index c293b87c68..e64333bbeb 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/LegacyCompositeDataLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/LegacyCompositeDataLoader.cs @@ -403,6 +403,7 @@ public static IDataView LoadSelectedTransforms(ModelLoadContext ctx, IDataView s internal TransformerChain GetTransformer() { var result = new TransformerChain(); + IDataTransform lastTransformer = null; foreach (var transform in _transforms) { if (transform.Transform is RowToRowMapperTransform mapper) @@ -412,9 +413,11 @@ internal TransformerChain GetTransformer() } else { - ITransformer transformer = new TransformWrapper(_host, transform.Transform); + ITransformer transformer = new TransformWrapper(_host, transform.Transform, false, lastTransformer is RowToRowMapperTransform); result = result.Append(transformer); } + + lastTransformer = transform.Transform; } return result; } diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs index 1f876f459b..18dedc8f6a 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs @@ -29,8 +29,9 @@ internal sealed class TransformWrapper : ITransformer private readonly IDataView _xf; private readonly bool _allowSave; private readonly bool _isRowToRowMapper; + private readonly bool _useLastTransformOnly; - public TransformWrapper(IHostEnvironment env, IDataView xf, bool allowSave = false) + public TransformWrapper(IHostEnvironment env, IDataView xf, bool allowSave = false, bool useLastTransformOnly = false) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(TransformWrapper)); @@ -38,6 +39,7 @@ public TransformWrapper(IHostEnvironment env, IDataView xf, bool allowSave = fal _xf = xf; _allowSave = allowSave; _isRowToRowMapper = IsChainRowToRowMapper(_xf); + _useLastTransformOnly = useLastTransformOnly; } public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) @@ -45,7 +47,9 @@ public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) _host.CheckValue(inputSchema, nameof(inputSchema)); var dv = new EmptyDataView(_host, inputSchema); - var output = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, dv); + var output = _useLastTransformOnly ? ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, dv) : + ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, dv); + return output.Schema; } @@ -115,7 +119,8 @@ private TransformWrapper(IHostEnvironment env, ModelLoadContext ctx) _isRowToRowMapper = IsChainRowToRowMapper(_xf); } - public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input); + public IDataView Transform(IDataView input) => _useLastTransformOnly ? ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, input) : + ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input); private static bool IsChainRowToRowMapper(IDataView view) { @@ -133,18 +138,29 @@ IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); var input = new EmptyDataView(_host, inputSchema); - var revMaps = new List(); IDataView chain; - for (chain = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input); chain is IDataTransform xf; chain = xf.Source) + if (_useLastTransformOnly) + { + chain = ApplyTransformUtils.ApplyTransformToData(_host, (IDataTransform)_xf, input); + return new CompositeRowToRowMapper(inputSchema, new[] { (IRowToRowMapper)chain }); + } + else { - // Everything in the chain ought to be a row mapper. - _host.Assert(xf is IRowToRowMapper); - revMaps.Add((IRowToRowMapper)xf); + var revMaps = new List(); + for (chain = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input); + chain is IDataTransform xf; + chain = xf.Source) + { + // Everything in the chain ought to be a row mapper. + _host.Assert(xf is IRowToRowMapper); + revMaps.Add((IRowToRowMapper)xf); + } + + // The walkback should have ended at the input. + Contracts.Assert(chain == input); + revMaps.Reverse(); + return new CompositeRowToRowMapper(inputSchema, revMaps.ToArray()); } - // The walkback should have ended at the input. - Contracts.Assert(chain == input); - revMaps.Reverse(); - return new CompositeRowToRowMapper(inputSchema, revMaps.ToArray()); } }