-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor of model - sampler interactions #398
Refactor of model - sampler interactions #398
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(still looking, but a couple of early questions)
Looks good to me in general. My main issue is that |
Yeah. So the idea will be to define what sort of trajectory sampler you want when you define the model. For now that is defaulting to RFF, but as more methods become available (e.g. quadrature or even exact sampling) you will be able to choose between them. This is similar to Gpflux, where you need to define you chosen kernel decomposition when building the model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reset function looks good, remaining question about stack model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a couple more of comments
@@ -27,9 +27,10 @@ | |||
from ...data import Dataset | |||
from ...types import TensorType | |||
from ...utils import DEFAULTS, jit | |||
from ..interfaces import FastUpdateModel, TrainableProbabilisticModel | |||
from ..interfaces import FastUpdateModel, TrainableProbabilisticModel, TrajectorySampler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as with FastUpdateModel
we should probably have ReparametrizationSamplerModel
and TrajectorySamplerModel
interface - then we can also be more precise in each acquisition function what type of probabilistic model user needs
(although we still need to check existence of methods at runtime as users might not use mypy)
""" | ||
raise NotImplementedError(f"Model {self!r} does not have a reparametrization sampler") | ||
|
||
def trajectory_sampler(self) -> TrajectorySampler: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now that we have the FastUpdateModel
interface, we should probably move these methods to new interfaces that we would put here, see my comment above...
@@ -255,13 +254,21 @@ def prepare_acquisition_function( | |||
# hypervolume improvement in this area | |||
_partition_bounds = prepare_default_non_dominated_partition_bounds(_reference_pt, _pf.front) | |||
|
|||
sampler = BatchReparametrizationSampler(self._sample_size, model) | |||
try: | |||
sampler = model.reparam_sampler(self._sample_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we implement the ReparametrizationSamplerModel
interface, then we can also be more precise above in specifying the type of the probabilistic model
) # [S, 0] | ||
|
||
def sample(self, at: TensorType) -> TensorType: | ||
def sample(self, model: ProbabilisticModel, sample_size: int, at: TensorType) -> TensorType: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we implement the new interface, then we can be more precise here with TrajectorySamplerModel
This PR is a prerequisite for @sebastianober's MCEI acquisition function PR and adding in the functionality from out S-GP-TS paper.
Previously we had a set of models and a separate set of samplers (of varying types) defined over with our acquisition functions. Certain samplers only worked for certain models and we dealt with this by having lots of error traps to check the suitability of models when attempting to sample from them. This is not a scalable setup, especially as we have @sebastianober's new deepGP models and their associated custom samplers.
We have three types of samplers
BatchReparameterizationSampler
and @sebastianober 's new reparam sample for deepGPsThis PR has two key parts:
IndependentReparameterizationSampler
, which is probably a nice feature to have (to make it consistent with the batch version) and perhaps not that contentious asIndependentReparameterizationSampler
is not actually used in our code base!reparam_sampler
andtrajectory_sampler
that are defined when defining a new model and call the reparameterization and trajectory samplers relevant for the model. This makes an explicit link between the models and their supported samplers and has allowed me to move the relevant samplers into the models part of the code base. For example, the RFF sampler goes near GaussianProcessRegression andBatchReparameterizationSampler
now lives nearGPFlowPredictor
. Note that the Thompson samplers still live near acquisition functions.Crucially, this PR paves the way for @sebastianober to define his custom GPflux trajectory/reparam samplers in a way that means they are easily accessible by the GPFlux model and his new MCEI acquisition functions can work with our existing models and our MC acquisition functions can work with his (i.e. they all just require a model with a model.reparam_sampler method rather than requiring a list of explicit names for supported models).
I have update the sampling used within the MES and GIBBON acquisition functions and the discrete Thompson sampling ruled to work with the new changes, however, this part of the code will be made much nicer as soon as this PR is closed. Using the work from this PR, we can now just pass in our desired type of Thompson Sampler when defining these functions/rules which is more Pythonic and requires a lot less code.