Skip to content

[Bug] Unable to fit Heteroskedastic GP with input warping #2551

@SaiAakash

Description

@SaiAakash

🐛 Bug

Unable to fit a Heteroskedastic GP with Warp input transform.

To reproduce

** Code snippet to reproduce **

# Define some training data as a DataFrame
import numpy as np
import torch
from botorch.models import HeteroskedasticSingleTaskGP
from botorch.fit import fit_gpytorch_mll
from botorch.models.transforms import (
    Normalize,
    Standardize,
    Warp,
    ChainedInputTransform
)
from gpytorch.mlls import ExactMarginalLogLikelihood


# The true function
def oscillator(x):
    return np.cos((x - 5) / 2) ** 2 * x * 2


noise_scale = 3.0

n_data = 200
X_data = np.random.uniform(-10, 10, n_data)
y_data = oscillator(X_data) + np.random.normal(scale=3.0, size=X_data.shape)

# add noise to data
y_noise = np.random.normal(scale=noise_scale, size=X_data.shape[0]) * np.abs(
    X_data * 0.5
)
y_data_heteroskedastic = oscillator(X_data) + y_noise

train_X = torch.tensor(X_data).view(-1, 1)
train_y = torch.tensor(y_data_heteroskedastic).view(-1, 1)
train_yvar = torch.tensor(y_noise**2).view(-1, 1)

n = 50
normalize = Normalize(d=1)
warp = Warp(indices=list(range(train_X.shape[1])))
outcome_transform = Standardize(m=1)
input_transform = ChainedInputTransform(**{"normalize": normalize, "warp": warp})


gp = HeteroskedasticSingleTaskGP(
    train_X=train_X[0:n],
    train_Y=train_y[0:n],
    train_Yvar=train_yvar[0:n],
    input_transform=input_transform,
    outcome_transform=outcome_transform,
)

mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
fit_gpytorch_mll(mll)

** Stack trace/error message **

{
	"name": "RuntimeError",
	"message": "You must train on the training inputs!",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[3], line 15
      6 gp = HeteroskedasticSingleTaskGP(
      7     train_X=train_X[0:n],
      8     train_Y=train_y[0:n],
   (...)
     11     outcome_transform=outcome_transform,
     12 )
     14 mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
---> 15 fit_gpytorch_mll(mll)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/fit.py:105, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
    102 if optimizer is not None:  # defer to per-method defaults
    103     kwargs[\"optimizer\"] = optimizer
--> 105 return FitGPyTorchMLL(
    106     mll,
    107     type(mll.likelihood),
    108     type(mll.model),
    109     closure=closure,
    110     closure_kwargs=closure_kwargs,
    111     optimizer_kwargs=optimizer_kwargs,
    112     **kwargs,
    113 )

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/utils/dispatcher.py:93, in Dispatcher.__call__(self, *args, **kwargs)
     91 func = self.__getitem__(types=types)
     92 try:
---> 93     return func(*args, **kwargs)
     94 except MDNotImplementedError:
     95     # Traverses registered methods in order, yields whenever a match is found
     96     funcs = self.dispatch_iter(*types)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/fit.py:205, in _fit_fallback(mll, _, __, closure, optimizer, closure_kwargs, optimizer_kwargs, max_attempts, pick_best_of_all_attempts, warning_handler, caught_exception_types, **ignore)
    203 with catch_warnings(record=True) as warning_list, debug(True):
    204     simplefilter(\"always\", category=OptimizationWarning)
--> 205     result = optimizer(mll, closure=closure, **optimizer_kwargs)
    207 # Resolve warnings and determine whether or not to retry
    208 success = True

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/fit.py:94, in fit_gpytorch_mll_scipy(mll, parameters, bounds, closure, closure_kwargs, method, options, callback, timeout_sec)
     91 if closure_kwargs is not None:
     92     closure = partial(closure, **closure_kwargs)
---> 94 result = scipy_minimize(
     95     closure=closure,
     96     parameters=parameters,
     97     bounds=bounds,
     98     method=method,
     99     options=options,
    100     callback=callback,
    101     timeout_sec=timeout_sec,
    102 )
    103 if result.status != OptimizationStatus.SUCCESS:
    104     warn(
    105         f\"`scipy_minimize` terminated with status {result.status}, displaying\"
    106         f\" original message from `scipy.optimize.minimize`: {result.message}\",
    107         OptimizationWarning,
    108     )

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/core.py:110, in scipy_minimize(closure, parameters, bounds, callback, x0, method, options, timeout_sec)
    102         result = OptimizationResult(
    103             step=next(call_counter),
    104             fval=float(wrapped_closure(x)[0]),
    105             status=OptimizationStatus.RUNNING,
    106             runtime=monotonic() - start_time,
    107         )
    108         return callback(parameters, result)  # pyre-ignore [29]
--> 110 raw = minimize_with_timeout(
    111     wrapped_closure,
    112     wrapped_closure.state if x0 is None else x0.astype(np_float64, copy=False),
    113     jac=True,
    114     bounds=bounds_np,
    115     method=method,
    116     options=options,
    117     callback=wrapped_callback,
    118     timeout_sec=timeout_sec,
    119 )
    121 # Post-processing and outcome handling
    122 wrapped_closure.state = asarray(raw.x)  # set parameter state to optimal values

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/utils/timeout.py:83, in minimize_with_timeout(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options, timeout_sec)
     81 try:
     82     warnings.filterwarnings(\"error\", message=\"Method .* cannot handle\")
---> 83     return optimize.minimize(
     84         fun=fun,
     85         x0=x0,
     86         args=args,
     87         method=method,
     88         jac=jac,
     89         hess=hess,
     90         hessp=hessp,
     91         bounds=bounds,
     92         constraints=constraints,
     93         tol=tol,
     94         callback=wrapped_callback,
     95         options=options,
     96     )
     97 except OptimizationTimeoutError as e:
     98     msg = f\"Optimization timed out after {e.runtime} seconds.\"

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_minimize.py:731, in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
    728     res = _minimize_newtoncg(fun, x0, args, jac, hess, hessp, callback,
    729                              **options)
    730 elif meth == 'l-bfgs-b':
--> 731     res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
    732                            callback=callback, **options)
    733 elif meth == 'tnc':
    734     res = _minimize_tnc(fun, x0, args, jac, bounds, callback=callback,
    735                         **options)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_lbfgsb_py.py:407, in _minimize_lbfgsb(fun, x0, args, jac, bounds, disp, maxcor, ftol, gtol, eps, maxfun, maxiter, iprint, callback, maxls, finite_diff_rel_step, **unknown_options)
    401 task_str = task.tobytes()
    402 if task_str.startswith(b'FG'):
    403     # The minimization routine wants f and g at the current x.
    404     # Note that interruptions due to maxfun are postponed
    405     # until the completion of the current minimization iteration.
    406     # Overwrite f and g:
--> 407     f, g = func_and_grad(x)
    408 elif task_str.startswith(b'NEW_X'):
    409     # new iteration
    410     n_iterations += 1

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:343, in ScalarFunction.fun_and_grad(self, x)
    341 if not np.array_equal(x, self.x):
    342     self._update_x(x)
--> 343 self._update_fun()
    344 self._update_grad()
    345 return self.f, self.g

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:294, in ScalarFunction._update_fun(self)
    292 def _update_fun(self):
    293     if not self.f_updated:
--> 294         fx = self._wrapped_fun(self.x)
    295         if fx < self._lowest_f:
    296             self._lowest_x = self.x

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:20, in _wrapper_fun.<locals>.wrapped(x)
     16 ncalls[0] += 1
     17 # Send a copy because the user may overwrite it.
     18 # Overwriting results in undefined behaviour because
     19 # fun(self.x) will change self.x, with the two no longer linked.
---> 20 fx = fun(np.copy(x), *args)
     21 # Make sure the function returns a true scalar
     22 if not np.isscalar(fx):

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_optimize.py:79, in MemoizeJac.__call__(self, x, *args)
     77 def __call__(self, x, *args):
     78     \"\"\" returns the function value \"\"\"
---> 79     self._compute_if_needed(x, *args)
     80     return self._value

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_optimize.py:73, in MemoizeJac._compute_if_needed(self, x, *args)
     71 if not np.all(x == self.x) or self._value is None or self.jac is None:
     72     self.x = np.asarray(x).copy()
---> 73     fg = self.fun(x, *args)
     74     self.jac = fg[1]
     75     self._value = fg[0]

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/closures/core.py:162, in NdarrayOptimizationClosure.__call__(self, state, **kwargs)
    160         index += size
    161 except RuntimeError as e:
--> 162     value, grads = _handle_numerical_errors(e, x=self.state, dtype=np_float64)
    164 return value, grads

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/utils/common.py:32, in _handle_numerical_errors(error, x, dtype)
     30     _dtype = x.dtype if dtype is None else dtype
     31     return np.full((), \"nan\", dtype=_dtype), np.full_like(x, \"nan\", dtype=_dtype)
---> 32 raise error

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/closures/core.py:152, in NdarrayOptimizationClosure.__call__(self, state, **kwargs)
    149     self.state = state
    151 try:
--> 152     value_tensor, grad_tensors = self.closure(**kwargs)
    153     value = self.as_array(value_tensor)
    154     grads = self._get_gradient_ndarray(fill_value=self.fill_value)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/closures/core.py:66, in ForwardBackwardClosure.__call__(self, **kwargs)
     64 def __call__(self, **kwargs: Any) -> tuple[Tensor, tuple[Optional[Tensor], ...]]:
     65     with self.context_manager():
---> 66         values = self.forward(**kwargs)
     67         value = values if self.reducer is None else self.reducer(values)
     68         self.backward(value)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/closures/model_closures.py:179, in _get_loss_closure_exact_internal.<locals>.closure(**kwargs)
    177 # The inputs will get transformed in forward here.
    178 model_output = model(*model.train_inputs)
--> 179 log_likelihood = mll(
    180     model_output,
    181     model.train_targets,
    182     # During model training, the model inputs get transformed in the forward
    183     # pass. The train_inputs property is not transformed yet, so we need to
    184     # transform it before passing it to the likelihood for consistency.
    185     *(model.transform_inputs(X=t_in) for t_in in model.train_inputs),
    186     **kwargs,
    187 )
    188 return -log_likelihood

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/module.py:31, in Module.__call__(self, *inputs, **kwargs)
     30 def __call__(self, *inputs, **kwargs) -> Union[Tensor, Distribution, LinearOperator]:
---> 31     outputs = self.forward(*inputs, **kwargs)
     32     if isinstance(outputs, list):
     33         return [_validate_module_outputs(output) for output in outputs]

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:83, in ExactMarginalLogLikelihood.forward(self, function_dist, target, *params, **kwargs)
     81 # Get the log prob of the marginal distribution
     82 res = output.log_prob(target)
---> 83 res = self._add_other_terms(res, params)
     85 # Scale by the amount of data we have
     86 num_data = function_dist.event_shape.numel()

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:42, in ExactMarginalLogLikelihood._add_other_terms(self, res, params)
     39 def _add_other_terms(self, res, params):
     40     # Add additional terms (SGPR / learned inducing points, heteroskedastic likelihood models)
     41     for added_loss_term in self.model.added_loss_terms():
---> 42         res = res.add(added_loss_term.loss(*params))
     44     # Add log probs of priors on the (functions of) parameters
     45     res_ndim = res.ndim

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/mlls/noise_model_added_loss_term.py:13, in NoiseModelAddedLossTerm.loss(self, *params)
     12 def loss(self, *params):
---> 13     output = self.noise_mll.model(*params)
     14     targets = self.noise_mll.model.train_targets
     15     return self.noise_mll(output, targets)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/models/exact_gp.py:267, in ExactGP.__call__(self, *args, **kwargs)
    263 if settings.debug.on():
    264     if not all(
    265         torch.equal(train_input, input) for train_input, input in length_safe_zip(train_inputs, inputs)
    266     ):
--> 267         raise RuntimeError(\"You must train on the training inputs!\")
    268 res = super().__call__(*inputs, **kwargs)
    269 return res

RuntimeError: You must train on the training inputs!"
}

Expected Behavior

Model should fit to the data without raising an error.

System information

Please complete the following information:

  • BoTorch version: 0.12.0
  • GPyTorch version: 1.13
  • PyTorch version: 2.4.1
  • OS: macOS Sonoma 14.5

Additional context

I suspect this originates from a recent change to HeteroskedasticSingleTaskGP from #2527. This might be because Warp works in a slightly different way than other input transforms like Normalize. The relevant attributes for Normalize are computed at initialisation and is fixed to the same value throughout training whereas for learnable transforms like warp (where the parameters of the input transform are considered somewhat like hyper parameters of the GP) the transform dynamically changes with training iterations.

The input transform is applied on the training_inputs in the mll closure everytime the mll is computed. This can lead to the divergence of training_inputs from their values during initialisation since the warp parameters are updated and the train_inputs are now constantly evolving during training with the new changes in #2527.

I was able to hypothesize this because the mll is successfully computed once when I call fit_gpytorch_mll and it only fails the second time which I think is due to training_inputs diverging from its original value because warp parameters are updated after the first iteration.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions