From 8a45f37cf87e380a93146d08acac19f215648f9a Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 15 Nov 2018 09:33:05 -0800 Subject: [PATCH] Remove lazy parameters for GetRowCount (#1621) * Remove lazy parameters for GetRowCount * Address comments --- .../DataViewConstructionUtils.cs | 8 ++++---- src/Microsoft.ML.Api/StatefulFilterTransform.cs | 2 +- src/Microsoft.ML.Core/Data/IDataView.cs | 16 +++++++--------- src/Microsoft.ML.Data/Data/DataViewUtils.cs | 2 +- src/Microsoft.ML.Data/Data/RowCursorUtils.cs | 2 +- .../DataLoadSave/Binary/BinaryLoader.cs | 2 +- .../DataLoadSave/CompositeDataLoader.cs | 4 ++-- .../DataLoadSave/PartitionedFileLoader.cs | 2 +- .../DataLoadSave/Text/TextLoader.cs | 2 +- .../DataLoadSave/Text/TextSaver.cs | 2 +- .../DataLoadSave/Transpose/TransposeLoader.cs | 2 +- .../DataView/AppendRowsDataView.cs | 6 +++--- .../DataView/ArrayDataViewBuilder.cs | 2 +- src/Microsoft.ML.Data/DataView/CacheDataView.cs | 17 ++++++----------- src/Microsoft.ML.Data/DataView/EmptyDataView.cs | 2 +- .../DataView/OpaqueDataView.cs | 4 ++-- src/Microsoft.ML.Data/DataView/Transposer.cs | 8 ++++---- src/Microsoft.ML.Data/DataView/ZipDataView.cs | 4 ++-- .../Evaluators/RankerEvaluator.cs | 4 ++-- .../Transforms/NopTransform.cs | 4 ++-- .../Transforms/PerGroupTransformBase.cs | 4 ++-- .../Transforms/SelectColumnsTransform.cs | 2 +- .../Transforms/SkipTakeFilter.cs | 4 ++-- .../Transforms/TermTransform.cs | 4 ++-- .../Transforms/TransformBase.cs | 6 +++--- src/Microsoft.ML.FastTree/FastTree.cs | 2 +- src/Microsoft.ML.Parquet/ParquetLoader.cs | 2 +- .../AutoInference.cs | 2 +- .../LogisticRegression/LbfgsPredictorBase.cs | 2 +- .../SequentialTransformBase.cs | 4 ++-- .../SequentialTransformerBase.cs | 4 ++-- .../CountFeatureSelectionTransformer.cs | 2 +- src/Microsoft.ML.Transforms/GroupTransform.cs | 2 +- .../Text/NgramTransform.cs | 2 +- src/Microsoft.ML.Transforms/UngroupTransform.cs | 4 ++-- src/Microsoft.ML.Transforms/VectorWhitening.cs | 2 +- .../Microsoft.ML.Tests/Scenarios/Api/TestApi.cs | 2 +- 37 files changed, 69 insertions(+), 76 deletions(-) diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index a4bc8cfda5..e01a993dcd 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -397,7 +397,7 @@ protected DataViewBase(IHostEnvironment env, string name, InternalSchemaDefiniti } } - public abstract long? GetRowCount(bool lazy = true); + public abstract long? GetRowCount(); public abstract IRowCursor GetRowCursor(Func predicate, IRandom rand = null); @@ -555,7 +555,7 @@ public override bool CanShuffle get { return true; } } - public override long? GetRowCount(bool lazy = true) + public override long? GetRowCount() { return _data.Count; } @@ -654,7 +654,7 @@ public override bool CanShuffle get { return false; } } - public override long? GetRowCount(bool lazy = true) + public override long? GetRowCount() { return (_data as ICollection)?.Count; } @@ -735,7 +735,7 @@ public override bool CanShuffle get { return false; } } - public override long? GetRowCount(bool lazy = true) + public override long? GetRowCount() { return null; } diff --git a/src/Microsoft.ML.Api/StatefulFilterTransform.cs b/src/Microsoft.ML.Api/StatefulFilterTransform.cs index eb67085bc2..9b93425f63 100644 --- a/src/Microsoft.ML.Api/StatefulFilterTransform.cs +++ b/src/Microsoft.ML.Api/StatefulFilterTransform.cs @@ -99,7 +99,7 @@ private StatefulFilterTransform(IHostEnvironment env, StatefulFilterTransform _bindings.Schema; - public long? GetRowCount(bool lazy = true) + public long? GetRowCount() { // REVIEW: currently stateful map is implemented via filter, and this is sub-optimal. return null; diff --git a/src/Microsoft.ML.Core/Data/IDataView.cs b/src/Microsoft.ML.Core/Data/IDataView.cs index 32cfbd2285..3f0b89aed7 100644 --- a/src/Microsoft.ML.Core/Data/IDataView.cs +++ b/src/Microsoft.ML.Core/Data/IDataView.cs @@ -82,17 +82,15 @@ public interface IDataView : ISchematized bool CanShuffle { get; } /// - /// Returns the number of rows if known. Null means unknown. If lazy is true, then - /// this is permitted to return null when it might return a non-null value on a subsequent - /// call. This indicates, that the transform does not YET know the number of rows, but - /// may in the future. If lazy is false, then this is permitted to do some work (no more - /// that it would normally do for cursoring) to determine the number of rows. + /// Returns the number of rows if known. Returning null means that the row count is unknown but + /// it might return a non-null value on a subsequent call. This indicates, that the transform does + /// not YET know the number of rows, but may in the future. Its implementation's computation + /// complexity should be O(1). /// - /// Most components will return the same answer whether lazy is true or false. Some, like - /// a cache, might return null until the cache is fully populated (when lazy is true). When - /// lazy is false, such a cache would block until the cache was populated. + /// Most implementation will return the same answer every time. Some, like a cache, might + /// return null until the cache is fully populated. /// - long? GetRowCount(bool lazy = true); + long? GetRowCount(); /// /// Get a row cursor. The active column indices are those for which needCol(col) returns true. diff --git a/src/Microsoft.ML.Data/Data/DataViewUtils.cs b/src/Microsoft.ML.Data/Data/DataViewUtils.cs index 0e33390c7e..b6ebaef135 100644 --- a/src/Microsoft.ML.Data/Data/DataViewUtils.cs +++ b/src/Microsoft.ML.Data/Data/DataViewUtils.cs @@ -77,7 +77,7 @@ public static string[] GetTempColumnNames(this ISchema schema, int n, string tag /// public static long ComputeRowCount(IDataView view) { - long? countNullable = view.GetRowCount(lazy: false); + long? countNullable = view.GetRowCount(); if (countNullable != null) return countNullable.Value; long count = 0; diff --git a/src/Microsoft.ML.Data/Data/RowCursorUtils.cs b/src/Microsoft.ML.Data/Data/RowCursorUtils.cs index e26e61fff6..2979825d56 100644 --- a/src/Microsoft.ML.Data/Data/RowCursorUtils.cs +++ b/src/Microsoft.ML.Data/Data/RowCursorUtils.cs @@ -541,7 +541,7 @@ public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Fun return new IRowCursor[] { GetRowCursor(needCol, rand) }; } - public long? GetRowCount(bool lazy = true) + public long? GetRowCount() { return 1; } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs index 39816d3f24..bc8c5178c2 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs @@ -761,7 +761,7 @@ public void GetMetadata(string kind, int col, ref TValue value) private long RowCount { get { return _header.RowCount; } } - public long? GetRowCount(bool lazy = true) { return RowCount; } + public long? GetRowCount() { return RowCount; } public bool CanShuffle { get { return true; } } diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs index 0009ad4768..7206745e03 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs @@ -557,9 +557,9 @@ private static string GenerateTag(int index) return string.Format("xf{0:00}", index); } - public long? GetRowCount(bool lazy = true) + public long? GetRowCount() { - return View.GetRowCount(lazy); + return View.GetRowCount(); } public bool CanShuffle => View.CanShuffle; diff --git a/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs index 5998cd0f22..eb2fd269e7 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/PartitionedFileLoader.cs @@ -287,7 +287,7 @@ public void Save(ModelSaveContext ctx) public Schema Schema { get; } - public long? GetRowCount(bool lazy = true) + public long? GetRowCount() { return null; } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs index 282b3feea3..a808374573 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs @@ -1352,7 +1352,7 @@ public BoundLoader(TextLoader reader, IMultiStreamSource files) _files = files; } - public long? GetRowCount(bool lazy = true) + public long? GetRowCount() { // We don't know how many rows there are. // REVIEW: Should we try to support RowCount? diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs index db84057b79..8161c8e653 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextSaver.cs @@ -420,7 +420,7 @@ private void WriteDataCore(IChannel ch, TextWriter writer, IDataView data, if (_outputSchema) WriteSchemaAsComment(writer, header); - double rowCount = data.GetRowCount(true) ?? double.NaN; + double rowCount = data.GetRowCount() ?? double.NaN; using (var pch = !_silent ? _host.StartProgressChannel("TextSaver: saving data") : null) { long stateCount = 0; diff --git a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs index de1964c27b..b12f3ad1e7 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Transpose/TransposeLoader.cs @@ -662,7 +662,7 @@ public VectorType GetSlotType(int col) } } - public long? GetRowCount(bool lazy = true) + public long? GetRowCount() { return _header.RowCount; } diff --git a/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs b/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs index 4ea44a6957..89f863ca18 100644 --- a/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs +++ b/src/Microsoft.ML.Data/DataView/AppendRowsDataView.cs @@ -91,7 +91,7 @@ private AppendRowsDataView(IHostEnvironment env, Schema schema, IDataView[] sour _counts = null; break; } - long? count = dv.GetRowCount(true); + long? count = dv.GetRowCount(); if (count == null || count < 0 || count > int.MaxValue) { _canShuffle = false; @@ -127,12 +127,12 @@ private void CheckSchemaConsistency() } } - public long? GetRowCount(bool lazy = true) + public long? GetRowCount() { long sum = 0; foreach (var source in _sources) { - var cur = source.GetRowCount(lazy); + var cur = source.GetRowCount(); if (cur == null) return null; _host.Check(cur.Value >= 0, "One of the sources returned a negative row count"); diff --git a/src/Microsoft.ML.Data/DataView/ArrayDataViewBuilder.cs b/src/Microsoft.ML.Data/DataView/ArrayDataViewBuilder.cs index b7f9b494e9..ef6c06d9ad 100644 --- a/src/Microsoft.ML.Data/DataView/ArrayDataViewBuilder.cs +++ b/src/Microsoft.ML.Data/DataView/ArrayDataViewBuilder.cs @@ -197,7 +197,7 @@ private sealed class DataView : IDataView public Schema Schema { get { return _schema; } } - public long? GetRowCount(bool lazy = true) { return _rowCount; } + public long? GetRowCount() { return _rowCount; } public bool CanShuffle { get { return true; } } diff --git a/src/Microsoft.ML.Data/DataView/CacheDataView.cs b/src/Microsoft.ML.Data/DataView/CacheDataView.cs index 3865229d27..876663b33c 100644 --- a/src/Microsoft.ML.Data/DataView/CacheDataView.cs +++ b/src/Microsoft.ML.Data/DataView/CacheDataView.cs @@ -193,18 +193,13 @@ public int MapInputToCacheColumnIndex(int inputIndex) public Schema Schema => _subsetInput.Schema; - public long? GetRowCount(bool lazy = true) + /// + /// Return the number of rows if available. + /// + public long? GetRowCount() { if (_rowCount < 0) - { - if (lazy) - return null; - if (_cacheDefaultWaiter == null) - KickoffFiller(new int[0]); - _host.Assert(_cacheDefaultWaiter != null); - _cacheDefaultWaiter.Wait(long.MaxValue); - _host.Assert(_rowCount >= 0); - } + return null; return _rowCount; } @@ -317,7 +312,7 @@ public IRowSeeker GetSeeker(Func predicate) _host.CheckValue(predicate, nameof(predicate)); // The seeker needs to know the row count when it validates the row index to move to. // Calling GetRowCount here to force a wait indirectly so that _rowCount will have a valid value. - GetRowCount(false); + GetRowCount(); _host.Assert(_rowCount >= 0); var waiter = WaiterWaiter.Create(this, predicate); if (waiter.IsTrivial) diff --git a/src/Microsoft.ML.Data/DataView/EmptyDataView.cs b/src/Microsoft.ML.Data/DataView/EmptyDataView.cs index 543ac42952..8c6f385f88 100644 --- a/src/Microsoft.ML.Data/DataView/EmptyDataView.cs +++ b/src/Microsoft.ML.Data/DataView/EmptyDataView.cs @@ -25,7 +25,7 @@ public EmptyDataView(IHostEnvironment env, Schema schema) Schema = schema; } - public long? GetRowCount(bool lazy = true) => 0; + public long? GetRowCount() => 0; public IRowCursor GetRowCursor(Func needCol, IRandom rand = null) { diff --git a/src/Microsoft.ML.Data/DataView/OpaqueDataView.cs b/src/Microsoft.ML.Data/DataView/OpaqueDataView.cs index 44d8d0dcad..cc8b08a87a 100644 --- a/src/Microsoft.ML.Data/DataView/OpaqueDataView.cs +++ b/src/Microsoft.ML.Data/DataView/OpaqueDataView.cs @@ -21,9 +21,9 @@ public OpaqueDataView(IDataView source) _source = source; } - public long? GetRowCount(bool lazy = true) + public long? GetRowCount() { - return _source.GetRowCount(lazy); + return _source.GetRowCount(); } public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) diff --git a/src/Microsoft.ML.Data/DataView/Transposer.cs b/src/Microsoft.ML.Data/DataView/Transposer.cs index 5424f5a04b..ae38fea8b0 100644 --- a/src/Microsoft.ML.Data/DataView/Transposer.cs +++ b/src/Microsoft.ML.Data/DataView/Transposer.cs @@ -274,7 +274,7 @@ public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Fun return _view.GetRowCursorSet(out consolidator, predicate, n, rand); } - public long? GetRowCount(bool lazy = true) + public long? GetRowCount() { // Not a passthrough. return RowCount; @@ -818,9 +818,9 @@ public DataViewSlicer(IHost host, IDataView input, int[] toSlice) _schema = new SchemaImpl(this, nameToCol); } - public long? GetRowCount(bool lazy = true) + public long? GetRowCount() { - return _input.GetRowCount(lazy); + return _input.GetRowCount(); } /// @@ -1503,7 +1503,7 @@ public SlotDataView(IHostEnvironment env, ITransposeDataView data, int col) _schemaImpl = new SchemaImpl(this); } - public long? GetRowCount(bool lazy = true) + public long? GetRowCount() { var type = _data.Schema.GetColumnType(_col); int valueCount = type.ValueCount; diff --git a/src/Microsoft.ML.Data/DataView/ZipDataView.cs b/src/Microsoft.ML.Data/DataView/ZipDataView.cs index b87efd9195..5489491b3f 100644 --- a/src/Microsoft.ML.Data/DataView/ZipDataView.cs +++ b/src/Microsoft.ML.Data/DataView/ZipDataView.cs @@ -54,12 +54,12 @@ private ZipDataView(IHost host, IDataView[] sources) public Schema Schema => _compositeSchema.AsSchema; - public long? GetRowCount(bool lazy = true) + public long? GetRowCount() { long min = -1; foreach (var source in _sources) { - var cur = source.GetRowCount(lazy); + var cur = source.GetRowCount(); if (cur == null) return null; _host.Check(cur.Value >= 0, "One of the sources returned a negative row count"); diff --git a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs index 46278662f2..3cc3ddb4a6 100644 --- a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs @@ -635,9 +635,9 @@ public void Save(ModelSaveContext ctx) _transform.Save(ctx); } - public long? GetRowCount(bool lazy = true) + public long? GetRowCount() { - return _transform.GetRowCount(lazy); + return _transform.GetRowCount(); } public IRowCursor GetRowCursor(Func needCol, IRandom rand = null) diff --git a/src/Microsoft.ML.Data/Transforms/NopTransform.cs b/src/Microsoft.ML.Data/Transforms/NopTransform.cs index bf48e357f7..1ba9ed4c38 100644 --- a/src/Microsoft.ML.Data/Transforms/NopTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/NopTransform.cs @@ -103,9 +103,9 @@ public bool CanShuffle public Schema Schema => Source.Schema; - public long? GetRowCount(bool lazy = true) + public long? GetRowCount() { - return Source.GetRowCount(lazy); + return Source.GetRowCount(); } public IRowCursor GetRowCursor(Func predicate, IRandom rand = null) diff --git a/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs b/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs index 5d870f898a..1276ab8e22 100644 --- a/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs +++ b/src/Microsoft.ML.Data/Transforms/PerGroupTransformBase.cs @@ -144,9 +144,9 @@ public virtual void Save(ModelSaveContext ctx) protected abstract BindingsBase GetBindings(); - public long? GetRowCount(bool lazy = true) + public long? GetRowCount() { - return Source.GetRowCount(lazy); + return Source.GetRowCount(); } public IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, IRandom rand = null) diff --git a/src/Microsoft.ML.Data/Transforms/SelectColumnsTransform.cs b/src/Microsoft.ML.Data/Transforms/SelectColumnsTransform.cs index d15c1cfbd4..fd72520c0a 100644 --- a/src/Microsoft.ML.Data/Transforms/SelectColumnsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/SelectColumnsTransform.cs @@ -628,7 +628,7 @@ public SelectColumnsDataTransform(IHostEnvironment env, SelectColumnsTransform t Schema ISchematized.Schema => _mapper.Schema; - public long? GetRowCount(bool lazy = true) => Source.GetRowCount(lazy); + public long? GetRowCount() => Source.GetRowCount(); public IRowCursor GetRowCursor(Func needCol, IRandom rand = null) { diff --git a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs index 69f184607d..353cf2ed47 100644 --- a/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs +++ b/src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs @@ -169,11 +169,11 @@ public override void Save(ModelSaveContext ctx) /// Returns the computed count of rows remaining after skip and take operation. /// Returns null if count is unknown. /// - public override long? GetRowCount(bool lazy = true) + public override long? GetRowCount() { if (_take == 0) return 0; - long? count = Source.GetRowCount(lazy); + long? count = Source.GetRowCount(); if (count == null) return null; diff --git a/src/Microsoft.ML.Data/Transforms/TermTransform.cs b/src/Microsoft.ML.Data/Transforms/TermTransform.cs index 6c082d50df..dc14c014c9 100644 --- a/src/Microsoft.ML.Data/Transforms/TermTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TermTransform.cs @@ -507,7 +507,7 @@ private static TermMap CreateFileTermMap(IHostEnvironment env, IChannel ch, stri { var header = new ProgressHeader(new[] { "Total Terms" }, new[] { "examples" }); var trainer = Trainer.Create(cursor, colSrc, autoConvert, int.MaxValue, bldr); - double rowCount = termData.GetRowCount(true) ?? double.NaN; + double rowCount = termData.GetRowCount() ?? double.NaN; long rowCur = 0; pch.SetHeader(header, e => @@ -606,7 +606,7 @@ private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] info using (var pch = env.StartProgressChannel("Building term dictionary")) { long rowCur = 0; - double rowCount = trainingData.GetRowCount(true) ?? double.NaN; + double rowCount = trainingData.GetRowCount() ?? double.NaN; var header = new ProgressHeader(new[] { "Total Terms" }, new[] { "examples" }); itrainer = 0; diff --git a/src/Microsoft.ML.Data/Transforms/TransformBase.cs b/src/Microsoft.ML.Data/Transforms/TransformBase.cs index 9cdc99d1f9..102dfa6998 100644 --- a/src/Microsoft.ML.Data/Transforms/TransformBase.cs +++ b/src/Microsoft.ML.Data/Transforms/TransformBase.cs @@ -44,7 +44,7 @@ protected TransformBase(IHost host, IDataView input) public abstract void Save(ModelSaveContext ctx); - public abstract long? GetRowCount(bool lazy = true); + public abstract long? GetRowCount(); public virtual bool CanShuffle { get { return Source.CanShuffle; } } @@ -104,7 +104,7 @@ protected RowToRowTransformBase(IHost host, IDataView input) { } - public sealed override long? GetRowCount(bool lazy = true) { return Source.GetRowCount(lazy); } + public sealed override long? GetRowCount() { return Source.GetRowCount(); } } /// @@ -124,7 +124,7 @@ private protected FilterBase(IHost host, IDataView input) { } - public override long? GetRowCount(bool lazy = true) => null; + public override long? GetRowCount() => null; public sealed override Schema Schema => Source.Schema; diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index a71827f892..9728598d3b 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -1862,7 +1862,7 @@ private void MakeBoundariesAndCheckLabels(out long missingInstances, out long to ch.Info("Changing data from row-wise to column-wise"); long pos = 0; - double rowCountDbl = (double?)_data.Data.GetRowCount(lazy: true) ?? Double.NaN; + double rowCountDbl = (double?)_data.Data.GetRowCount() ?? Double.NaN; pch.SetHeader(new ProgressHeader("examples"), e => e.SetProgress(0, pos, rowCountDbl)); // REVIEW: Should we ignore rows with bad label, weight, or group? The previous code seemed to let diff --git a/src/Microsoft.ML.Parquet/ParquetLoader.cs b/src/Microsoft.ML.Parquet/ParquetLoader.cs index 4bbfc608f5..6254c70a1a 100644 --- a/src/Microsoft.ML.Parquet/ParquetLoader.cs +++ b/src/Microsoft.ML.Parquet/ParquetLoader.cs @@ -384,7 +384,7 @@ private static Stream OpenStream(string filename) public Schema Schema { get; } - public long? GetRowCount(bool lazy = true) + public long? GetRowCount() { return _rowCount; } diff --git a/src/Microsoft.ML.PipelineInference/AutoInference.cs b/src/Microsoft.ML.PipelineInference/AutoInference.cs index 5fb8e70d20..d04b61403f 100644 --- a/src/Microsoft.ML.PipelineInference/AutoInference.cs +++ b/src/Microsoft.ML.PipelineInference/AutoInference.cs @@ -470,7 +470,7 @@ public static AutoMlMlState InferPipelines(IHostEnvironment env, PipelineOptimiz env.CheckValue(trainData, nameof(trainData)); env.CheckValue(testData, nameof(testData)); - int numOfRows = (int)(trainData.GetRowCount(false) ?? 1000); + int numOfRows = (int)(trainData.GetRowCount() ?? 1000); AutoMlMlState amls = new AutoMlMlState(env, metric, autoMlEngine, terminator, trainerKind, trainData, testData); bestPipeline = amls.InferPipelines(numTransformLevels, batchSize, numOfRows); return amls; diff --git a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs index f17e617e29..37d838d6b3 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs @@ -447,7 +447,7 @@ protected virtual void TrainCore(IChannel ch, RoleMappedData data) { // REVIEW: maybe it makes sense for the factory to capture the good row count after // the first successful cursoring? - Double totalCount = data.Data.GetRowCount(true) ?? Double.NaN; + Double totalCount = data.Data.GetRowCount() ?? Double.NaN; long exCount = 0; pch.SetHeader(new ProgressHeader(null, new[] { "examples" }), diff --git a/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs b/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs index 3482487a6e..0737809cb7 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs @@ -359,9 +359,9 @@ protected override IRowCursor GetRowCursorCore(Func predicate, IRando public override Schema Schema => _transform.Schema; - public override long? GetRowCount(bool lazy = true) + public override long? GetRowCount() { - return _transform.GetRowCount(lazy); + return _transform.GetRowCount(); } public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, IRandom rand = null) diff --git a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs index 91ccc51101..b1a95b916f 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs @@ -356,9 +356,9 @@ protected override IRowCursor GetRowCursorCore(Func predicate, IRando public override Schema Schema => _bindings.Schema; - public override long? GetRowCount(bool lazy = true) + public override long? GetRowCount() { - return _transform.GetRowCount(lazy); + return _transform.GetRowCount(); } public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, IRandom rand = null) diff --git a/src/Microsoft.ML.Transforms/CountFeatureSelectionTransformer.cs b/src/Microsoft.ML.Transforms/CountFeatureSelectionTransformer.cs index c7f8986752..313108f33c 100644 --- a/src/Microsoft.ML.Transforms/CountFeatureSelectionTransformer.cs +++ b/src/Microsoft.ML.Transforms/CountFeatureSelectionTransformer.cs @@ -179,7 +179,7 @@ public static long[][] Train(IHostEnvironment env, IDataView input, string[] col var aggregators = new CountAggregator[size]; long rowCur = 0; - double rowCount = input.GetRowCount(true) ?? double.NaN; + double rowCount = input.GetRowCount() ?? double.NaN; using (var pch = env.StartProgressChannel("Aggregating counts")) using (var cursor = input.GetRowCursor(col => activeInput[col])) { diff --git a/src/Microsoft.ML.Transforms/GroupTransform.cs b/src/Microsoft.ML.Transforms/GroupTransform.cs index e679299831..6ca3f25831 100644 --- a/src/Microsoft.ML.Transforms/GroupTransform.cs +++ b/src/Microsoft.ML.Transforms/GroupTransform.cs @@ -147,7 +147,7 @@ public override void Save(ModelSaveContext ctx) _groupSchema.Save(ctx); } - public override long? GetRowCount(bool lazy = true) + public override long? GetRowCount() { // We have no idea how many total rows we'll have. return null; diff --git a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs index 49082aa254..a443b9ecdd 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs @@ -502,7 +502,7 @@ private SequencePool[] Train(Arguments args, IDataView trainingData, out double[ invDocFreqs = new double[Infos.Length][]; long totalDocs = 0; - Double rowCount = trainingData.GetRowCount(true) ?? Double.NaN; + Double rowCount = trainingData.GetRowCount() ?? Double.NaN; var buffers = new VBuffer[Infos.Length]; pch.SetHeader(new ProgressHeader(new[] { "Total n-grams" }, new[] { "documents" }), e => e.SetProgress(0, totalDocs, rowCount)); diff --git a/src/Microsoft.ML.Transforms/UngroupTransform.cs b/src/Microsoft.ML.Transforms/UngroupTransform.cs index 230138151d..c34100da37 100644 --- a/src/Microsoft.ML.Transforms/UngroupTransform.cs +++ b/src/Microsoft.ML.Transforms/UngroupTransform.cs @@ -149,13 +149,13 @@ public override void Save(ModelSaveContext ctx) _schemaImpl.Save(ctx); } - public override long? GetRowCount(bool lazy = true) + public override long? GetRowCount() { // Row count is known if the input's row count is known, and pivot column sizes are fixed. var commonSize = _schemaImpl.GetCommonPivotColumnSize(); if (commonSize > 0) { - long? srcRowCount = Source.GetRowCount(true); + long? srcRowCount = Source.GetRowCount(); if (srcRowCount.HasValue && srcRowCount.Value <= (long.MaxValue / commonSize)) return srcRowCount.Value * commonSize; } diff --git a/src/Microsoft.ML.Transforms/VectorWhitening.cs b/src/Microsoft.ML.Transforms/VectorWhitening.cs index dd129cd6b8..b60098d572 100644 --- a/src/Microsoft.ML.Transforms/VectorWhitening.cs +++ b/src/Microsoft.ML.Transforms/VectorWhitening.cs @@ -341,7 +341,7 @@ private static void ValidateModel(IExceptionContext ectx, float[] model, ColumnT // A more reliable solution is to turely iterate through all rows via a RowCursor. private static long GetRowCount(IDataView inputData, params ColumnInfo[] columns) { - long? rows = inputData.GetRowCount(lazy: false); + long? rows = inputData.GetRowCount(); if (rows != null) return rows.GetValueOrDefault(); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs index 5115a25be9..9507d4bce3 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs @@ -157,7 +157,7 @@ public void LambdaTransformCreate() var filter = LambdaTransform.CreateFilter(env, idv, (input, state) => input.Label == 0, null); - Assert.Null(filter.GetRowCount(false)); + Assert.Null(filter.GetRowCount()); // test re-apply var applied = env.CreateDataView(data);