-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid retracing acquisition functions #271
Avoid retracing acquisition functions #271
Conversation
tf.random.normal([self._sample_size, tf.shape(mean)[-1]], dtype=tf.float64) | ||
) # [S, L] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whats this for? Just for the tensroflow retracing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Compiling the graph means the variable mean
only knows the statically determined shape (which may have None in it). You have to use tf.shape
to get the dynamic shape at each execution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense and sounds familiar
trieste/models/model_interfaces.py
Outdated
@@ -357,7 +368,7 @@ def model(self) -> GPR | SGPR: | |||
return self._model | |||
|
|||
def update(self, dataset: Dataset) -> None: | |||
x, y = self.model.data | |||
x, y = map(lambda var: var.value(), self.model.data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens if you keep using these as variables?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does .value mean that the variable just becomes a constant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that value() is needed to get the shape a few lines further down. Also to for the type checking on the DataSet initialisation to pass. Either way, it's fine to be explicit here.
try: | ||
if track_state: | ||
models = copy.deepcopy(models) | ||
acquisition_state = copy.deepcopy(acquisition_state) | ||
models_copy = copy.deepcopy(models) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this change is necessary because the deepcopy
d compiled models no longer work properly. I presume the Record
saves the copy of the model so it can be queried later? Does that version of the model work properly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but I'll add a test to show this.
81ae27c
to
07aeb06
Compare
assert len(history) == 4 | ||
assert len(history) == 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why has this changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed: this is because we now store the model copy in the history rather than the original model. This test uses a model that can only be copied 3 times, meaning we now only store it three times.
@@ -499,6 +511,32 @@ def evaluate_loss_of_model_parameters() -> tf.Tensor: | |||
multiple_assign(self.model, current_best_parameters) | |||
|
|||
|
|||
class NumDataPropertyMixin: | |||
"""Mixin class for exposing num_data as a property, stored in a tf.Variable. This is to work |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does Mixin mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A mixin is "a class that contains methods for use by other classes without having to be the parent class of those other classes". It's useful here as both wrappers want the same behaviour.
class SVGPWrapper(SVGP, NumDataPropertyMixin): | ||
"""A wrapper around GPFlow's SVGP class that stores num_data in a tf.Variable and exposes | ||
it as a property.""" | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this (and the VGP one below). It just complicated things. Why not always assume its a standard SVGP coming into the SParseVariational Class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even if we only supported passing in standard SVGPs into SparseVariational, we'd still need this, as we need to turn those standard SVGPs into SVGPs that store num_data in a Variable and expose it as a property.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see
trieste/acquisition/function.py
Outdated
optimization. Improvement is with respect to the current "best" observation ``eta``, where an | ||
improvement moves towards the objective function's minimum, and the expectation is calculated | ||
with respect to the ``model`` posterior. For model posterior :math:`f`, this is | ||
class expected_improvement: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do wonder if we should define a base class for these acquisition function bits?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need documentation saying that these acq functions need an update and a call and what these bits are for
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add an ABC (though note these acquisition functions don't need an update: whether they have one, and what it looks like, is specific to the implementations).
return acquisition | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How come we dont have a new lower_confidence_bound class (its still a function like before). Just because it doesnt have anything to update doesnt mean we dont want to stop recompiling it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because I haven't updated all the acquisition functions yet. See the text box for the PR.
:return: The updated acquisition function. | ||
""" | ||
tf.debugging.assert_positive(len(dataset)) | ||
tf.debugging.Assert(None not in [self._base_acquisition_function], []) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this line do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checks that the base acquisition function already been constructed by a previous call to prepare_acquisition_function. The funny syntax is to make tensorflow happy (and is copied from similar checks elsewhere in this file).
# if possible, just update the penalization function variables | ||
self._penalization.update(pending_points, self._lipschitz_constant, self._eta) | ||
return self._penalized_acquisition | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont see a situation where a penalization function isnt updatable! Maybe we can just force them to be updatable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's still more hassle to write updateable functions than non-updateable ones. And this bit of code is still required for the first time we generate the penalization function, so I don't think we gain anything by removing support for non-updateable functions.
# check that acquisition functions defined as classes aren't being retraced unnecessarily | ||
if isinstance(acquisition_rule, EfficientGlobalOptimization): | ||
acquisition_function = acquisition_rule._acquisition_function | ||
if isinstance(acquisition_function, AcquisitionFunctionClass): | ||
assert acquisition_function.__call__._get_tracing_count() == 3 # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lovely!
This PR avoids retracing acquisition functions by updating them rather than generating them afresh each optimization loop, and compiling them with tf.function. For simplicity, this is made optional and backwards compatible: users can choose whether to implement the update methods or not (at the cost of performance if they don't).
Note that updating the model wrappers to be compatible with being updated involves adding some tf.Variables that slow down the non-sparse GPR and VGP models. In some cases this slowdown can outweigh the speedup from not recompiling the acquisition function. If this becomes a real issue in the future we can add more architecture to support models that don't allow AF updates.
TIP: when reviewing this, add ?w=1 to the URL to ignore whitespace changes.
Still left to do (in a separate PR)