Skip to content

Commit c57c8ee

Browse files
esantorellafacebook-github-bot
authored andcommitted
Don't allow unused keyword arguments in Model.construct_inputs (#2186)
Summary: Pull Request resolved: #2186 Reviewed By: saitcakmak Differential Revision: D53086323
1 parent 44ee88c commit c57c8ee

File tree

8 files changed

+58
-19
lines changed

8 files changed

+58
-19
lines changed

botorch/models/contextual.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def construct_inputs(
4747
cls,
4848
training_data: SupervisedDataset,
4949
decomposition: Dict[str, List[int]],
50-
**kwargs: Any,
5150
) -> Dict[str, Any]:
5251
r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`.
5352
@@ -56,7 +55,7 @@ def construct_inputs(
5655
decomposition: Dictionary of context names and their indexes of the
5756
corresponding active context parameters.
5857
"""
59-
base_inputs = super().construct_inputs(training_data=training_data, **kwargs)
58+
base_inputs = super().construct_inputs(training_data=training_data)
6059
return {
6160
**base_inputs,
6261
"decomposition": decomposition,
@@ -127,7 +126,6 @@ def construct_inputs(
127126
embs_feature_dict: Optional[Dict] = None,
128127
embs_dim_list: Optional[List[int]] = None,
129128
context_weight_dict: Optional[Dict] = None,
130-
**kwargs: Any,
131129
) -> Dict[str, Any]:
132130
r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`.
133131
@@ -147,7 +145,7 @@ def construct_inputs(
147145
dimension is set to 1 for each categorical variable.
148146
context_weight_dict: Known population weights of each context.
149147
"""
150-
base_inputs = super().construct_inputs(training_data=training_data, **kwargs)
148+
base_inputs = super().construct_inputs(training_data=training_data)
151149
index_decomp = {
152150
c: [training_data.feature_names.index(i) for i in v]
153151
for c, v in decomposition.items()

botorch/models/gp_regression.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from __future__ import annotations
3232

3333
import warnings
34-
from typing import NoReturn, Optional
34+
from typing import Dict, NoReturn, Optional, Union
3535

3636
import torch
3737
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
@@ -44,6 +44,8 @@
4444
get_matern_kernel_with_gamma_prior,
4545
MIN_INFERRED_NOISE_LEVEL,
4646
)
47+
from botorch.utils.containers import BotorchContainer
48+
from botorch.utils.datasets import SupervisedDataset
4749
from gpytorch.constraints.constraints import GreaterThan
4850
from gpytorch.distributions.multivariate_normal import MultivariateNormal
4951
from gpytorch.likelihoods.gaussian_likelihood import (
@@ -207,6 +209,31 @@ def __init__(
207209
self.input_transform = input_transform
208210
self.to(train_X)
209211

212+
@classmethod
213+
def construct_inputs(
214+
cls, training_data: SupervisedDataset, *, task_feature: Optional[int] = None
215+
) -> Dict[str, Union[BotorchContainer, Tensor]]:
216+
r"""Construct `SingleTaskGP` keyword arguments from a `SupervisedDataset`.
217+
218+
Args:
219+
training_data: A `SupervisedDataset`, with attributes `train_X`,
220+
`train_Y`, and, optionally, `train_Yvar`.
221+
task_feature: Deprecated and allowed only for backward
222+
compatibility; ignored.
223+
224+
Returns:
225+
A dict of keyword arguments that can be used to initialize a `SingleTaskGP`,
226+
with keys `train_X`, `train_Y`, and, optionally, `train_Yvar`.
227+
"""
228+
if task_feature is not None:
229+
warnings.warn(
230+
"`task_feature` is deprecated and will be ignored. In the "
231+
"future, this will be an error.",
232+
DeprecationWarning,
233+
stacklevel=2,
234+
)
235+
return super().construct_inputs(training_data=training_data)
236+
210237
def forward(self, x: Tensor) -> MultivariateNormal:
211238
if self.training:
212239
x = self.transform_inputs(x)

botorch/models/gp_regression_fidelity.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,15 +167,14 @@ def construct_inputs(
167167
cls,
168168
training_data: SupervisedDataset,
169169
fidelity_features: List[int],
170-
**kwargs,
171170
) -> Dict[str, Any]:
172171
r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`.
173172
174173
Args:
175174
training_data: Dictionary of `SupervisedDataset`.
176175
fidelity_features: Index of fidelity parameter as input columns.
177176
"""
178-
inputs = super().construct_inputs(training_data=training_data, **kwargs)
177+
inputs = super().construct_inputs(training_data=training_data)
179178
inputs["data_fidelities"] = fidelity_features
180179
return inputs
181180

botorch/models/gp_regression_mixed.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ def construct_inputs(
187187
training_data: SupervisedDataset,
188188
categorical_features: List[int],
189189
likelihood: Optional[Likelihood] = None,
190-
**kwargs: Any,
191190
) -> Dict[str, Any]:
192191
r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`.
193192
@@ -196,7 +195,7 @@ def construct_inputs(
196195
categorical_features: Column indices of categorical features.
197196
likelihood: Optional likelihood used to constuct the model.
198197
"""
199-
base_inputs = super().construct_inputs(training_data=training_data, **kwargs)
198+
base_inputs = super().construct_inputs(training_data=training_data)
200199
return {
201200
**base_inputs,
202201
"cat_dims": categorical_features,

botorch/models/model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,13 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
186186
def construct_inputs(
187187
cls,
188188
training_data: SupervisedDataset,
189-
**kwargs: Any,
190189
) -> Dict[str, Union[BotorchContainer, Tensor]]:
191190
"""
192191
Construct `Model` keyword arguments from a `SupervisedDataset`.
193192
194193
Args:
195194
training_data: A `SupervisedDataset`, with attributes `train_X`,
196195
`train_Y`, and, optionally, `train_Yvar`.
197-
kwargs: Ignored.
198196
199197
Returns:
200198
A dict of keyword arguments that can be used to initialize a `Model`,

botorch/models/multitask.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,6 @@ def construct_inputs(
334334
task_covar_prior: Optional[Prior] = None,
335335
prior_config: Optional[dict] = None,
336336
rank: Optional[int] = None,
337-
**kwargs,
338337
) -> Dict[str, Any]:
339338
r"""Construct `Model` keyword arguments from a dataset and other args.
340339
@@ -367,9 +366,8 @@ def construct_inputs(
367366
raise ValueError(f"eta must be a real number, your eta was {eta}.")
368367
task_covar_prior = LKJCovariancePrior(num_tasks, eta, sd_prior)
369368

370-
base_inputs = super().construct_inputs(
371-
training_data=training_data, task_feature=task_feature, **kwargs
372-
)
369+
# Call Model.construct_inputs to parse training data
370+
base_inputs = super().construct_inputs(training_data=training_data)
373371
if isinstance(training_data, MultiTaskDataset):
374372
all_tasks = list(range(len(training_data.datasets)))
375373
base_inputs["all_tasks"] = all_tasks

botorch/models/pairwise_gp.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -781,15 +781,13 @@ def batch_shape(self) -> torch.Size:
781781
def construct_inputs(
782782
cls,
783783
training_data: SupervisedDataset,
784-
**kwargs: Any,
785784
) -> Dict[str, Tensor]:
786785
r"""
787786
Construct `Model` keyword arguments from a `RankingDataset`.
788787
789788
Args:
790789
training_data: A `RankingDataset`, with attributes `train_X`,
791790
`train_Y`, and, optionally, `train_Yvar`.
792-
kwargs: Ignored.
793791
794792
Returns:
795793
A dict of keyword arguments that can be used to initialize a

test/models/test_gp_regression.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from gpytorch.priors import GammaPrior
3737

3838

39-
class TestSingleTaskGP(BotorchTestCase):
39+
class TestGPRegressionBase(BotorchTestCase):
4040
def _get_model_and_data(
4141
self,
4242
batch_shape,
@@ -398,6 +398,28 @@ def test_set_transformed_inputs(self):
398398
self.assertEqual(X.shape, tf_X.shape)
399399

400400

401+
class TestSingleTaskGP(TestGPRegressionBase):
402+
model_class = SingleTaskGP
403+
404+
def test_construct_inputs_task_feature_deprecated(self) -> None:
405+
model, model_kwargs = self._get_model_and_data(
406+
batch_shape=torch.Size([]),
407+
m=1,
408+
device=self.device,
409+
dtype=torch.double,
410+
)
411+
X = model_kwargs["train_X"]
412+
Y = model_kwargs["train_Y"]
413+
training_data = SupervisedDataset(
414+
X,
415+
Y,
416+
feature_names=[f"x{i}" for i in range(X.shape[-1])],
417+
outcome_names=["y"],
418+
)
419+
with self.assertWarnsRegex(DeprecationWarning, "`task_feature` is deprecated"):
420+
model.construct_inputs(training_data, task_feature=0)
421+
422+
401423
class TestFixedNoiseGP(TestSingleTaskGP):
402424
model_class = FixedNoiseGP
403425

@@ -542,7 +564,7 @@ class TestFixedNoiseSingleTaskGP(TestFixedNoiseGP):
542564
model_class = SingleTaskGP
543565

544566

545-
class TestHeteroskedasticSingleTaskGP(TestSingleTaskGP):
567+
class TestHeteroskedasticSingleTaskGP(TestGPRegressionBase):
546568
def _get_model_and_data(
547569
self, batch_shape, m, outcome_transform=None, input_transform=None, **tkwargs
548570
):

0 commit comments

Comments
 (0)