Description
There are two changes proposed here for ITrainer
. Due to the nature of the changes, assuming we agree they are good ideas, it really makes sense that they ought to happen in one go, since they involve changes to the same core functionality.
I am quite certain the first thing is a good idea, but the second thing I am less certain about. (Of course my confidence could be misplaced. 😄) People that occur to me as potentially being good contributors to the discussion would be @eerhardt , @zeahmed , @shauheen , @ericstj , @glebuk . (To be clear, this is not exclusionary. Anyone can and should feel free to comment. I just want these people to get flagged is all. 😄 )
ITrainer<TData, TPred>.Train
ought to return the predictor
Currently in order to train a predictor, there is a two step process. You call ITrainer.Train
then call ITrainer.GetPredictor
. As near as I can tell this arrangement was meant as some scheme to support online training, but of course that vision never came to pass, and if we were to actually support online training I have to imagine it would be through some separate interface anyway.
This arrangement seems unambiguously bad. It necessitates making ITrainer
objects stateful for no particularly good reason. This complicates both the implementation and usage of the class, since (1) the caller can't do things like call Train
multiple times even though a developer, seeing this interface, might reasonably suppose such a thing were possible and (2) the author of the ITrainer
component has to protect against that misuse.
Get rid of IValidatingTrainer
, IIncrementalTrainer
, IValidatingIncrementalTrainer
The problem
First let's talk about the problem...
Most (all?) trainers implement ITrainer<RoleMappedData>
, based on this interface here.
We have historically followed adding more inputs to the training process by declaring specialized interfaces that represent the Cartesian product of all possible permutations of inputs, as we see:
-
There was a discussion about adding a validation set. So we doubled the number of interfaces to two,
ITrainer
, andIValidatingTrainer
. -
Later on there was a discussion about adding continued training based on an initialization. So we again doubled the number of interfaces, introducing
IIncrementalTrainer
andIValidatingIncrementalTrainer
. -
There have been discussions of adding a test set to allow computing metrics as training progresses. Following the same strategy we would of course again double the number of interfaces (the full set being represented, perhaps, by
IValidatingIncrementalTestingTrainer
), for a total of eight. -
If hypothetically we were to somehow allow for one more input beyond that, we'd have a total of sixteen interfaces.
Etc. etc. That there is this exponential cost makes clear something is misdesigned. This has cost not only here and in the implementations of ITrainer
, but in the usage as well. Here we see a method that explores the cartesian product of possible interfaces so it can call the right one. It seems to me something is wrong here when just calling "train" requires a fairly non-obvious utility method to make sure we call the "right" train.
This issue incidentally is the primary reason why we haven't done anything like add support for test set metrics during training (despite the many requests). That is, it is not any technical difficulty with the idea itself, it's just that writing such a thing would make the code unmaintainable.
The possible solution(s)
So: instead we might just have one interface, with one required input (the training dataset), and all these other things are optional.
There are two obvious ways I could imagine doing this, first explicitly as part of the method signature on ITrainer<...>
:
public interface ITrainer<TDataset, TPredictor> {
TPredictor Train(TDataset train, TDataset validation = null, TDataset testSet = null, IPredictor initPredictor = null); }
Or else have some sort of context object. (I'm not married to any of these names, to be clear. 😄 )
public sealed class TrainContext {
public RoleMappedData Train { get; }
public RoleMappedData Validation { get; }
public RoleMappedData Test { get; }
public IPredictor InitPredictor { get; }
}
and all trainers implement ITrainer<TrainContext>
instead of ITrainer<RoleMappedData>
.
The latter is perhaps a bit more awkward since it involves the addition of a new abstraction (the hypothetical TrainContext
), but it is more flexible in a forward-looking sense, since if we add more "stuff" to how we initialize trainers, we won't break all existing ITrainer
implementations. (My expectation is that trainers that can't support something would simply ignore.)