Skip to content

Commit 0790005

Browse files
author
Ivan Matantsev
committed
address small things
1 parent 3d9b174 commit 0790005

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

src/Microsoft.ML.Data/TrainCatalog.cs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public struct TrainTestData
4343
/// </summary>
4444
/// <param name="trainSet">Training set.</param>
4545
/// <param name="testSet">Testing set.</param>
46-
public TrainTestData(IDataView trainSet, IDataView testSet)
46+
internal TrainTestData(IDataView trainSet, IDataView testSet)
4747
{
4848
TrainSet = trainSet;
4949
TestSet = testSet;
@@ -98,7 +98,7 @@ protected internal struct CrossValidationResult
9898
/// </summary>
9999
public readonly ITransformer Model;
100100
/// <summary>
101-
/// <see cref="IDataView"/> for scored fold.
101+
/// Scored test set with <see cref="Model"/> for this fold.
102102
/// </summary>
103103
public readonly IDataView Scores;
104104
/// <summary>
@@ -113,27 +113,30 @@ public CrossValidationResult(ITransformer model, IDataView scores, int fold)
113113
Fold = fold;
114114
}
115115
}
116-
116+
/// <summary>
117+
/// Results of running crossvalidation.
118+
/// </summary>
119+
/// <typeparam name="T">Type of metric class</typeparam>
117120
public class CrossValidationResult<T> where T : class
118121
{
119122
/// <summary>
120-
/// Metrics for cross validation fold.
123+
/// Metrics for this cross validation fold.
121124
/// </summary>
122125
public readonly T Metrics;
123126
/// <summary>
124127
/// Model trained during cross validation fold.
125128
/// </summary>
126129
public readonly ITransformer Model;
127130
/// <summary>
128-
/// <see cref="IDataView"/> for scored fold.
131+
/// Scored test set with <see cref="Model"/> for this fold.
129132
/// </summary>
130133
public readonly IDataView Scores;
131134
/// <summary>
132135
/// Fold number.
133136
/// </summary>
134137
public readonly int Fold;
135138

136-
public CrossValidationResult(ITransformer model, T metrics, IDataView scores, int fold)
139+
internal CrossValidationResult(ITransformer model, T metrics, IDataView scores, int fold)
137140
{
138141
Model = model;
139142
Metrics = metrics;
@@ -341,13 +344,14 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string
341344
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
342345
/// And if it is not provided, the default value will be used.</param>
343346
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
344-
public (BinaryClassificationMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidateNonCalibrated(
347+
public CrossValidationResult<BinaryClassificationMetrics>[] CrossValidateNonCalibrated(
345348
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label,
346349
string stratificationColumn = null, uint? seed = null)
347350
{
348351
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
349352
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed);
350-
return result.Select(x => (EvaluateNonCalibrated(x.Scores, labelColumn), x.Model, x.Scores)).ToArray();
353+
return result.Select(x => new CrossValidationResult<BinaryClassificationMetrics>(x.Model,
354+
EvaluateNonCalibrated(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray();
351355
}
352356

353357
/// <summary>

src/Microsoft.ML.StaticPipe/TrainingStaticExtensions.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,9 @@ public static (BinaryClassificationMetrics metrics, Transformer<TInShape, TOutSh
221221
var results = catalog.CrossValidateNonCalibrated(data.AsDynamic, estimator.AsDynamic, numFolds, labelName, stratName, seed);
222222

223223
return results.Select(x => (
224-
x.metrics,
225-
new Transformer<TInShape, TOutShape, TTransformer>(env, (TTransformer)x.model, data.Shape, estimator.Shape),
226-
new DataView<TOutShape>(env, x.scoredTestData, estimator.Shape)))
224+
x.Metrics,
225+
new Transformer<TInShape, TOutShape, TTransformer>(env, (TTransformer)x.Model, data.Shape, estimator.Shape),
226+
new DataView<TOutShape>(env, x.Scores, estimator.Shape)))
227227
.ToArray();
228228
}
229229

0 commit comments

Comments
 (0)