Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid retracing acquisition functions #271

Merged
merged 38 commits into from
Aug 11, 2021

Conversation

uri-granta
Copy link
Collaborator

@uri-granta uri-granta commented Jun 16, 2021

This PR avoids retracing acquisition functions by updating them rather than generating them afresh each optimization loop, and compiling them with tf.function. For simplicity, this is made optional and backwards compatible: users can choose whether to implement the update methods or not (at the cost of performance if they don't).

Note that updating the model wrappers to be compatible with being updated involves adding some tf.Variables that slow down the non-sparse GPR and VGP models. In some cases this slowdown can outweigh the speedup from not recompiling the acquisition function. If this becomes a real issue in the future we can add more architecture to support models that don't allow AF updates.

TIP: when reviewing this, add ?w=1 to the URL to ignore whitespace changes.

Still left to do (in a separate PR)

  • update all the acquisition functions:
  • ExpectedImprovement
  • AugmentedExpectedImprovement
  • MinValueEntropySearch
  • NegativeLowerConfidenceBound
  • ProbabilityOfFeasibility
  • ExpectedConstrainedImprovement
  • ExpectedHypervolumeImprovement
  • BatchMonteCarloExpectedImprovement.
  • LocalPenalizationAcquisitionFunction
  • update all the penalizers too:
  • soft_local_penalizer
  • hard_local_penalizer

@uri-granta uri-granta requested a review from henrymoss June 29, 2021 09:30
@uri-granta uri-granta marked this pull request as ready for review June 30, 2021 08:18
Comment on lines +240 to 241
tf.random.normal([self._sample_size, tf.shape(mean)[-1]], dtype=tf.float64)
) # [S, L]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whats this for? Just for the tensroflow retracing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Compiling the graph means the variable mean only knows the statically determined shape (which may have None in it). You have to use tf.shape to get the dynamic shape at each execution.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense and sounds familiar

@@ -357,7 +368,7 @@ def model(self) -> GPR | SGPR:
return self._model

def update(self, dataset: Dataset) -> None:
x, y = self.model.data
x, y = map(lambda var: var.value(), self.model.data)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if you keep using these as variables?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does .value mean that the variable just becomes a constant?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that value() is needed to get the shape a few lines further down. Also to for the type checking on the DataSet initialisation to pass. Either way, it's fine to be explicit here.

@uri-granta uri-granta requested a review from johnamcleod July 1, 2021 12:28
try:
if track_state:
models = copy.deepcopy(models)
acquisition_state = copy.deepcopy(acquisition_state)
models_copy = copy.deepcopy(models)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this change is necessary because the deepcopyd compiled models no longer work properly. I presume the Record saves the copy of the model so it can be queried later? Does that version of the model work properly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but I'll add a test to show this.

@uri-granta uri-granta force-pushed the uri/avoid_retracing branch from 81ae27c to 07aeb06 Compare July 8, 2021 08:01
@henrymoss henrymoss self-requested a review July 22, 2021 14:35
assert len(history) == 4
assert len(history) == 3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why has this changed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed: this is because we now store the model copy in the history rather than the original model. This test uses a model that can only be copied 3 times, meaning we now only store it three times.

@henrymoss henrymoss self-requested a review August 11, 2021 10:36
@@ -499,6 +511,32 @@ def evaluate_loss_of_model_parameters() -> tf.Tensor:
multiple_assign(self.model, current_best_parameters)


class NumDataPropertyMixin:
"""Mixin class for exposing num_data as a property, stored in a tf.Variable. This is to work
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does Mixin mean?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A mixin is "a class that contains methods for use by other classes without having to be the parent class of those other classes". It's useful here as both wrappers want the same behaviour.

Comment on lines +535 to +539
class SVGPWrapper(SVGP, NumDataPropertyMixin):
"""A wrapper around GPFlow's SVGP class that stores num_data in a tf.Variable and exposes
it as a property."""


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this (and the VGP one below). It just complicated things. Why not always assume its a standard SVGP coming into the SParseVariational Class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if we only supported passing in standard SVGPs into SparseVariational, we'd still need this, as we need to turn those standard SVGPs into SVGPs that store num_data in a Variable and expose it as a property.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see

optimization. Improvement is with respect to the current "best" observation ``eta``, where an
improvement moves towards the objective function's minimum, and the expectation is calculated
with respect to the ``model`` posterior. For model posterior :math:`f`, this is
class expected_improvement:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do wonder if we should define a base class for these acquisition function bits?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need documentation saying that these acq functions need an update and a call and what these bits are for

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add an ABC (though note these acquisition functions don't need an update: whether they have one, and what it looks like, is specific to the implementations).

Comment on lines -368 to -369
return acquisition

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come we dont have a new lower_confidence_bound class (its still a function like before). Just because it doesnt have anything to update doesnt mean we dont want to stop recompiling it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I haven't updated all the acquisition functions yet. See the text box for the PR.

:return: The updated acquisition function.
"""
tf.debugging.assert_positive(len(dataset))
tf.debugging.Assert(None not in [self._base_acquisition_function], [])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this line do?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checks that the base acquisition function already been constructed by a previous call to prepare_acquisition_function. The funny syntax is to make tensorflow happy (and is copied from similar checks elsewhere in this file).

Comment on lines +1334 to +1337
# if possible, just update the penalization function variables
self._penalization.update(pending_points, self._lipschitz_constant, self._eta)
return self._penalized_acquisition
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont see a situation where a penalization function isnt updatable! Maybe we can just force them to be updatable

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still more hassle to write updateable functions than non-updateable ones. And this bit of code is still required for the first time we generate the penalization function, so I don't think we gain anything by removing support for non-updateable functions.

Comment on lines +137 to +141
# check that acquisition functions defined as classes aren't being retraced unnecessarily
if isinstance(acquisition_rule, EfficientGlobalOptimization):
acquisition_function = acquisition_rule._acquisition_function
if isinstance(acquisition_function, AcquisitionFunctionClass):
assert acquisition_function.__call__._get_tracing_count() == 3 # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lovely!

@uri-granta uri-granta merged commit a05075b into secondmind-labs:develop Aug 11, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants