-
Notifications
You must be signed in to change notification settings - Fork 451
Description
🐛 Bug
Unable to fit a Heteroskedastic GP with Warp input transform.
To reproduce
** Code snippet to reproduce **
# Define some training data as a DataFrame
import numpy as np
import torch
from botorch.models import HeteroskedasticSingleTaskGP
from botorch.fit import fit_gpytorch_mll
from botorch.models.transforms import (
Normalize,
Standardize,
Warp,
ChainedInputTransform
)
from gpytorch.mlls import ExactMarginalLogLikelihood
# The true function
def oscillator(x):
return np.cos((x - 5) / 2) ** 2 * x * 2
noise_scale = 3.0
n_data = 200
X_data = np.random.uniform(-10, 10, n_data)
y_data = oscillator(X_data) + np.random.normal(scale=3.0, size=X_data.shape)
# add noise to data
y_noise = np.random.normal(scale=noise_scale, size=X_data.shape[0]) * np.abs(
X_data * 0.5
)
y_data_heteroskedastic = oscillator(X_data) + y_noise
train_X = torch.tensor(X_data).view(-1, 1)
train_y = torch.tensor(y_data_heteroskedastic).view(-1, 1)
train_yvar = torch.tensor(y_noise**2).view(-1, 1)
n = 50
normalize = Normalize(d=1)
warp = Warp(indices=list(range(train_X.shape[1])))
outcome_transform = Standardize(m=1)
input_transform = ChainedInputTransform(**{"normalize": normalize, "warp": warp})
gp = HeteroskedasticSingleTaskGP(
train_X=train_X[0:n],
train_Y=train_y[0:n],
train_Yvar=train_yvar[0:n],
input_transform=input_transform,
outcome_transform=outcome_transform,
)
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
fit_gpytorch_mll(mll)** Stack trace/error message **
{
"name": "RuntimeError",
"message": "You must train on the training inputs!",
"stack": "---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[3], line 15
6 gp = HeteroskedasticSingleTaskGP(
7 train_X=train_X[0:n],
8 train_Y=train_y[0:n],
(...)
11 outcome_transform=outcome_transform,
12 )
14 mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
---> 15 fit_gpytorch_mll(mll)
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/fit.py:105, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
102 if optimizer is not None: # defer to per-method defaults
103 kwargs[\"optimizer\"] = optimizer
--> 105 return FitGPyTorchMLL(
106 mll,
107 type(mll.likelihood),
108 type(mll.model),
109 closure=closure,
110 closure_kwargs=closure_kwargs,
111 optimizer_kwargs=optimizer_kwargs,
112 **kwargs,
113 )
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/utils/dispatcher.py:93, in Dispatcher.__call__(self, *args, **kwargs)
91 func = self.__getitem__(types=types)
92 try:
---> 93 return func(*args, **kwargs)
94 except MDNotImplementedError:
95 # Traverses registered methods in order, yields whenever a match is found
96 funcs = self.dispatch_iter(*types)
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/fit.py:205, in _fit_fallback(mll, _, __, closure, optimizer, closure_kwargs, optimizer_kwargs, max_attempts, pick_best_of_all_attempts, warning_handler, caught_exception_types, **ignore)
203 with catch_warnings(record=True) as warning_list, debug(True):
204 simplefilter(\"always\", category=OptimizationWarning)
--> 205 result = optimizer(mll, closure=closure, **optimizer_kwargs)
207 # Resolve warnings and determine whether or not to retry
208 success = True
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/fit.py:94, in fit_gpytorch_mll_scipy(mll, parameters, bounds, closure, closure_kwargs, method, options, callback, timeout_sec)
91 if closure_kwargs is not None:
92 closure = partial(closure, **closure_kwargs)
---> 94 result = scipy_minimize(
95 closure=closure,
96 parameters=parameters,
97 bounds=bounds,
98 method=method,
99 options=options,
100 callback=callback,
101 timeout_sec=timeout_sec,
102 )
103 if result.status != OptimizationStatus.SUCCESS:
104 warn(
105 f\"`scipy_minimize` terminated with status {result.status}, displaying\"
106 f\" original message from `scipy.optimize.minimize`: {result.message}\",
107 OptimizationWarning,
108 )
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/core.py:110, in scipy_minimize(closure, parameters, bounds, callback, x0, method, options, timeout_sec)
102 result = OptimizationResult(
103 step=next(call_counter),
104 fval=float(wrapped_closure(x)[0]),
105 status=OptimizationStatus.RUNNING,
106 runtime=monotonic() - start_time,
107 )
108 return callback(parameters, result) # pyre-ignore [29]
--> 110 raw = minimize_with_timeout(
111 wrapped_closure,
112 wrapped_closure.state if x0 is None else x0.astype(np_float64, copy=False),
113 jac=True,
114 bounds=bounds_np,
115 method=method,
116 options=options,
117 callback=wrapped_callback,
118 timeout_sec=timeout_sec,
119 )
121 # Post-processing and outcome handling
122 wrapped_closure.state = asarray(raw.x) # set parameter state to optimal values
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/utils/timeout.py:83, in minimize_with_timeout(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options, timeout_sec)
81 try:
82 warnings.filterwarnings(\"error\", message=\"Method .* cannot handle\")
---> 83 return optimize.minimize(
84 fun=fun,
85 x0=x0,
86 args=args,
87 method=method,
88 jac=jac,
89 hess=hess,
90 hessp=hessp,
91 bounds=bounds,
92 constraints=constraints,
93 tol=tol,
94 callback=wrapped_callback,
95 options=options,
96 )
97 except OptimizationTimeoutError as e:
98 msg = f\"Optimization timed out after {e.runtime} seconds.\"
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_minimize.py:731, in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
728 res = _minimize_newtoncg(fun, x0, args, jac, hess, hessp, callback,
729 **options)
730 elif meth == 'l-bfgs-b':
--> 731 res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
732 callback=callback, **options)
733 elif meth == 'tnc':
734 res = _minimize_tnc(fun, x0, args, jac, bounds, callback=callback,
735 **options)
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_lbfgsb_py.py:407, in _minimize_lbfgsb(fun, x0, args, jac, bounds, disp, maxcor, ftol, gtol, eps, maxfun, maxiter, iprint, callback, maxls, finite_diff_rel_step, **unknown_options)
401 task_str = task.tobytes()
402 if task_str.startswith(b'FG'):
403 # The minimization routine wants f and g at the current x.
404 # Note that interruptions due to maxfun are postponed
405 # until the completion of the current minimization iteration.
406 # Overwrite f and g:
--> 407 f, g = func_and_grad(x)
408 elif task_str.startswith(b'NEW_X'):
409 # new iteration
410 n_iterations += 1
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:343, in ScalarFunction.fun_and_grad(self, x)
341 if not np.array_equal(x, self.x):
342 self._update_x(x)
--> 343 self._update_fun()
344 self._update_grad()
345 return self.f, self.g
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:294, in ScalarFunction._update_fun(self)
292 def _update_fun(self):
293 if not self.f_updated:
--> 294 fx = self._wrapped_fun(self.x)
295 if fx < self._lowest_f:
296 self._lowest_x = self.x
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_differentiable_functions.py:20, in _wrapper_fun.<locals>.wrapped(x)
16 ncalls[0] += 1
17 # Send a copy because the user may overwrite it.
18 # Overwriting results in undefined behaviour because
19 # fun(self.x) will change self.x, with the two no longer linked.
---> 20 fx = fun(np.copy(x), *args)
21 # Make sure the function returns a true scalar
22 if not np.isscalar(fx):
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_optimize.py:79, in MemoizeJac.__call__(self, x, *args)
77 def __call__(self, x, *args):
78 \"\"\" returns the function value \"\"\"
---> 79 self._compute_if_needed(x, *args)
80 return self._value
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/scipy/optimize/_optimize.py:73, in MemoizeJac._compute_if_needed(self, x, *args)
71 if not np.all(x == self.x) or self._value is None or self.jac is None:
72 self.x = np.asarray(x).copy()
---> 73 fg = self.fun(x, *args)
74 self.jac = fg[1]
75 self._value = fg[0]
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/closures/core.py:162, in NdarrayOptimizationClosure.__call__(self, state, **kwargs)
160 index += size
161 except RuntimeError as e:
--> 162 value, grads = _handle_numerical_errors(e, x=self.state, dtype=np_float64)
164 return value, grads
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/utils/common.py:32, in _handle_numerical_errors(error, x, dtype)
30 _dtype = x.dtype if dtype is None else dtype
31 return np.full((), \"nan\", dtype=_dtype), np.full_like(x, \"nan\", dtype=_dtype)
---> 32 raise error
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/closures/core.py:152, in NdarrayOptimizationClosure.__call__(self, state, **kwargs)
149 self.state = state
151 try:
--> 152 value_tensor, grad_tensors = self.closure(**kwargs)
153 value = self.as_array(value_tensor)
154 grads = self._get_gradient_ndarray(fill_value=self.fill_value)
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/closures/core.py:66, in ForwardBackwardClosure.__call__(self, **kwargs)
64 def __call__(self, **kwargs: Any) -> tuple[Tensor, tuple[Optional[Tensor], ...]]:
65 with self.context_manager():
---> 66 values = self.forward(**kwargs)
67 value = values if self.reducer is None else self.reducer(values)
68 self.backward(value)
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/closures/model_closures.py:179, in _get_loss_closure_exact_internal.<locals>.closure(**kwargs)
177 # The inputs will get transformed in forward here.
178 model_output = model(*model.train_inputs)
--> 179 log_likelihood = mll(
180 model_output,
181 model.train_targets,
182 # During model training, the model inputs get transformed in the forward
183 # pass. The train_inputs property is not transformed yet, so we need to
184 # transform it before passing it to the likelihood for consistency.
185 *(model.transform_inputs(X=t_in) for t_in in model.train_inputs),
186 **kwargs,
187 )
188 return -log_likelihood
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/module.py:31, in Module.__call__(self, *inputs, **kwargs)
30 def __call__(self, *inputs, **kwargs) -> Union[Tensor, Distribution, LinearOperator]:
---> 31 outputs = self.forward(*inputs, **kwargs)
32 if isinstance(outputs, list):
33 return [_validate_module_outputs(output) for output in outputs]
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:83, in ExactMarginalLogLikelihood.forward(self, function_dist, target, *params, **kwargs)
81 # Get the log prob of the marginal distribution
82 res = output.log_prob(target)
---> 83 res = self._add_other_terms(res, params)
85 # Scale by the amount of data we have
86 num_data = function_dist.event_shape.numel()
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:42, in ExactMarginalLogLikelihood._add_other_terms(self, res, params)
39 def _add_other_terms(self, res, params):
40 # Add additional terms (SGPR / learned inducing points, heteroskedastic likelihood models)
41 for added_loss_term in self.model.added_loss_terms():
---> 42 res = res.add(added_loss_term.loss(*params))
44 # Add log probs of priors on the (functions of) parameters
45 res_ndim = res.ndim
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/mlls/noise_model_added_loss_term.py:13, in NoiseModelAddedLossTerm.loss(self, *params)
12 def loss(self, *params):
---> 13 output = self.noise_mll.model(*params)
14 targets = self.noise_mll.model.train_targets
15 return self.noise_mll(output, targets)
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/models/exact_gp.py:267, in ExactGP.__call__(self, *args, **kwargs)
263 if settings.debug.on():
264 if not all(
265 torch.equal(train_input, input) for train_input, input in length_safe_zip(train_inputs, inputs)
266 ):
--> 267 raise RuntimeError(\"You must train on the training inputs!\")
268 res = super().__call__(*inputs, **kwargs)
269 return res
RuntimeError: You must train on the training inputs!"
}
Expected Behavior
Model should fit to the data without raising an error.
System information
Please complete the following information:
- BoTorch version: 0.12.0
- GPyTorch version: 1.13
- PyTorch version: 2.4.1
- OS: macOS Sonoma 14.5
Additional context
I suspect this originates from a recent change to HeteroskedasticSingleTaskGP from #2527. This might be because Warp works in a slightly different way than other input transforms like Normalize. The relevant attributes for Normalize are computed at initialisation and is fixed to the same value throughout training whereas for learnable transforms like warp (where the parameters of the input transform are considered somewhat like hyper parameters of the GP) the transform dynamically changes with training iterations.
The input transform is applied on the training_inputs in the mll closure everytime the mll is computed. This can lead to the divergence of training_inputs from their values during initialisation since the warp parameters are updated and the train_inputs are now constantly evolving during training with the new changes in #2527.
I was able to hypothesize this because the mll is successfully computed once when I call fit_gpytorch_mll and it only fails the second time which I think is due to training_inputs diverging from its original value because warp parameters are updated after the first iteration.