Skip to content

Commit 04edaf2

Browse files
esantorellafacebook-github-bot
authored andcommitted
Don't allow unused keyword arguments in Model.construct_inputs
Summary: BoTorch changes to go with the subsequent Ax diff. This will eventually need to be landed with this diff going before the Ax diff. Differential Revision: D53086323
1 parent 3384c24 commit 04edaf2

File tree

3 files changed

+2
-8
lines changed

3 files changed

+2
-8
lines changed

botorch/models/model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,13 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
187187
def construct_inputs(
188188
cls,
189189
training_data: SupervisedDataset,
190-
**kwargs: Any,
191190
) -> Dict[str, Union[BotorchContainer, Tensor]]:
192191
"""
193192
Construct `Model` keyword arguments from a `SupervisedDataset`.
194193
195194
Args:
196195
training_data: A `SupervisedDataset`, with attributes `train_X`,
197196
`train_Y`, and, optionally, `train_Yvar`.
198-
kwargs: Ignored.
199197
200198
Returns:
201199
A dict of keyword arguments that can be used to initialize a `Model`,

botorch/models/multitask.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,6 @@ def construct_inputs(
277277
task_covar_prior: Optional[Prior] = None,
278278
prior_config: Optional[dict] = None,
279279
rank: Optional[int] = None,
280-
**kwargs,
281280
) -> Dict[str, Any]:
282281
r"""Construct `Model` keyword arguments from a dataset and other args.
283282
@@ -310,9 +309,8 @@ def construct_inputs(
310309
raise ValueError(f"eta must be a real number, your eta was {eta}.")
311310
task_covar_prior = LKJCovariancePrior(num_tasks, eta, sd_prior)
312311

313-
base_inputs = super().construct_inputs(
314-
training_data=training_data, task_feature=task_feature, **kwargs
315-
)
312+
# Call Model.construct_inputs to parse training data
313+
base_inputs = super().construct_inputs(training_data=training_data)
316314
return {
317315
**base_inputs,
318316
"task_feature": task_feature,

botorch/models/pairwise_gp.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -788,15 +788,13 @@ def batch_shape(self) -> torch.Size:
788788
def construct_inputs(
789789
cls,
790790
training_data: SupervisedDataset,
791-
**kwargs: Any,
792791
) -> Dict[str, Tensor]:
793792
r"""
794793
Construct `Model` keyword arguments from a `RankingDataset`.
795794
796795
Args:
797796
training_data: A `RankingDataset`, with attributes `train_X`,
798797
`train_Y`, and, optionally, `train_Yvar`.
799-
kwargs: Ignored.
800798
801799
Returns:
802800
A dict of keyword arguments that can be used to initialize a

0 commit comments

Comments
 (0)