Skip to content

Commit 38eaf93

Browse files
Tom FinleyTomFinley
Tom Finley
authored andcommitted
No abbreviations in TrainContext, use static readonly field where convenient.
1 parent cd15b05 commit 38eaf93

File tree

8 files changed

+33
-29
lines changed

8 files changed

+33
-29
lines changed

src/Microsoft.ML.Core/Prediction/TrainContext.cs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,21 @@ public sealed class TrainContext
3737
/// <summary>
3838
/// Constructor, given a training set and optional other arguments.
3939
/// </summary>
40-
/// <param name="train">Will set <see cref="TrainingSet"/> to this value. This must be specified</param>
41-
/// <param name="valid">Will set <see cref="ValidationSet"/> to this value if specified</param>
42-
/// <param name="initPredictor">Will set <see cref="InitialPredictor"/> to this value if specified</param>
43-
public TrainContext(RoleMappedData train, RoleMappedData valid = null, IPredictor initPredictor = null)
40+
/// <param name="trainingSet">Will set <see cref="TrainingSet"/> to this value. This must be specified</param>
41+
/// <param name="validationSet">Will set <see cref="ValidationSet"/> to this value if specified</param>
42+
/// <param name="initialPredictor">Will set <see cref="InitialPredictor"/> to this value if specified</param>
43+
public TrainContext(RoleMappedData trainingSet, RoleMappedData validationSet = null, IPredictor initialPredictor = null)
4444
{
45-
Contracts.CheckValue(train, nameof(train));
46-
Contracts.CheckValueOrNull(valid);
47-
Contracts.CheckValueOrNull(initPredictor);
45+
Contracts.CheckValue(trainingSet, nameof(trainingSet));
46+
Contracts.CheckValueOrNull(validationSet);
47+
Contracts.CheckValueOrNull(initialPredictor);
4848

4949
// REVIEW: Should there be code here to ensure that the role mappings between the two are compatible?
5050
// That is, all the role mappings are the same and the columns between them have identical types?
5151

52-
TrainingSet = train;
53-
ValidationSet = valid;
54-
InitialPredictor = initPredictor;
52+
TrainingSet = trainingSet;
53+
ValidationSet = validationSet;
54+
InitialPredictor = initialPredictor;
5555
}
5656
}
5757
}

src/Microsoft.ML.LightGBM/LightGbmTrainerBase.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ private sealed class CategoricalMetaData
5151
private protected int FeatureCount;
5252
private protected FastTree.Internal.Ensemble TrainedEnsemble;
5353

54-
private static TrainerInfo _info = new TrainerInfo(normalization: false, caching: false, supportValid: true);
54+
private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false, supportValid: true);
5555
public override TrainerInfo Info => _info;
5656

5757
private protected LightGbmTrainerBase(IHostEnvironment env, LightGbmArguments args, string name)

src/Microsoft.ML.PCA/PcaTrainer.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight
7575
private readonly int _seed;
7676

7777
public override PredictionKind PredictionKind => PredictionKind.AnomalyDetection;
78-
public override TrainerInfo Info { get; }
78+
79+
// The training performs two passes, only. Probably not worth caching.
80+
private static readonly TrainerInfo _info = new TrainerInfo(caching: false);
81+
public override TrainerInfo Info => _info;
7982

8083
public RandomizedPcaTrainer(IHostEnvironment env, Arguments args)
8184
: base(env, LoadNameValue)
@@ -84,8 +87,6 @@ public RandomizedPcaTrainer(IHostEnvironment env, Arguments args)
8487
Host.CheckUserArg(args.Rank > 0, nameof(args.Rank), "Rank must be positive");
8588
Host.CheckUserArg(args.Oversampling >= 0, nameof(args.Oversampling), "Oversampling must be non-negative");
8689

87-
// Two passes, only. Probably not worth caching.
88-
Info = new TrainerInfo(caching: false);
8990
_rank = args.Rank;
9091
_center = args.Center;
9192
_oversampling = args.Oversampling;

src/Microsoft.ML.StandardLearners/Standard/LinearClassificationTrainer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ public abstract class LinearTrainerBase<TPredictor> : TrainerBase<TPredictor>
4949
{
5050
protected bool NeedShuffle;
5151

52-
public override TrainerInfo Info { get; }
52+
private static readonly TrainerInfo _info = new TrainerInfo();
53+
public override TrainerInfo Info => _info;
5354

5455
/// <summary>
5556
/// Whether data is to be shuffled every epoch.
@@ -59,7 +60,6 @@ public abstract class LinearTrainerBase<TPredictor> : TrainerBase<TPredictor>
5960
private protected LinearTrainerBase(IHostEnvironment env, string name)
6061
: base(env, name)
6162
{
62-
Info = new TrainerInfo();
6363
}
6464

6565
public override TPredictor Train(TrainContext context)

src/Microsoft.ML.StandardLearners/Standard/LogisticRegression/LbfgsPredictorBase.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,10 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight
132132
private VBuffer<Float>[] _localGradients;
133133
private Float[] _localLosses;
134134

135-
public override TrainerInfo Info { get; }
135+
// REVIEW: It's pointless to request caching when we're going to load everything into
136+
// memory, that is, when using multiple threads. So should caching not be requested?
137+
private static readonly TrainerInfo _info = new TrainerInfo(caching: true, supportIncrementalTrain: true);
138+
public override TrainerInfo Info => _info;
136139

137140
internal LbfgsTrainerBase(ArgumentsBase args, IHostEnvironment env, string name, bool showTrainingStats = false)
138141
: base(env, name)
@@ -160,9 +163,6 @@ internal LbfgsTrainerBase(ArgumentsBase args, IHostEnvironment env, string name,
160163
DenseOptimizer = args.DenseOptimizer;
161164
ShowTrainingStats = showTrainingStats;
162165
EnforceNonNegativity = args.EnforceNonNegativity;
163-
// REVIEW: It's pointless to request caching when we're going to load everything into
164-
// memory, that is, when using multiple threads. So should caching not be requested?
165-
Info = new TrainerInfo(caching: true, supportIncrementalTrain: true);
166166

167167
if (EnforceNonNegativity && ShowTrainingStats)
168168
{

src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@ public sealed class Arguments : LearnerInputBaseWithLabel
4545

4646
public override PredictionKind PredictionKind => PredictionKind.MultiClassClassification;
4747

48-
public override TrainerInfo Info { get; }
48+
private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false);
49+
public override TrainerInfo Info => _info;
4950

5051
public MultiClassNaiveBayesTrainer(IHostEnvironment env, Arguments args)
5152
: base(env, LoadName)
5253
{
5354
Host.CheckValue(args, nameof(args));
54-
Info = new TrainerInfo(normalization: false, caching: false);
5555
}
5656

5757
public override MultiClassNaiveBayesPredictor Train(TrainContext context)

src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ It assumes that the conditional mean of the dependent variable follows a linear
6161
private readonly bool _perParameterSignificance;
6262

6363
public override PredictionKind PredictionKind => PredictionKind.Regression;
64-
public override TrainerInfo Info { get; }
64+
65+
// The training performs two passes, only. Probably not worth caching.
66+
private static readonly TrainerInfo _info = new TrainerInfo(caching: false);
67+
public override TrainerInfo Info => _info;
6568

6669
public OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args)
6770
: base(env, LoadNameValue)
@@ -70,8 +73,6 @@ public OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args)
7073
Host.CheckUserArg(args.L2Weight >= 0, nameof(args.L2Weight), "L2 regularization term cannot be negative");
7174
_l2Weight = args.L2Weight;
7275
_perParameterSignificance = args.PerParameterSignificance;
73-
// Two passes, only. Probably not worth caching.
74-
Info = new TrainerInfo(caching: false);
7576
}
7677

7778
/// <summary>

src/Microsoft.ML.StandardLearners/Standard/Simple/SimpleTrainers.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,14 @@ public class Arguments
5555
}
5656

5757
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
58-
public override TrainerInfo Info { get; }
58+
59+
private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false);
60+
public override TrainerInfo Info => _info;
5961

6062
public RandomTrainer(IHostEnvironment env, Arguments args)
6163
: base(env, LoadNameValue)
6264
{
6365
Host.CheckValue(args, nameof(args));
64-
Info = new TrainerInfo(normalization: false, caching: false);
6566
}
6667

6768
public override RandomPredictor Train(TrainContext context)
@@ -205,13 +206,14 @@ public sealed class Arguments
205206
}
206207

207208
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
208-
public override TrainerInfo Info { get; }
209+
210+
private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false);
211+
public override TrainerInfo Info => _info;
209212

210213
public PriorTrainer(IHostEnvironment env, Arguments args)
211214
: base(env, LoadNameValue)
212215
{
213216
Host.CheckValue(args, nameof(args));
214-
Info = new TrainerInfo(normalization: false, caching: false);
215217
}
216218

217219
public override PriorPredictor Train(TrainContext context)

0 commit comments

Comments
 (0)