Skip to content

Commit

Permalink
[AutoML] Cross validation fixes; validate empty training / validation…
Browse files Browse the repository at this point in the history
… input data (dotnet#3794)
  • Loading branch information
daholste authored and Dmitry-A committed Aug 22, 2019
1 parent c07fe1f commit e5ebbf5
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 35 deletions.
4 changes: 1 addition & 3 deletions src/Microsoft.ML.AutoML/API/ExperimentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,7 @@ public ExperimentResult<TMetrics> Execute(IDataView trainData, IDataView validat
{
if (validationData == null)
{
var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, columnInformation?.SamplingKeyColumnName);
trainData = splitResult.trainData;
validationData = splitResult.validationData;
return Execute(trainData, columnInformation, preFeaturizer, progressHandler);
}
return ExecuteTrainValidate(trainData, columnInformation, validationData, preFeaturizer, progressHandler);
}
Expand Down
35 changes: 22 additions & 13 deletions src/Microsoft.ML.AutoML/Utils/SplitUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,37 @@ namespace Microsoft.ML.AutoML
{
internal static class SplitUtil
{
private const string CrossValEmptyFoldErrorMsg = @"Cross validation split has 0 rows. Perhaps " +
"try increasing number of rows provided in training data, or lowering specified number of " +
"cross validation folds.";

public static (IDataView[] trainDatasets, IDataView[] validationDatasets) CrossValSplit(MLContext context,
IDataView trainData, uint numFolds, string samplingKeyColumn)
{
var originalColumnNames = trainData.Schema.Select(c => c.Name);
var splits = context.Data.CrossValidationSplit(trainData, (int)numFolds, samplingKeyColumnName: samplingKeyColumn);
var trainDatasets = new IDataView[numFolds];
var validationDatasets = new IDataView[numFolds];
for (var i = 0; i < numFolds; i++)
var trainDatasets = new List<IDataView>();
var validationDatasets = new List<IDataView>();

foreach (var split in splits)
{
var split = splits[i];
trainDatasets[i] = DropAllColumnsExcept(context, split.TrainSet, originalColumnNames);
validationDatasets[i] = DropAllColumnsExcept(context, split.TestSet, originalColumnNames);
if (DatasetDimensionsUtil.IsDataViewEmpty(trainDatasets[i]) || DatasetDimensionsUtil.IsDataViewEmpty(validationDatasets[i]))
if (DatasetDimensionsUtil.IsDataViewEmpty(split.TrainSet) ||
DatasetDimensionsUtil.IsDataViewEmpty(split.TestSet))
{
throw new InvalidOperationException(CrossValEmptyFoldErrorMsg);
continue;
}

var trainDataset = DropAllColumnsExcept(context, split.TrainSet, originalColumnNames);
var validationDataset = DropAllColumnsExcept(context, split.TestSet, originalColumnNames);

trainDatasets.Add(trainDataset);
validationDatasets.Add(validationDataset);
}
return (trainDatasets, validationDatasets);

if (!trainDatasets.Any())
{
throw new InvalidOperationException("All cross validation folds have empty train or test data. " +
"Try increasing the number of rows provided in training data, or lowering specified number of " +
"cross validation folds.");
}

return (trainDatasets.ToArray(), validationDatasets.ToArray());
}

/// <summary>
Expand Down
10 changes: 10 additions & 0 deletions src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ private static void ValidateTrainData(IDataView trainData, ColumnInformation col
throw new ArgumentNullException(nameof(trainData), "Training data cannot be null");
}

if (DatasetDimensionsUtil.IsDataViewEmpty(trainData))
{
throw new ArgumentException("Training data has 0 rows", nameof(trainData));
}

foreach (var column in trainData.Schema)
{
if (column.Name == DefaultColumnNames.Features && column.Type.GetItemType() != NumberDataViewType.Single)
Expand Down Expand Up @@ -164,6 +169,11 @@ private static void ValidateValidationData(IDataView trainData, IDataView valida
return;
}

if (DatasetDimensionsUtil.IsDataViewEmpty(validationData))
{
throw new ArgumentException("Validation data has 0 rows", nameof(validationData));
}

const string schemaMismatchError = "Training data and validation data schemas do not match.";

if (trainData.Schema.Count != validationData.Schema.Count)
Expand Down
71 changes: 71 additions & 0 deletions test/Microsoft.ML.AutoML.Tests/SplitUtilTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace Microsoft.ML.AutoML.Test
{
[TestClass]
public class SplitUtilTests
{
/// <summary>
/// When there's only one row of data, assert that
/// attempted cross validation throws (all splits should have empty
/// train or test set).
/// </summary>
[TestMethod]
[ExpectedException(typeof(InvalidOperationException))]
public void CrossValSplitThrowsWhenNotEnoughData()
{
var mlContext = new MLContext();
var dataViewBuilder = new ArrayDataViewBuilder(mlContext);
dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, 0f);
dataViewBuilder.AddColumn("Label", NumberDataViewType.Single, 0f);
var dataView = dataViewBuilder.GetDataView();
SplitUtil.CrossValSplit(mlContext, dataView, 10, null);
}

/// <summary>
/// When there are few rows of data, assert that
/// cross validation succeeds, but # of splits is less than 10
/// (splits with empty train or test sets should not be returned from this API).
/// </summary>
[TestMethod]
public void CrossValSplitSmallDataView()
{
var mlContext = new MLContext(seed: 0);
var dataViewBuilder = new ArrayDataViewBuilder(mlContext);
dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, new float[9]);
dataViewBuilder.AddColumn("Label", NumberDataViewType.Single, new float[9]);
var dataView = dataViewBuilder.GetDataView();
const int requestedNumSplits = 10;
var splits = SplitUtil.CrossValSplit(mlContext, dataView, requestedNumSplits, null);
Assert.IsTrue(splits.trainDatasets.Any());
Assert.IsTrue(splits.trainDatasets.Count() < requestedNumSplits);
Assert.AreEqual(splits.trainDatasets.Count(), splits.validationDatasets.Count());
}

/// <summary>
/// Assert that with many rows of data, cross validation produces the requested
/// # of splits.
/// </summary>
[TestMethod]
public void CrossValSplitLargeDataView()
{
var mlContext = new MLContext(seed: 0);
var dataViewBuilder = new ArrayDataViewBuilder(mlContext);
dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, new float[10000]);
dataViewBuilder.AddColumn("Label", NumberDataViewType.Single, new float[10000]);
var dataView = dataViewBuilder.GetDataView();
const int requestedNumSplits = 10;
var splits = SplitUtil.CrossValSplit(mlContext, dataView, requestedNumSplits, null);
Assert.IsTrue(splits.trainDatasets.Any());
Assert.AreEqual(requestedNumSplits, splits.trainDatasets.Count());
Assert.AreEqual(requestedNumSplits, splits.validationDatasets.Count());
}
}
}
77 changes: 58 additions & 19 deletions test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -187,35 +187,47 @@ public void ValidateTextColumnNotText()
[TestMethod]
public void ValidateRegressionLabelTypes()
{
ValidateLabelTypeTestCore(TaskKind.Regression, NumberDataViewType.Single, true);
ValidateLabelTypeTestCore(TaskKind.Regression, BooleanDataViewType.Instance, false);
ValidateLabelTypeTestCore(TaskKind.Regression, NumberDataViewType.Double, false);
ValidateLabelTypeTestCore(TaskKind.Regression, TextDataViewType.Instance, false);
ValidateLabelTypeTestCore<float>(TaskKind.Regression, NumberDataViewType.Single, true);
ValidateLabelTypeTestCore<bool>(TaskKind.Regression, BooleanDataViewType.Instance, false);
ValidateLabelTypeTestCore<double>(TaskKind.Regression, NumberDataViewType.Double, false);
ValidateLabelTypeTestCore<string>(TaskKind.Regression, TextDataViewType.Instance, false);
}

[TestMethod]
public void ValidateBinaryClassificationLabelTypes()
{
ValidateLabelTypeTestCore(TaskKind.BinaryClassification, NumberDataViewType.Single, false);
ValidateLabelTypeTestCore(TaskKind.BinaryClassification, BooleanDataViewType.Instance, true);
ValidateLabelTypeTestCore<float>(TaskKind.BinaryClassification, NumberDataViewType.Single, false);
ValidateLabelTypeTestCore<bool>(TaskKind.BinaryClassification, BooleanDataViewType.Instance, true);
}

[TestMethod]
public void ValidateMulticlassLabelTypes()
{
ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, NumberDataViewType.Single, true);
ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, BooleanDataViewType.Instance, true);
ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, NumberDataViewType.Double, true);
ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, TextDataViewType.Instance, true);
ValidateLabelTypeTestCore<float>(TaskKind.MulticlassClassification, NumberDataViewType.Single, true);
ValidateLabelTypeTestCore<bool>(TaskKind.MulticlassClassification, BooleanDataViewType.Instance, true);
ValidateLabelTypeTestCore<double>(TaskKind.MulticlassClassification, NumberDataViewType.Double, true);
ValidateLabelTypeTestCore<string>(TaskKind.MulticlassClassification, TextDataViewType.Instance, true);
}

[TestMethod]
public void ValidateAllowedFeatureColumnTypes()
{
var dataViewBuilder = new ArrayDataViewBuilder(new MLContext());
dataViewBuilder.AddColumn("Boolean", BooleanDataViewType.Instance, false);
dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, 0f);
dataViewBuilder.AddColumn("Text", "a");
dataViewBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single, 0f);
var dataView = dataViewBuilder.GetDataView();
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(),
null, TaskKind.Regression);
}

[TestMethod]
[ExpectedException(typeof(ArgumentException))]
public void ValidateProhibitedFeatureColumnType()
{
var schemaBuilder = new DataViewSchema.Builder();
schemaBuilder.AddColumn("Boolean", BooleanDataViewType.Instance);
schemaBuilder.AddColumn("Number", NumberDataViewType.Single);
schemaBuilder.AddColumn("Text", TextDataViewType.Instance);
schemaBuilder.AddColumn("UInt64", NumberDataViewType.UInt64);
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
var schema = schemaBuilder.ToSchema();
var dataView = new EmptyDataView(new MLContext(), schema);
Expand All @@ -225,24 +237,51 @@ public void ValidateAllowedFeatureColumnTypes()

[TestMethod]
[ExpectedException(typeof(ArgumentException))]
public void ValidateProhibitedFeatureColumnType()
public void ValidateEmptyTrainingDataThrows()
{
var schemaBuilder = new DataViewSchema.Builder();
schemaBuilder.AddColumn("UInt64", NumberDataViewType.UInt64);
schemaBuilder.AddColumn("Number", NumberDataViewType.Single);
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
var schema = schemaBuilder.ToSchema();
var dataView = new EmptyDataView(new MLContext(), schema);
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(),
null, TaskKind.Regression);
}

private static void ValidateLabelTypeTestCore(TaskKind task, DataViewType labelType, bool labelTypeShouldBeValid)
[TestMethod]
[ExpectedException(typeof(ArgumentException))]
public void ValidateEmptyValidationDataThrows()
{
// Training data
var dataViewBuilder = new ArrayDataViewBuilder(new MLContext());
dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, 0f);
dataViewBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single, 0f);
var trainingData = dataViewBuilder.GetDataView();

// Validation data
var schemaBuilder = new DataViewSchema.Builder();
schemaBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single);
schemaBuilder.AddColumn(DefaultColumnNames.Label, labelType);
schemaBuilder.AddColumn("Number", NumberDataViewType.Single);
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
var schema = schemaBuilder.ToSchema();
var dataView = new EmptyDataView(new MLContext(), schema);
var validationData = new EmptyDataView(new MLContext(), schema);

UserInputValidationUtil.ValidateExperimentExecuteArgs(trainingData, new ColumnInformation(),
validationData, TaskKind.Regression);
}

private static void ValidateLabelTypeTestCore<LabelRawType>(TaskKind task, PrimitiveDataViewType labelType, bool labelTypeShouldBeValid)
{
var dataViewBuilder = new ArrayDataViewBuilder(new MLContext());
dataViewBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single, 0f);
if (labelType == TextDataViewType.Instance)
{
dataViewBuilder.AddColumn(DefaultColumnNames.Label, string.Empty);
}
else
{
dataViewBuilder.AddColumn(DefaultColumnNames.Label, labelType, Activator.CreateInstance<LabelRawType>());
}
var dataView = dataViewBuilder.GetDataView();
var validationExceptionThrown = false;
try
{
Expand Down

0 comments on commit e5ebbf5

Please sign in to comment.