Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoML] Cross validation fixes; validate empty training / validation input data #3794

Merged
merged 4 commits into from
Jun 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/Microsoft.ML.AutoML/API/ExperimentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,7 @@ public ExperimentResult<TMetrics> Execute(IDataView trainData, IDataView validat
{
if (validationData == null)
{
var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, columnInformation?.SamplingKeyColumnName);
trainData = splitResult.trainData;
validationData = splitResult.validationData;
return Execute(trainData, columnInformation, preFeaturizer, progressHandler);
}
return ExecuteTrainValidate(trainData, columnInformation, validationData, preFeaturizer, progressHandler);
}
Expand Down
35 changes: 22 additions & 13 deletions src/Microsoft.ML.AutoML/Utils/SplitUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,37 @@ namespace Microsoft.ML.AutoML
{
internal static class SplitUtil
{
private const string CrossValEmptyFoldErrorMsg = @"Cross validation split has 0 rows. Perhaps " +
"try increasing number of rows provided in training data, or lowering specified number of " +
"cross validation folds.";

public static (IDataView[] trainDatasets, IDataView[] validationDatasets) CrossValSplit(MLContext context,
IDataView trainData, uint numFolds, string samplingKeyColumn)
{
var originalColumnNames = trainData.Schema.Select(c => c.Name);
var splits = context.Data.CrossValidationSplit(trainData, (int)numFolds, samplingKeyColumnName: samplingKeyColumn);
var trainDatasets = new IDataView[numFolds];
var validationDatasets = new IDataView[numFolds];
for (var i = 0; i < numFolds; i++)
var trainDatasets = new List<IDataView>();
var validationDatasets = new List<IDataView>();

foreach (var split in splits)
{
var split = splits[i];
trainDatasets[i] = DropAllColumnsExcept(context, split.TrainSet, originalColumnNames);
validationDatasets[i] = DropAllColumnsExcept(context, split.TestSet, originalColumnNames);
if (DatasetDimensionsUtil.IsDataViewEmpty(trainDatasets[i]) || DatasetDimensionsUtil.IsDataViewEmpty(validationDatasets[i]))
if (DatasetDimensionsUtil.IsDataViewEmpty(split.TrainSet) ||
DatasetDimensionsUtil.IsDataViewEmpty(split.TestSet))
{
throw new InvalidOperationException(CrossValEmptyFoldErrorMsg);
continue;
}

var trainDataset = DropAllColumnsExcept(context, split.TrainSet, originalColumnNames);
var validationDataset = DropAllColumnsExcept(context, split.TestSet, originalColumnNames);

trainDatasets.Add(trainDataset);
validationDatasets.Add(validationDataset);
}
return (trainDatasets, validationDatasets);

if (!trainDatasets.Any())
{
throw new InvalidOperationException("All cross validation folds have empty train or test data. " +
"Try increasing the number of rows provided in training data, or lowering specified number of " +
"cross validation folds.");
}

return (trainDatasets.ToArray(), validationDatasets.ToArray());
}

/// <summary>
Expand Down
10 changes: 10 additions & 0 deletions src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ private static void ValidateTrainData(IDataView trainData, ColumnInformation col
throw new ArgumentNullException(nameof(trainData), "Training data cannot be null");
}

if (DatasetDimensionsUtil.IsDataViewEmpty(trainData))
{
throw new ArgumentException("Training data has 0 rows", nameof(trainData));
}

foreach (var column in trainData.Schema)
{
if (column.Name == DefaultColumnNames.Features && column.Type.GetItemType() != NumberDataViewType.Single)
Expand Down Expand Up @@ -164,6 +169,11 @@ private static void ValidateValidationData(IDataView trainData, IDataView valida
return;
}

if (DatasetDimensionsUtil.IsDataViewEmpty(validationData))
{
throw new ArgumentException("Validation data has 0 rows", nameof(validationData));
}

const string schemaMismatchError = "Training data and validation data schemas do not match.";

if (trainData.Schema.Count != validationData.Schema.Count)
Expand Down
71 changes: 71 additions & 0 deletions test/Microsoft.ML.AutoML.Tests/SplitUtilTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace Microsoft.ML.AutoML.Test
{
[TestClass]
public class SplitUtilTests
{
/// <summary>
/// When there's only one row of data, assert that
/// attempted cross validation throws (all splits should have empty
/// train or test set).
/// </summary>
[TestMethod]
[ExpectedException(typeof(InvalidOperationException))]
public void CrossValSplitThrowsWhenNotEnoughData()
{
var mlContext = new MLContext();
var dataViewBuilder = new ArrayDataViewBuilder(mlContext);
dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, 0f);
dataViewBuilder.AddColumn("Label", NumberDataViewType.Single, 0f);
var dataView = dataViewBuilder.GetDataView();
SplitUtil.CrossValSplit(mlContext, dataView, 10, null);
}

/// <summary>
/// When there are few rows of data, assert that
/// cross validation succeeds, but # of splits is less than 10
/// (splits with empty train or test sets should not be returned from this API).
/// </summary>
[TestMethod]
public void CrossValSplitSmallDataView()
{
var mlContext = new MLContext(seed: 0);
var dataViewBuilder = new ArrayDataViewBuilder(mlContext);
dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, new float[9]);
dataViewBuilder.AddColumn("Label", NumberDataViewType.Single, new float[9]);
var dataView = dataViewBuilder.GetDataView();
const int requestedNumSplits = 10;
var splits = SplitUtil.CrossValSplit(mlContext, dataView, requestedNumSplits, null);
Assert.IsTrue(splits.trainDatasets.Any());
Assert.IsTrue(splits.trainDatasets.Count() < requestedNumSplits);
Assert.AreEqual(splits.trainDatasets.Count(), splits.validationDatasets.Count());
}

/// <summary>
/// Assert that with many rows of data, cross validation produces the requested
/// # of splits.
/// </summary>
[TestMethod]
public void CrossValSplitLargeDataView()
{
var mlContext = new MLContext(seed: 0);
var dataViewBuilder = new ArrayDataViewBuilder(mlContext);
dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, new float[10000]);
dataViewBuilder.AddColumn("Label", NumberDataViewType.Single, new float[10000]);
var dataView = dataViewBuilder.GetDataView();
const int requestedNumSplits = 10;
var splits = SplitUtil.CrossValSplit(mlContext, dataView, requestedNumSplits, null);
Assert.IsTrue(splits.trainDatasets.Any());
Assert.AreEqual(requestedNumSplits, splits.trainDatasets.Count());
Assert.AreEqual(requestedNumSplits, splits.validationDatasets.Count());
}
}
}
77 changes: 58 additions & 19 deletions test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -187,35 +187,47 @@ public void ValidateTextColumnNotText()
[TestMethod]
public void ValidateRegressionLabelTypes()
{
ValidateLabelTypeTestCore(TaskKind.Regression, NumberDataViewType.Single, true);
ValidateLabelTypeTestCore(TaskKind.Regression, BooleanDataViewType.Instance, false);
ValidateLabelTypeTestCore(TaskKind.Regression, NumberDataViewType.Double, false);
ValidateLabelTypeTestCore(TaskKind.Regression, TextDataViewType.Instance, false);
ValidateLabelTypeTestCore<float>(TaskKind.Regression, NumberDataViewType.Single, true);
ValidateLabelTypeTestCore<bool>(TaskKind.Regression, BooleanDataViewType.Instance, false);
ValidateLabelTypeTestCore<double>(TaskKind.Regression, NumberDataViewType.Double, false);
ValidateLabelTypeTestCore<string>(TaskKind.Regression, TextDataViewType.Instance, false);
}

[TestMethod]
public void ValidateBinaryClassificationLabelTypes()
{
ValidateLabelTypeTestCore(TaskKind.BinaryClassification, NumberDataViewType.Single, false);
ValidateLabelTypeTestCore(TaskKind.BinaryClassification, BooleanDataViewType.Instance, true);
ValidateLabelTypeTestCore<float>(TaskKind.BinaryClassification, NumberDataViewType.Single, false);
ValidateLabelTypeTestCore<bool>(TaskKind.BinaryClassification, BooleanDataViewType.Instance, true);
}

[TestMethod]
public void ValidateMulticlassLabelTypes()
{
ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, NumberDataViewType.Single, true);
ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, BooleanDataViewType.Instance, true);
ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, NumberDataViewType.Double, true);
ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, TextDataViewType.Instance, true);
ValidateLabelTypeTestCore<float>(TaskKind.MulticlassClassification, NumberDataViewType.Single, true);
ValidateLabelTypeTestCore<bool>(TaskKind.MulticlassClassification, BooleanDataViewType.Instance, true);
ValidateLabelTypeTestCore<double>(TaskKind.MulticlassClassification, NumberDataViewType.Double, true);
ValidateLabelTypeTestCore<string>(TaskKind.MulticlassClassification, TextDataViewType.Instance, true);
}

[TestMethod]
public void ValidateAllowedFeatureColumnTypes()
{
var dataViewBuilder = new ArrayDataViewBuilder(new MLContext());
dataViewBuilder.AddColumn("Boolean", BooleanDataViewType.Instance, false);
dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, 0f);
dataViewBuilder.AddColumn("Text", "a");
dataViewBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single, 0f);
var dataView = dataViewBuilder.GetDataView();
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(),
null, TaskKind.Regression);
}

[TestMethod]
[ExpectedException(typeof(ArgumentException))]
public void ValidateProhibitedFeatureColumnType()
{
var schemaBuilder = new DataViewSchema.Builder();
schemaBuilder.AddColumn("Boolean", BooleanDataViewType.Instance);
schemaBuilder.AddColumn("Number", NumberDataViewType.Single);
schemaBuilder.AddColumn("Text", TextDataViewType.Instance);
schemaBuilder.AddColumn("UInt64", NumberDataViewType.UInt64);
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
var schema = schemaBuilder.ToSchema();
var dataView = new EmptyDataView(new MLContext(), schema);
Expand All @@ -225,24 +237,51 @@ public void ValidateAllowedFeatureColumnTypes()

[TestMethod]
[ExpectedException(typeof(ArgumentException))]
public void ValidateProhibitedFeatureColumnType()
public void ValidateEmptyTrainingDataThrows()
{
var schemaBuilder = new DataViewSchema.Builder();
schemaBuilder.AddColumn("UInt64", NumberDataViewType.UInt64);
schemaBuilder.AddColumn("Number", NumberDataViewType.Single);
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
var schema = schemaBuilder.ToSchema();
var dataView = new EmptyDataView(new MLContext(), schema);
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(),
null, TaskKind.Regression);
}

private static void ValidateLabelTypeTestCore(TaskKind task, DataViewType labelType, bool labelTypeShouldBeValid)
[TestMethod]
[ExpectedException(typeof(ArgumentException))]
public void ValidateEmptyValidationDataThrows()
{
// Training data
var dataViewBuilder = new ArrayDataViewBuilder(new MLContext());
dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, 0f);
dataViewBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single, 0f);
var trainingData = dataViewBuilder.GetDataView();

// Validation data
var schemaBuilder = new DataViewSchema.Builder();
schemaBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single);
schemaBuilder.AddColumn(DefaultColumnNames.Label, labelType);
schemaBuilder.AddColumn("Number", NumberDataViewType.Single);
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
var schema = schemaBuilder.ToSchema();
var dataView = new EmptyDataView(new MLContext(), schema);
var validationData = new EmptyDataView(new MLContext(), schema);

UserInputValidationUtil.ValidateExperimentExecuteArgs(trainingData, new ColumnInformation(),
validationData, TaskKind.Regression);
}

private static void ValidateLabelTypeTestCore<LabelRawType>(TaskKind task, PrimitiveDataViewType labelType, bool labelTypeShouldBeValid)
{
var dataViewBuilder = new ArrayDataViewBuilder(new MLContext());
dataViewBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single, 0f);
if (labelType == TextDataViewType.Instance)
{
dataViewBuilder.AddColumn(DefaultColumnNames.Label, string.Empty);
}
else
{
dataViewBuilder.AddColumn(DefaultColumnNames.Label, labelType, Activator.CreateInstance<LabelRawType>());
}
var dataView = dataViewBuilder.GetDataView();
var validationExceptionThrown = false;
try
{
Expand Down