@@ -43,7 +43,7 @@ public struct TrainTestData
43
43
/// </summary>
44
44
/// <param name="trainSet">Training set.</param>
45
45
/// <param name="testSet">Testing set.</param>
46
- public TrainTestData ( IDataView trainSet , IDataView testSet )
46
+ internal TrainTestData ( IDataView trainSet , IDataView testSet )
47
47
{
48
48
TrainSet = trainSet ;
49
49
TestSet = testSet ;
@@ -98,7 +98,7 @@ protected internal struct CrossValidationResult
98
98
/// </summary>
99
99
public readonly ITransformer Model ;
100
100
/// <summary>
101
- /// <see cref="IDataView "/> for scored fold.
101
+ /// Scored test set with <see cref="Model "/> for this fold.
102
102
/// </summary>
103
103
public readonly IDataView Scores ;
104
104
/// <summary>
@@ -113,27 +113,30 @@ public CrossValidationResult(ITransformer model, IDataView scores, int fold)
113
113
Fold = fold ;
114
114
}
115
115
}
116
-
116
+ /// <summary>
117
+ /// Results of running crossvalidation.
118
+ /// </summary>
119
+ /// <typeparam name="T">Type of metric class</typeparam>
117
120
public class CrossValidationResult < T > where T : class
118
121
{
119
122
/// <summary>
120
- /// Metrics for cross validation fold.
123
+ /// Metrics for this cross validation fold.
121
124
/// </summary>
122
125
public readonly T Metrics ;
123
126
/// <summary>
124
127
/// Model trained during cross validation fold.
125
128
/// </summary>
126
129
public readonly ITransformer Model ;
127
130
/// <summary>
128
- /// <see cref="IDataView "/> for scored fold.
131
+ /// Scored test set with <see cref="Model "/> for this fold.
129
132
/// </summary>
130
133
public readonly IDataView Scores ;
131
134
/// <summary>
132
135
/// Fold number.
133
136
/// </summary>
134
137
public readonly int Fold ;
135
138
136
- public CrossValidationResult ( ITransformer model , T metrics , IDataView scores , int fold )
139
+ internal CrossValidationResult ( ITransformer model , T metrics , IDataView scores , int fold )
137
140
{
138
141
Model = model ;
139
142
Metrics = metrics ;
@@ -341,13 +344,14 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string
341
344
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
342
345
/// And if it is not provided, the default value will be used.</param>
343
346
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
344
- public ( BinaryClassificationMetrics metrics , ITransformer model , IDataView scoredTestData ) [ ] CrossValidateNonCalibrated (
347
+ public CrossValidationResult < BinaryClassificationMetrics > [ ] CrossValidateNonCalibrated (
345
348
IDataView data , IEstimator < ITransformer > estimator , int numFolds = 5 , string labelColumn = DefaultColumnNames . Label ,
346
349
string stratificationColumn = null , uint ? seed = null )
347
350
{
348
351
Host . CheckNonEmpty ( labelColumn , nameof ( labelColumn ) ) ;
349
352
var result = CrossValidateTrain ( data , estimator , numFolds , stratificationColumn , seed ) ;
350
- return result . Select ( x => ( EvaluateNonCalibrated ( x . Scores , labelColumn ) , x . Model , x . Scores ) ) . ToArray ( ) ;
353
+ return result . Select ( x => new CrossValidationResult < BinaryClassificationMetrics > ( x . Model ,
354
+ EvaluateNonCalibrated ( x . Scores , labelColumn ) , x . Scores , x . Fold ) ) . ToArray ( ) ;
351
355
}
352
356
353
357
/// <summary>
0 commit comments