Skip to content

Commit f4ff77c

Browse files
David Erikssonfacebook-github-bot
authored andcommitted
Use Standardize by default for SingleTaskGP (#2458)
Summary: X-link: facebook/Ax#2630 Pull Request resolved: #2458 D60080819 recently updated the default `SingleTaskGP` BoTorch priors. One significant change was to remove the use of an outputscale, which may not work well if the outputs aren't standardized. This diff changes the `SingleTaskGP` to use `Standardize` by default if no outcome transforms are specified (this allows users to explicitly pass in `None` if they don't want to use any transforms). Differential Revision: D60492937
1 parent 5ffa491 commit f4ff77c

File tree

12 files changed

+197
-67
lines changed

12 files changed

+197
-67
lines changed

botorch/acquisition/analytic.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,15 +1091,17 @@ def _get_noiseless_fantasy_model(
10911091
# are used across all batches (by default, a GP with batched training data
10921092
# uses independent hyperparameters for each batch).
10931093

1094-
# Don't apply `outcome_transform` and `input_transform` here,
1095-
# since the data being passed has already been transformed.
1096-
# So we will instead set them afterwards.
1094+
# We don't want to use the true `outcome_transform` and `input_transform` here
1095+
# since the data being passed has already been transformed. We thus pass `None`
1096+
# and will instead set them afterwards.
10971097
fantasy_model = SingleTaskGP(
10981098
train_X=model.train_inputs[0],
10991099
train_Y=model.train_targets.unsqueeze(-1),
11001100
train_Yvar=model.likelihood.noise_covar.noise.unsqueeze(-1),
11011101
covar_module=deepcopy(model.covar_module),
11021102
mean_module=deepcopy(model.mean_module),
1103+
outcome_transform=None,
1104+
input_transform=None,
11031105
)
11041106

11051107
Yvar = torch.full_like(Y_fantasized, 1e-7)

botorch/acquisition/multi_objective/max_value_entropy_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class qMultiObjectiveMaxValueEntropy(
6464
_default_sample_shape: The `sample_shape` for the default sampler.
6565
6666
Example:
67-
>>> model = SingleTaskGP(train_X, train_Y)
67+
>>> model = SingleTaskGP(train_X, train_Y, outcome_transform=None)
6868
>>> MESMO = qMultiObjectiveMaxValueEntropy(model, sample_pfs)
6969
>>> mesmo = MESMO(test_X)
7070
"""

botorch/models/contextual.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,12 @@ def __init__(
102102
dimension is set to 1 for each categorical variable.
103103
context_weight_dict: Known population weights of each context.
104104
"""
105-
super().__init__(train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar)
105+
super().__init__(
106+
train_X=train_X,
107+
train_Y=train_Y,
108+
train_Yvar=train_Yvar,
109+
outcome_transform=None,
110+
)
106111
self.covar_module = LCEAKernel(
107112
decomposition=decomposition,
108113
batch_shape=self._aug_batch_shape,

botorch/models/converter.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818
from botorch.exceptions import UnsupportedError
1919
from botorch.exceptions.warnings import BotorchWarning
20+
from botorch.models import SingleTaskGP
2021
from botorch.models.gp_regression import HeteroskedasticSingleTaskGP
2122
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
2223
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
@@ -179,6 +180,11 @@ def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorch
179180
batch_length = len(models)
180181
covar_module = _batched_kernel(models[0].covar_module, batch_length)
181182
kwargs["covar_module"] = covar_module
183+
# SingleTaskGP uses a default outcome transforms while this converter doesn't
184+
# support outcome transforms. We need to explicitly pass down `None` to make
185+
# sure no outcome transform is being used.
186+
if isinstance(models[0], SingleTaskGP):
187+
kwargs["outcome_transform"] = None
182188

183189
# construct the batched GP model
184190
input_transform = getattr(models[0], "input_transform", None)
@@ -418,6 +424,12 @@ def batched_multi_output_to_single_output(
418424
kwargs["train_Yvar"] = noise_covar.noise.clone().unsqueeze(-1)
419425
if isinstance(batch_mo_model, SingleTaskMultiFidelityGP):
420426
kwargs.update(batch_mo_model._init_args)
427+
# SingleTaskGP uses a default outcome transforms while this converter doesn't
428+
# support outcome transforms. We need to explicitly pass down `None` to make
429+
# sure no outcome transform is being used.
430+
if isinstance(batch_mo_model, SingleTaskGP):
431+
kwargs["outcome_transform"] = None
432+
421433
single_outcome_model = batch_mo_model.__class__(
422434
input_transform=input_transform, **kwargs
423435
)

botorch/models/gp_regression.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
3838
from botorch.models.model import FantasizeMixin
3939
from botorch.models.transforms.input import InputTransform
40-
from botorch.models.transforms.outcome import Log, OutcomeTransform
40+
from botorch.models.transforms.outcome import Log, OutcomeTransform, Standardize
4141
from botorch.models.utils import validate_input_scaling
4242
from botorch.models.utils.gpytorch_modules import (
4343
get_covar_module_with_dim_scaled_prior,
@@ -46,6 +46,7 @@
4646
)
4747
from botorch.utils.containers import BotorchContainer
4848
from botorch.utils.datasets import SupervisedDataset
49+
from botorch.utils.types import _DefaultType, DEFAULT
4950
from gpytorch.constraints.constraints import GreaterThan
5051
from gpytorch.distributions.multivariate_normal import MultivariateNormal
5152
from gpytorch.likelihoods.gaussian_likelihood import (
@@ -134,7 +135,7 @@ def __init__(
134135
likelihood: Optional[Likelihood] = None,
135136
covar_module: Optional[Module] = None,
136137
mean_module: Optional[Mean] = None,
137-
outcome_transform: Optional[OutcomeTransform] = None,
138+
outcome_transform: Optional[Union[OutcomeTransform, _DefaultType]] = DEFAULT,
138139
input_transform: Optional[InputTransform] = None,
139140
) -> None:
140141
r"""
@@ -154,16 +155,24 @@ def __init__(
154155
outcome_transform: An outcome transform that is applied to the
155156
training data during instantiation and to the posterior during
156157
inference (that is, the `Posterior` obtained by calling
157-
`.posterior` on the model will be on the original scale).
158+
`.posterior` on the model will be on the original scale). We use a
159+
`Standardize` transform if no `outcome_transform` is specified.
160+
Pass down `None` to use no outcome transform.
158161
input_transform: An input transform that is applied in the model's
159162
forward pass.
160163
"""
164+
self._validate_tensor_args(X=train_X, Y=train_Y, Yvar=train_Yvar)
165+
if outcome_transform == DEFAULT:
166+
outcome_transform = Standardize(
167+
m=train_Y.shape[-1], batch_shape=train_X.shape[:-2]
168+
)
161169
with torch.no_grad():
162170
transformed_X = self.transform_inputs(
163171
X=train_X, input_transform=input_transform
164172
)
165173
if outcome_transform is not None:
166174
train_Y, train_Yvar = outcome_transform(train_Y, train_Yvar)
175+
# Validate again after applying the transforms
167176
self._validate_tensor_args(X=transformed_X, Y=train_Y, Yvar=train_Yvar)
168177
ignore_X_dims = getattr(self, "_ignore_X_dims_scaling_check", None)
169178
validate_input_scaling(
@@ -352,6 +361,7 @@ def __init__(
352361
train_X=train_X,
353362
train_Y=train_Y,
354363
likelihood=likelihood,
364+
outcome_transform=None,
355365
input_transform=input_transform,
356366
)
357367
self.register_added_loss_term("noise_added_loss")

test/acquisition/multi_objective/test_max_value_entropy_search.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
qMultiObjectiveMaxValueEntropy,
1515
)
1616
from botorch.acquisition.multi_objective.utils import compute_sample_box_decomposition
17+
from botorch.exceptions.errors import UnsupportedError
1718
from botorch.models.gp_regression import SingleTaskGP
1819
from botorch.models.model_list_gp_regression import ModelListGP
1920
from botorch.models.transforms.outcome import Standardize
@@ -71,15 +72,30 @@ def test_multi_objective_max_value_entropy(self):
7172
# test batched model
7273
train_X = torch.rand(1, 1, 2, dtype=dtype, device=self.device)
7374
train_Y = torch.rand(1, 1, m, dtype=dtype, device=self.device)
74-
model = SingleTaskGP(train_X, train_Y)
75+
model = SingleTaskGP(train_X, train_Y, outcome_transform=None)
7576
with self.assertRaises(NotImplementedError):
76-
qMultiObjectiveMaxValueEntropy(model, dummy_sample_pareto_frontiers)
77+
qMultiObjectiveMaxValueEntropy(
78+
model=model, sample_pareto_frontiers=dummy_sample_pareto_frontiers
79+
)
7780
# test initialization
7881
train_X = torch.rand(4, 2, dtype=dtype, device=self.device)
7982
train_Y = torch.rand(4, m, dtype=dtype, device=self.device)
80-
# test batched MO model
83+
# Models with outcome transforms aren't supported.
8184
model = SingleTaskGP(train_X, train_Y)
82-
mesmo = qMultiObjectiveMaxValueEntropy(model, dummy_sample_pareto_frontiers)
85+
with self.assertRaisesRegex(
86+
UnsupportedError,
87+
"Conversion of models with outcome transforms is currently "
88+
"unsupported.",
89+
):
90+
qMultiObjectiveMaxValueEntropy(
91+
model=ModelListGP(model, model),
92+
sample_pareto_frontiers=dummy_sample_pareto_frontiers,
93+
)
94+
# test batched MO model
95+
model = SingleTaskGP(train_X, train_Y, outcome_transform=None)
96+
mesmo = qMultiObjectiveMaxValueEntropy(
97+
model=model, sample_pareto_frontiers=dummy_sample_pareto_frontiers
98+
)
8399
self.assertEqual(mesmo.num_fantasies, 16)
84100
# Initialize the sampler.
85101
dummy_post = model.posterior(train_X[:1])
@@ -98,11 +114,16 @@ def test_multi_objective_max_value_entropy(self):
98114
)
99115
# test ModelListGP
100116
model = ModelListGP(
101-
*[SingleTaskGP(train_X, train_Y[:, i : i + 1]) for i in range(m)]
117+
*[
118+
SingleTaskGP(train_X, train_Y[:, i : i + 1], outcome_transform=None)
119+
for i in range(m)
120+
]
102121
)
103122
mock_sample_pfs = mock.Mock()
104123
mock_sample_pfs.return_value = dummy_sample_pareto_frontiers(model=model)
105-
mesmo = qMultiObjectiveMaxValueEntropy(model, mock_sample_pfs)
124+
mesmo = qMultiObjectiveMaxValueEntropy(
125+
model=model, sample_pareto_frontiers=mock_sample_pfs
126+
)
106127
self.assertEqual(mesmo.num_fantasies, 16)
107128
# Initialize the sampler.
108129
dummy_post = model.posterior(train_X[:1])
@@ -156,7 +177,7 @@ def test_multi_objective_max_value_entropy(self):
156177
],
157178
dim=1,
158179
)
159-
fantasy_model = SingleTaskGP(fant_X, fant_Y)
180+
fantasy_model = SingleTaskGP(fant_X, fant_Y, outcome_transform=None)
160181

161182
# test with X_pending is not None
162183
with mock.patch.object(

0 commit comments

Comments
 (0)