Skip to content

Commit

Permalink
fix issue in WaiterWaiter caused by race condition (#4829)
Browse files Browse the repository at this point in the history
* fix issue in WaiterWaiter

* re-enable tests that affected by the fixed issue

* refine comments

* refactor based on discussion

* take comments
  • Loading branch information
frank-dong-ms-zz authored Feb 13, 2020
1 parent 229ef37 commit 6cd6081
Show file tree
Hide file tree
Showing 8 changed files with 15 additions and 40 deletions.
23 changes: 15 additions & 8 deletions src/Microsoft.ML.Data/DataView/CacheDataView.cs
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,13 @@ public int MapInputToCacheColumnIndex(int inputIndex)
/// </summary>
public long? GetRowCount()
{
if (_rowCount < 0)
// _rowCount may or may not be initialized at this point. Only read the value once
// to avoid race conditions.
long rowCount = _rowCount;

if (rowCount < 0)
return null;
return _rowCount;
return rowCount;
}

public DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
Expand Down Expand Up @@ -605,8 +609,9 @@ private sealed class TrivialWaiter : IWaiter

private TrivialWaiter(CacheDataView parent)
{
Contracts.Assert(parent._rowCount >= 0);
_lim = parent._rowCount;
var rowCount = parent.GetRowCount();
Contracts.Assert(rowCount.HasValue);
_lim = rowCount.Value;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand Down Expand Up @@ -669,7 +674,7 @@ private WaiterWaiter(CacheDataView parent, Func<int, bool> pred)
waiters.Add(waiter);
}
// Make the array of waiters.
if (_parent._rowCount < 0 && waiters.Count == 0)
if (!_parent.GetRowCount().HasValue && waiters.Count == 0)
{
Contracts.AssertValue(_parent._cacheDefaultWaiter);
waiters.Add(_parent._cacheDefaultWaiter);
Expand All @@ -682,7 +687,9 @@ public bool Wait(long pos)
{
foreach (var w in _waiters)
w.Wait(pos);
return pos < _parent._rowCount || _parent._rowCount == -1;

var rowCount = _parent.GetRowCount();
return !rowCount.HasValue || pos < rowCount.Value;
}

public static Wrapper Create(CacheDataView parent, Func<int, bool> pred)
Expand Down Expand Up @@ -1419,8 +1426,8 @@ public ImplOne(CacheDataView parent, DataViewRowCursor input, int srcCol, Ordere
: base(parent, input, srcCol, waiter)
{
_getter = input.GetGetter<T>(input.Schema[srcCol]);
if (parent._rowCount >= 0)
_values = new T[(int)parent._rowCount];
if (parent.GetRowCount() is { } rowCount)
_values = new T[rowCount];
}

public override void CacheCurrent()
Expand Down
12 changes: 0 additions & 12 deletions test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2640,8 +2640,6 @@ public void EntryPointEvaluateMulticlass()
}

[Fact]
//Skipping test temporarily. This test will be re-enabled once the cause of failures has been determined
[Trait("Category", "SkipInCI")]
public void EntryPointEvaluateRegression()
{
var dataPath = GetDataPath(TestDatasets.generatedRegressionDatasetmacro.trainFilename);
Expand Down Expand Up @@ -2752,8 +2750,6 @@ public void EntryPointLightGbmMulticlass()
}

[Fact]
//Skipping test temporarily. This test will be re-enabled once the cause of failures has been determined
[Trait("Category", "SkipInCI")]
public void EntryPointSdcaBinary()
{
TestEntryPointRoutine("breast-cancer.txt", "Trainers.StochasticDualCoordinateAscentBinaryClassifier");
Expand All @@ -2766,8 +2762,6 @@ public void EntryPointSDCAMulticlass()
}

[Fact]
//Skipping test temporarily. This test will be re-enabled once the cause of failures has been determined
[Trait("Category", "SkipInCI")]
public void EntryPointSDCARegression()
{
TestEntryPointRoutine(TestDatasets.generatedRegressionDatasetmacro.trainFilename, "Trainers.StochasticDualCoordinateAscentRegressor", loader: TestDatasets.generatedRegressionDatasetmacro.loaderSettings);
Expand Down Expand Up @@ -3855,8 +3849,6 @@ public void EntryPointChainedTrainTestMacros()
}

[Fact]
//Skipping test temporarily. This test will be re-enabled once the cause of failures has been determined
[Trait("Category", "SkipInCI")]
public void EntryPointChainedCrossValMacros()
{
string inputGraph = @"
Expand Down Expand Up @@ -5506,8 +5498,6 @@ public void TestCrossValidationMacroMulticlassWithWarnings()
}

[Fact]
//Skipping test temporarily. This test will be re-enabled once the cause of failures has been determined
[Trait("Category", "SkipInCI")]
public void TestCrossValidationMacroWithStratification()
{
var dataPath = GetDataPath(@"breast-cancer.txt");
Expand Down Expand Up @@ -6039,8 +6029,6 @@ public void TestCrossValidationMacroWithNonDefaultNames()
}

[Fact]
//Skipping test temporarily. This test will be re-enabled once the cause of failures has been determined
[Trait("Category", "SkipInCI")]
public void TestOvaMacro()
{
var dataPath = GetDataPath(@"iris.txt");
Expand Down
2 changes: 0 additions & 2 deletions test/Microsoft.ML.Functional.Tests/IntrospectiveTraining.cs
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,6 @@ public void InspectLdaModelParameters()
/// Introspective Training: Linear model parameters may be inspected.
/// </summary>
[Fact]
//Skipping test temporarily. This test will be re-enabled once the cause of failures has been determined
[Trait("Category", "SkipInCI")]
public void InpsectLinearModelParameters()
{
var mlContext = new MLContext(seed: 1);
Expand Down
4 changes: 0 additions & 4 deletions test/Microsoft.ML.Predictor.Tests/TestPredictors.cs
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,6 @@ public void KMeansClusteringTest()
[X64Fact("x86 output differs from Baseline")]
[TestCategory("Binary")]
[TestCategory("SDCA")]
//Skipping test temporarily. This test will be re-enabled once the cause of failures has been determined
[Trait("Category", "SkipInCI")]
public void LinearClassifierTest()
{
var binaryPredictors = new[]
Expand Down Expand Up @@ -324,8 +322,6 @@ public void BinaryClassifierLogisticRegressionNormTest()
///</summary>
[LessThanNetCore30OrNotNetCoreAndX64Fact("netcoreapp3.0 and x86 output differs from Baseline")]
[TestCategory("Binary")]
//Skipping test temporarily. This test will be re-enabled once the cause of failures has been determined
[Trait("Category", "SkipInCI")]
public void BinaryClassifierLogisticRegressionNonNegativeTest()
{
var binaryPredictors = new[] { TestLearners.logisticRegressionNonNegative };
Expand Down
2 changes: 0 additions & 2 deletions test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1022,8 +1022,6 @@ public void SavePipeTrainAndScoreFccFastTree()

[TestCategory("DataPipeSerialization")]
[Fact]
//Skipping test temporarily. This test will be re-enabled once the cause of failures has been determined
[Trait("Category", "SkipInCI")]
public void SavePipeTrainAndScoreFccTransformStr()
{
TestCore(null, false,
Expand Down
2 changes: 0 additions & 2 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -696,8 +696,6 @@ public void LogisticRegressionOnnxConversionTest()
}

[LightGBMFact]
//Skipping test temporarily. This test will be re-enabled once the cause of failures has been determined
[Trait("Category", "SkipInCI")]
public void LightGbmBinaryClassificationOnnxConversionTest()
{
// Step 1: Create and train a ML.NET pipeline.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,6 @@ private void TrainRegression(string trainDataPath, string testDataPath, string m
}

[Fact]
//Skipping test temporarily. This test will be re-enabled once the cause of failures has been determined
[Trait("Category", "SkipInCI")]
public void TrainRegressionModel()
=> TrainRegression(GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename), GetDataPath(TestDatasets.generatedRegressionDataset.testFilename),
DeleteOutputPath("cook_model.zip"));
Expand Down Expand Up @@ -293,8 +291,6 @@ private void PredictOnIris(ITransformer model)
}

[Fact]
//Skipping test temporarily. This test will be re-enabled once the cause of failures has been determined
[Trait("Category", "SkipInCI")]
public void TrainAndPredictOnIris()
=> PredictOnIris(TrainOnIris(GetDataPath("iris.data")));

Expand Down Expand Up @@ -629,8 +625,6 @@ private void CategoricalFeaturizationOn(params string[] dataPath)
}

[Fact]
//Skipping test temporarily. This test will be re-enabled once the cause of failures has been determined
[Trait("Category", "SkipInCI")]
public void CrossValidationIris()
=> CrossValidationOn(GetDataPath("iris.data"));

Expand Down Expand Up @@ -708,8 +702,6 @@ public class OutputRow
}

[Fact]
//Skipping test temporarily. This test will be re-enabled once the cause of failures has been determined
[Trait("Category", "SkipInCI")]
public void CustomTransformer()
{
var mlContext = new MLContext(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ namespace Microsoft.ML.Scenarios
public partial class ScenariosTests : BaseTestClass
{
[Fact]
//Skipping test temporarily. This test will be re-enabled once the cause of failures has been determined
[Trait("Category", "SkipInCI")]
public void TrainAndPredictIrisModelTest()
{
var mlContext = new MLContext(seed: 1);
Expand Down

0 comments on commit 6cd6081

Please sign in to comment.