Skip to content

Commit d276d0c

Browse files
Balandatfacebook-github-bot
authored andcommitted
Fix sobol maxdim limitation in prune_baseline (pytorch#419)
Summary: Fixes facebook/Ax#291 Pull Request resolved: pytorch#419 Test Plan: unit tests Reviewed By: sdaulton Differential Revision: D20978955 Pulled By: Balandat fbshipit-source-id: 2ff30700f1d86542c535f9b39b3c8350e0ec4c25
1 parent 0e24fcb commit d276d0c

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

botorch/acquisition/utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,16 @@
1111
from __future__ import annotations
1212

1313
import math
14+
import warnings
1415
from typing import Callable, Dict, List, Optional
1516

1617
import torch
1718
from torch import Tensor
19+
from torch.quasirandom import SobolEngine
1820

21+
from .. import settings
1922
from ..exceptions.errors import UnsupportedError
23+
from ..exceptions.warnings import SamplingWarning
2024
from ..models.model import Model
2125
from ..sampling.samplers import IIDNormalSampler, SobolQMCNormalSampler
2226
from ..utils.transforms import squeeze_last_dim
@@ -213,10 +217,19 @@ def prune_inferior_points(
213217
max_points = math.ceil(max_frac * X.size(-2))
214218
if max_points < 1 or max_points > X.size(-2):
215219
raise ValueError(f"max_frac must take values in (0, 1], is {max_frac}")
216-
sampler = SobolQMCNormalSampler(num_samples=num_samples)
217220
with torch.no_grad():
218221
posterior = model.posterior(X=X)
219-
samples = sampler(posterior)
222+
if posterior.event_shape.numel() > SobolEngine.MAXDIM:
223+
if settings.debug.on():
224+
warnings.warn(
225+
f"Sample dimension q*m={posterior.event_shape.numel()} exceeding Sobol "
226+
f"max dimension ({SobolEngine.MAXDIM}). Using iid samples instead.",
227+
SamplingWarning,
228+
)
229+
sampler = IIDNormalSampler(num_samples=num_samples)
230+
else:
231+
sampler = SobolQMCNormalSampler(num_samples=num_samples)
232+
samples = sampler(posterior)
220233
if objective is None:
221234
objective = IdentityMCObjective()
222235
obj_vals = objective(samples)

test/acquisition/test_utils.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import itertools
8+
import warnings
9+
from contextlib import ExitStack
810
from unittest import mock
911

1012
import torch
13+
from botorch import settings
1114
from botorch.acquisition import monte_carlo
1215
from botorch.acquisition.objective import GenericMCObjective, MCAcquisitionObjective
1316
from botorch.acquisition.utils import (
@@ -17,7 +20,8 @@
1720
project_to_target_fidelity,
1821
prune_inferior_points,
1922
)
20-
from botorch.exceptions import UnsupportedError
23+
from botorch.exceptions.errors import UnsupportedError
24+
from botorch.exceptions.warnings import SamplingWarning
2125
from botorch.sampling.samplers import IIDNormalSampler, SobolQMCNormalSampler
2226
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
2327
from torch import Tensor
@@ -351,6 +355,22 @@ def test_prune_inferior_points(self):
351355
mm = MockModel(MockPosterior(samples=samples))
352356
X_pruned = prune_inferior_points(model=mm, X=X)
353357
self.assertTrue(torch.equal(X_pruned, X[:2]))
358+
# test high-dim sampling
359+
with ExitStack() as es:
360+
mock_event_shape = es.enter_context(
361+
mock.patch(
362+
"botorch.utils.testing.MockPosterior.event_shape",
363+
new_callable=mock.PropertyMock,
364+
)
365+
)
366+
mock_event_shape.return_value = torch.Size([1, 1, 1112])
367+
es.enter_context(
368+
mock.patch.object(MockPosterior, "rsample", return_value=samples)
369+
)
370+
mm = MockModel(MockPosterior(samples=samples))
371+
with warnings.catch_warnings(record=True) as ws, settings.debug(True):
372+
prune_inferior_points(model=mm, X=X)
373+
self.assertTrue(issubclass(ws[-1].category, SamplingWarning))
354374

355375

356376
class TestFidelityUtils(BotorchTestCase):

0 commit comments

Comments
 (0)