3636import torch
3737from botorch .models .gpytorch import BatchedMultiOutputGPyTorchModel
3838from botorch .models .model import FantasizeMixin
39- from botorch .models .transforms .input import InputTransform
40- from botorch .models .transforms .outcome import Log , OutcomeTransform
39+ from botorch .models .transforms .input import InputTransform , Normalize
40+ from botorch .models .transforms .outcome import Log , OutcomeTransform , Standardize
4141from botorch .models .utils import validate_input_scaling
4242from botorch .models .utils .gpytorch_modules import (
4343 get_covar_module_with_dim_scaled_prior ,
4646)
4747from botorch .utils .containers import BotorchContainer
4848from botorch .utils .datasets import SupervisedDataset
49+ from botorch .utils .types import _DefaultType , DEFAULT
4950from gpytorch .constraints .constraints import GreaterThan
5051from gpytorch .distributions .multivariate_normal import MultivariateNormal
5152from gpytorch .likelihoods .gaussian_likelihood import (
@@ -134,8 +135,8 @@ def __init__(
134135 likelihood : Optional [Likelihood ] = None ,
135136 covar_module : Optional [Module ] = None ,
136137 mean_module : Optional [Mean ] = None ,
137- outcome_transform : Optional [OutcomeTransform ] = None ,
138- input_transform : Optional [InputTransform ] = None ,
138+ outcome_transform : Optional [Union [ OutcomeTransform , _DefaultType ]] = DEFAULT ,
139+ input_transform : Optional [Union [ InputTransform , _DefaultType ]] = DEFAULT ,
139140 ) -> None :
140141 r"""
141142 Args:
@@ -154,16 +155,27 @@ def __init__(
154155 outcome_transform: An outcome transform that is applied to the
155156 training data during instantiation and to the posterior during
156157 inference (that is, the `Posterior` obtained by calling
157- `.posterior` on the model will be on the original scale).
158- input_transform: An input transform that is applied in the model's
159- forward pass.
158+ `.posterior` on the model will be on the original scale). We use a
159+ `Standardize` transform if no `outcome_transform` is specified.
160+ Pass down `None` to use no outcome transform.
161+ input_transform: An input transform that is applied in the model's forward
162+ pass. We use a `Normalize` transform if no `input_transform` is
163+ specified. Pass down `None` to use no input transform.
160164 """
165+ self ._validate_tensor_args (X = train_X , Y = train_Y , Yvar = train_Yvar )
166+ if outcome_transform == DEFAULT :
167+ outcome_transform = Standardize (
168+ m = train_Y .shape [- 1 ], batch_shape = train_X .shape [:- 2 ]
169+ )
170+ if input_transform == DEFAULT :
171+ input_transform = Normalize (d = train_X .shape [- 1 ], transform_on_train = True )
161172 with torch .no_grad ():
162173 transformed_X = self .transform_inputs (
163174 X = train_X , input_transform = input_transform
164175 )
165176 if outcome_transform is not None :
166177 train_Y , train_Yvar = outcome_transform (train_Y , train_Yvar )
178+ # Validate again after applying the transforms
167179 self ._validate_tensor_args (X = transformed_X , Y = train_Y , Yvar = train_Yvar )
168180 ignore_X_dims = getattr (self , "_ignore_X_dims_scaling_check" , None )
169181 validate_input_scaling (
@@ -352,6 +364,7 @@ def __init__(
352364 train_X = train_X ,
353365 train_Y = train_Y ,
354366 likelihood = likelihood ,
367+ outcome_transform = None ,
355368 input_transform = input_transform ,
356369 )
357370 self .register_added_loss_term ("noise_added_loss" )
0 commit comments