|
31 | 31 | from __future__ import annotations |
32 | 32 |
|
33 | 33 | import warnings |
34 | | -from typing import NoReturn, Optional |
| 34 | +from typing import Dict, NoReturn, Optional, Union |
35 | 35 |
|
36 | 36 | import torch |
37 | 37 | from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel |
|
44 | 44 | get_matern_kernel_with_gamma_prior, |
45 | 45 | MIN_INFERRED_NOISE_LEVEL, |
46 | 46 | ) |
| 47 | +from botorch.utils.containers import BotorchContainer |
| 48 | +from botorch.utils.datasets import SupervisedDataset |
47 | 49 | from gpytorch.constraints.constraints import GreaterThan |
48 | 50 | from gpytorch.distributions.multivariate_normal import MultivariateNormal |
49 | 51 | from gpytorch.likelihoods.gaussian_likelihood import ( |
@@ -207,6 +209,31 @@ def __init__( |
207 | 209 | self.input_transform = input_transform |
208 | 210 | self.to(train_X) |
209 | 211 |
|
| 212 | + @classmethod |
| 213 | + def construct_inputs( |
| 214 | + cls, training_data: SupervisedDataset, *, task_feature: Optional[int] = None |
| 215 | + ) -> Dict[str, Union[BotorchContainer, Tensor]]: |
| 216 | + r"""Construct `SingleTaskGP` keyword arguments from a `SupervisedDataset`. |
| 217 | +
|
| 218 | + Args: |
| 219 | + training_data: A `SupervisedDataset`, with attributes `train_X`, |
| 220 | + `train_Y`, and, optionally, `train_Yvar`. |
| 221 | + task_feature: Deprecated and allowed only for backward |
| 222 | + compatibility; ignored. |
| 223 | +
|
| 224 | + Returns: |
| 225 | + A dict of keyword arguments that can be used to initialize a `SingleTaskGP`, |
| 226 | + with keys `train_X`, `train_Y`, and, optionally, `train_Yvar`. |
| 227 | + """ |
| 228 | + if task_feature is not None: |
| 229 | + warnings.warn( |
| 230 | + "`task_feature` is deprecated and will be ignored. In the " |
| 231 | + "future, this will be an error.", |
| 232 | + DeprecationWarning, |
| 233 | + stacklevel=2, |
| 234 | + ) |
| 235 | + return super().construct_inputs(training_data=training_data) |
| 236 | + |
210 | 237 | def forward(self, x: Tensor) -> MultivariateNormal: |
211 | 238 | if self.training: |
212 | 239 | x = self.transform_inputs(x) |
|
0 commit comments