diff --git a/src/Microsoft.ML.AutoML/API/ExperimentBase.cs b/src/Microsoft.ML.AutoML/API/ExperimentBase.cs index c3e23687492..05601342dc3 100644 --- a/src/Microsoft.ML.AutoML/API/ExperimentBase.cs +++ b/src/Microsoft.ML.AutoML/API/ExperimentBase.cs @@ -162,9 +162,7 @@ public ExperimentResult Execute(IDataView trainData, IDataView validat { if (validationData == null) { - var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, columnInformation?.SamplingKeyColumnName); - trainData = splitResult.trainData; - validationData = splitResult.validationData; + return Execute(trainData, columnInformation, preFeaturizer, progressHandler); } return ExecuteTrainValidate(trainData, columnInformation, validationData, preFeaturizer, progressHandler); } diff --git a/src/Microsoft.ML.AutoML/Utils/SplitUtil.cs b/src/Microsoft.ML.AutoML/Utils/SplitUtil.cs index 584e769538b..cc77f0fbc9b 100644 --- a/src/Microsoft.ML.AutoML/Utils/SplitUtil.cs +++ b/src/Microsoft.ML.AutoML/Utils/SplitUtil.cs @@ -10,28 +10,37 @@ namespace Microsoft.ML.AutoML { internal static class SplitUtil { - private const string CrossValEmptyFoldErrorMsg = @"Cross validation split has 0 rows. Perhaps " + - "try increasing number of rows provided in training data, or lowering specified number of " + - "cross validation folds."; - public static (IDataView[] trainDatasets, IDataView[] validationDatasets) CrossValSplit(MLContext context, IDataView trainData, uint numFolds, string samplingKeyColumn) { var originalColumnNames = trainData.Schema.Select(c => c.Name); var splits = context.Data.CrossValidationSplit(trainData, (int)numFolds, samplingKeyColumnName: samplingKeyColumn); - var trainDatasets = new IDataView[numFolds]; - var validationDatasets = new IDataView[numFolds]; - for (var i = 0; i < numFolds; i++) + var trainDatasets = new List(); + var validationDatasets = new List(); + + foreach (var split in splits) { - var split = splits[i]; - trainDatasets[i] = DropAllColumnsExcept(context, split.TrainSet, originalColumnNames); - validationDatasets[i] = DropAllColumnsExcept(context, split.TestSet, originalColumnNames); - if (DatasetDimensionsUtil.IsDataViewEmpty(trainDatasets[i]) || DatasetDimensionsUtil.IsDataViewEmpty(validationDatasets[i])) + if (DatasetDimensionsUtil.IsDataViewEmpty(split.TrainSet) || + DatasetDimensionsUtil.IsDataViewEmpty(split.TestSet)) { - throw new InvalidOperationException(CrossValEmptyFoldErrorMsg); + continue; } + + var trainDataset = DropAllColumnsExcept(context, split.TrainSet, originalColumnNames); + var validationDataset = DropAllColumnsExcept(context, split.TestSet, originalColumnNames); + + trainDatasets.Add(trainDataset); + validationDatasets.Add(validationDataset); } - return (trainDatasets, validationDatasets); + + if (!trainDatasets.Any()) + { + throw new InvalidOperationException("All cross validation folds have empty train or test data. " + + "Try increasing the number of rows provided in training data, or lowering specified number of " + + "cross validation folds."); + } + + return (trainDatasets.ToArray(), validationDatasets.ToArray()); } /// diff --git a/src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs b/src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs index 7fe1f2b1567..dfbecb2634a 100644 --- a/src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs +++ b/src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs @@ -61,6 +61,11 @@ private static void ValidateTrainData(IDataView trainData, ColumnInformation col throw new ArgumentNullException(nameof(trainData), "Training data cannot be null"); } + if (DatasetDimensionsUtil.IsDataViewEmpty(trainData)) + { + throw new ArgumentException("Training data has 0 rows", nameof(trainData)); + } + foreach (var column in trainData.Schema) { if (column.Name == DefaultColumnNames.Features && column.Type.GetItemType() != NumberDataViewType.Single) @@ -164,6 +169,11 @@ private static void ValidateValidationData(IDataView trainData, IDataView valida return; } + if (DatasetDimensionsUtil.IsDataViewEmpty(validationData)) + { + throw new ArgumentException("Validation data has 0 rows", nameof(validationData)); + } + const string schemaMismatchError = "Training data and validation data schemas do not match."; if (trainData.Schema.Count != validationData.Schema.Count) diff --git a/test/Microsoft.ML.AutoML.Tests/SplitUtilTests.cs b/test/Microsoft.ML.AutoML.Tests/SplitUtilTests.cs new file mode 100644 index 00000000000..f1faf11b878 --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/SplitUtilTests.cs @@ -0,0 +1,71 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Linq; +using Microsoft.ML.Data; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.ML.AutoML.Test +{ + [TestClass] + public class SplitUtilTests + { + /// + /// When there's only one row of data, assert that + /// attempted cross validation throws (all splits should have empty + /// train or test set). + /// + [TestMethod] + [ExpectedException(typeof(InvalidOperationException))] + public void CrossValSplitThrowsWhenNotEnoughData() + { + var mlContext = new MLContext(); + var dataViewBuilder = new ArrayDataViewBuilder(mlContext); + dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, 0f); + dataViewBuilder.AddColumn("Label", NumberDataViewType.Single, 0f); + var dataView = dataViewBuilder.GetDataView(); + SplitUtil.CrossValSplit(mlContext, dataView, 10, null); + } + + /// + /// When there are few rows of data, assert that + /// cross validation succeeds, but # of splits is less than 10 + /// (splits with empty train or test sets should not be returned from this API). + /// + [TestMethod] + public void CrossValSplitSmallDataView() + { + var mlContext = new MLContext(seed: 0); + var dataViewBuilder = new ArrayDataViewBuilder(mlContext); + dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, new float[9]); + dataViewBuilder.AddColumn("Label", NumberDataViewType.Single, new float[9]); + var dataView = dataViewBuilder.GetDataView(); + const int requestedNumSplits = 10; + var splits = SplitUtil.CrossValSplit(mlContext, dataView, requestedNumSplits, null); + Assert.IsTrue(splits.trainDatasets.Any()); + Assert.IsTrue(splits.trainDatasets.Count() < requestedNumSplits); + Assert.AreEqual(splits.trainDatasets.Count(), splits.validationDatasets.Count()); + } + + /// + /// Assert that with many rows of data, cross validation produces the requested + /// # of splits. + /// + [TestMethod] + public void CrossValSplitLargeDataView() + { + var mlContext = new MLContext(seed: 0); + var dataViewBuilder = new ArrayDataViewBuilder(mlContext); + dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, new float[10000]); + dataViewBuilder.AddColumn("Label", NumberDataViewType.Single, new float[10000]); + var dataView = dataViewBuilder.GetDataView(); + const int requestedNumSplits = 10; + var splits = SplitUtil.CrossValSplit(mlContext, dataView, requestedNumSplits, null); + Assert.IsTrue(splits.trainDatasets.Any()); + Assert.AreEqual(requestedNumSplits, splits.trainDatasets.Count()); + Assert.AreEqual(requestedNumSplits, splits.validationDatasets.Count()); + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs b/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs index 8c644d97c57..444d2c78034 100644 --- a/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs @@ -187,35 +187,47 @@ public void ValidateTextColumnNotText() [TestMethod] public void ValidateRegressionLabelTypes() { - ValidateLabelTypeTestCore(TaskKind.Regression, NumberDataViewType.Single, true); - ValidateLabelTypeTestCore(TaskKind.Regression, BooleanDataViewType.Instance, false); - ValidateLabelTypeTestCore(TaskKind.Regression, NumberDataViewType.Double, false); - ValidateLabelTypeTestCore(TaskKind.Regression, TextDataViewType.Instance, false); + ValidateLabelTypeTestCore(TaskKind.Regression, NumberDataViewType.Single, true); + ValidateLabelTypeTestCore(TaskKind.Regression, BooleanDataViewType.Instance, false); + ValidateLabelTypeTestCore(TaskKind.Regression, NumberDataViewType.Double, false); + ValidateLabelTypeTestCore(TaskKind.Regression, TextDataViewType.Instance, false); } [TestMethod] public void ValidateBinaryClassificationLabelTypes() { - ValidateLabelTypeTestCore(TaskKind.BinaryClassification, NumberDataViewType.Single, false); - ValidateLabelTypeTestCore(TaskKind.BinaryClassification, BooleanDataViewType.Instance, true); + ValidateLabelTypeTestCore(TaskKind.BinaryClassification, NumberDataViewType.Single, false); + ValidateLabelTypeTestCore(TaskKind.BinaryClassification, BooleanDataViewType.Instance, true); } [TestMethod] public void ValidateMulticlassLabelTypes() { - ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, NumberDataViewType.Single, true); - ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, BooleanDataViewType.Instance, true); - ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, NumberDataViewType.Double, true); - ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, TextDataViewType.Instance, true); + ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, NumberDataViewType.Single, true); + ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, BooleanDataViewType.Instance, true); + ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, NumberDataViewType.Double, true); + ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, TextDataViewType.Instance, true); } [TestMethod] public void ValidateAllowedFeatureColumnTypes() + { + var dataViewBuilder = new ArrayDataViewBuilder(new MLContext()); + dataViewBuilder.AddColumn("Boolean", BooleanDataViewType.Instance, false); + dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, 0f); + dataViewBuilder.AddColumn("Text", "a"); + dataViewBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single, 0f); + var dataView = dataViewBuilder.GetDataView(); + UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(), + null, TaskKind.Regression); + } + + [TestMethod] + [ExpectedException(typeof(ArgumentException))] + public void ValidateProhibitedFeatureColumnType() { var schemaBuilder = new DataViewSchema.Builder(); - schemaBuilder.AddColumn("Boolean", BooleanDataViewType.Instance); - schemaBuilder.AddColumn("Number", NumberDataViewType.Single); - schemaBuilder.AddColumn("Text", TextDataViewType.Instance); + schemaBuilder.AddColumn("UInt64", NumberDataViewType.UInt64); schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single); var schema = schemaBuilder.ToSchema(); var dataView = new EmptyDataView(new MLContext(), schema); @@ -225,10 +237,10 @@ public void ValidateAllowedFeatureColumnTypes() [TestMethod] [ExpectedException(typeof(ArgumentException))] - public void ValidateProhibitedFeatureColumnType() + public void ValidateEmptyTrainingDataThrows() { var schemaBuilder = new DataViewSchema.Builder(); - schemaBuilder.AddColumn("UInt64", NumberDataViewType.UInt64); + schemaBuilder.AddColumn("Number", NumberDataViewType.Single); schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single); var schema = schemaBuilder.ToSchema(); var dataView = new EmptyDataView(new MLContext(), schema); @@ -236,13 +248,40 @@ public void ValidateProhibitedFeatureColumnType() null, TaskKind.Regression); } - private static void ValidateLabelTypeTestCore(TaskKind task, DataViewType labelType, bool labelTypeShouldBeValid) + [TestMethod] + [ExpectedException(typeof(ArgumentException))] + public void ValidateEmptyValidationDataThrows() { + // Training data + var dataViewBuilder = new ArrayDataViewBuilder(new MLContext()); + dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, 0f); + dataViewBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single, 0f); + var trainingData = dataViewBuilder.GetDataView(); + + // Validation data var schemaBuilder = new DataViewSchema.Builder(); - schemaBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single); - schemaBuilder.AddColumn(DefaultColumnNames.Label, labelType); + schemaBuilder.AddColumn("Number", NumberDataViewType.Single); + schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single); var schema = schemaBuilder.ToSchema(); - var dataView = new EmptyDataView(new MLContext(), schema); + var validationData = new EmptyDataView(new MLContext(), schema); + + UserInputValidationUtil.ValidateExperimentExecuteArgs(trainingData, new ColumnInformation(), + validationData, TaskKind.Regression); + } + + private static void ValidateLabelTypeTestCore(TaskKind task, PrimitiveDataViewType labelType, bool labelTypeShouldBeValid) + { + var dataViewBuilder = new ArrayDataViewBuilder(new MLContext()); + dataViewBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single, 0f); + if (labelType == TextDataViewType.Instance) + { + dataViewBuilder.AddColumn(DefaultColumnNames.Label, string.Empty); + } + else + { + dataViewBuilder.AddColumn(DefaultColumnNames.Label, labelType, Activator.CreateInstance()); + } + var dataView = dataViewBuilder.GetDataView(); var validationExceptionThrown = false; try {