-
Notifications
You must be signed in to change notification settings - Fork 450
Change how qNEHVI handles pending points #2985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@TobyBoyne Thanks for the PR! I'm pretty sure that @sdaulton meant qLogNEHVI and q(non-Log)NEHVI when he mentioned that more than one acqf have the same issue, so you're already good on that. |
Co-authored-by: Carl Hvarfner <58733990+hvarfner@users.noreply.github.com>
…orch into qnehvi-pending-bug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @TobyBoyne, this looks good to me!
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2985 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 216 216
Lines 20327 20337 +10
=========================================
+ Hits 20327 20337 +10 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| @t_batch_mode_transform() | ||
| @average_over_ensemble_models | ||
| def forward(self, X: Tensor) -> Tensor: | ||
| # Manually concatenate pending points only if: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like some 90% of the code in the forward() method here is shared with qLogNoisyExpectedHypervolumeImprovement.forward() above. We should probably deduplicate this by putting it into a shared utility.
@TobyBoyne if you're up for it that would be great, otherwise @hvarfner let's put this on the backlog.
|
@Balandat, I've extracted the common code in the forward passes into the shared Also, it looks like the non-log acqf always returns |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @TobyBoyne ,
This change looks great. Some minor nits to make the forward passes less opaque, then we should be good to merge this.
| q_in = X.shape[-2] * n_w | ||
| self._set_sampler(q_in=q_in, posterior=posterior) | ||
| samples = self._get_f_X_samples(posterior=posterior, q_in=q_in) | ||
| samples, X = super().forward(X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, this is awesome. Please just add a comment for both to say where the forward pass goes and what it does, since this forward pass is now more opaque.
| q_in = X.shape[-2] * n_w | ||
| self._set_sampler(q_in=q_in, posterior=posterior) | ||
| samples = self._get_f_X_samples(posterior=posterior, q_in=q_in) | ||
| samples, X = super().forward(X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
...and same here.
|
@TobyBoyne Thank you for being persistent with this. Looks ready to be merged! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, looks great. Let me import this and make sure this works; I'll land this after the long weekend here in the US.
Motivation
Currently, qNEHVI proposes repeated experiments in a batch when initial pending points are passed. This PR changes how this class handles pending points -
X_pendingis now always populated, and only appended in the forward pass if those points have not yet been cached. See issue #2983 for further discussion.Have you read the Contributing Guidelines on pull requests?
Yes
Test Plan
I will rewrite the tests in
test/acquisition/multi_objective/test_monte_carlo.pyto ensure that they pass.