Skip to content

Commit 59e3d8f

Browse files
esantorellafacebook-github-bot
authored andcommitted
Don't allow unused keyword arguments in Model.construct_inputs (#2186)
Summary: Pull Request resolved: #2186 Differential Revision: D53086323
1 parent d213f3c commit 59e3d8f

File tree

7 files changed

+18
-18
lines changed

7 files changed

+18
-18
lines changed

botorch/models/contextual.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def construct_inputs(
4747
cls,
4848
training_data: SupervisedDataset,
4949
decomposition: Dict[str, List[int]],
50-
**kwargs: Any,
5150
) -> Dict[str, Any]:
5251
r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`.
5352
@@ -56,7 +55,7 @@ def construct_inputs(
5655
decomposition: Dictionary of context names and their indexes of the
5756
corresponding active context parameters.
5857
"""
59-
base_inputs = super().construct_inputs(training_data=training_data, **kwargs)
58+
base_inputs = super().construct_inputs(training_data=training_data)
6059
return {
6160
**base_inputs,
6261
"decomposition": decomposition,
@@ -127,7 +126,6 @@ def construct_inputs(
127126
embs_feature_dict: Optional[Dict] = None,
128127
embs_dim_list: Optional[List[int]] = None,
129128
context_weight_dict: Optional[Dict] = None,
130-
**kwargs: Any,
131129
) -> Dict[str, Any]:
132130
r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`.
133131
@@ -147,7 +145,7 @@ def construct_inputs(
147145
dimension is set to 1 for each categorical variable.
148146
context_weight_dict: Known population weights of each context.
149147
"""
150-
base_inputs = super().construct_inputs(training_data=training_data, **kwargs)
148+
base_inputs = super().construct_inputs(training_data=training_data)
151149
index_decomp = {
152150
c: [training_data.feature_names.index(i) for i in v]
153151
for c, v in decomposition.items()

botorch/models/fully_bayesian_multitask.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,18 +388,28 @@ def construct_inputs(
388388
cls,
389389
training_data: Union[SupervisedDataset, MultiTaskDataset],
390390
task_feature: int,
391+
output_tasks: Optional[List[int]] = None,
392+
prior_config: Optional[dict] = None,
391393
rank: Optional[int] = None,
392-
**kwargs: Any,
393394
) -> Dict[str, Any]:
394395
r"""Construct `Model` keyword arguments from a dataset and other args.
395396
396397
Args:
397398
training_data: A `SupervisedDataset` or a `MultiTaskDataset`.
398399
task_feature: Column index of embedded task indicator features.
400+
output_tasks: A list of task indices for which to compute model
401+
outputs for. If omitted, return outputs for all task indices.
402+
prior_config: Configuration for inter-task covariance prior.
403+
Should only be used if `task_covar_prior` is not passed directly. Must
404+
contain `use_LKJ_prior` indicator and should contain float value `eta`.
399405
rank: The rank of the cross-task covariance matrix.
400406
"""
401407
inputs = super().construct_inputs(
402-
training_data=training_data, task_feature=task_feature, rank=rank, **kwargs
408+
training_data=training_data,
409+
task_feature=task_feature,
410+
rank=rank,
411+
output_tasks=output_tasks,
412+
prior_config=prior_config,
403413
)
404414
inputs.pop("task_covar_prior")
405415
if "train_Yvar" not in inputs:

botorch/models/gp_regression_fidelity.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,15 +167,14 @@ def construct_inputs(
167167
cls,
168168
training_data: SupervisedDataset,
169169
fidelity_features: List[int],
170-
**kwargs,
171170
) -> Dict[str, Any]:
172171
r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`.
173172
174173
Args:
175174
training_data: Dictionary of `SupervisedDataset`.
176175
fidelity_features: Index of fidelity parameter as input columns.
177176
"""
178-
inputs = super().construct_inputs(training_data=training_data, **kwargs)
177+
inputs = super().construct_inputs(training_data=training_data)
179178
inputs["data_fidelities"] = fidelity_features
180179
return inputs
181180

botorch/models/gp_regression_mixed.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ def construct_inputs(
187187
training_data: SupervisedDataset,
188188
categorical_features: List[int],
189189
likelihood: Optional[Likelihood] = None,
190-
**kwargs: Any,
191190
) -> Dict[str, Any]:
192191
r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`.
193192
@@ -196,7 +195,7 @@ def construct_inputs(
196195
categorical_features: Column indices of categorical features.
197196
likelihood: Optional likelihood used to constuct the model.
198197
"""
199-
base_inputs = super().construct_inputs(training_data=training_data, **kwargs)
198+
base_inputs = super().construct_inputs(training_data=training_data)
200199
return {
201200
**base_inputs,
202201
"cat_dims": categorical_features,

botorch/models/model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,13 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
187187
def construct_inputs(
188188
cls,
189189
training_data: SupervisedDataset,
190-
**kwargs: Any,
191190
) -> Dict[str, Union[BotorchContainer, Tensor]]:
192191
"""
193192
Construct `Model` keyword arguments from a `SupervisedDataset`.
194193
195194
Args:
196195
training_data: A `SupervisedDataset`, with attributes `train_X`,
197196
`train_Y`, and, optionally, `train_Yvar`.
198-
kwargs: Ignored.
199197
200198
Returns:
201199
A dict of keyword arguments that can be used to initialize a `Model`,

botorch/models/multitask.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,6 @@ def construct_inputs(
277277
task_covar_prior: Optional[Prior] = None,
278278
prior_config: Optional[dict] = None,
279279
rank: Optional[int] = None,
280-
**kwargs,
281280
) -> Dict[str, Any]:
282281
r"""Construct `Model` keyword arguments from a dataset and other args.
283282
@@ -310,9 +309,8 @@ def construct_inputs(
310309
raise ValueError(f"eta must be a real number, your eta was {eta}.")
311310
task_covar_prior = LKJCovariancePrior(num_tasks, eta, sd_prior)
312311

313-
base_inputs = super().construct_inputs(
314-
training_data=training_data, task_feature=task_feature, **kwargs
315-
)
312+
# Call Model.construct_inputs to parse training data
313+
base_inputs = super().construct_inputs(training_data=training_data)
316314
return {
317315
**base_inputs,
318316
"task_feature": task_feature,

botorch/models/pairwise_gp.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -788,15 +788,13 @@ def batch_shape(self) -> torch.Size:
788788
def construct_inputs(
789789
cls,
790790
training_data: SupervisedDataset,
791-
**kwargs: Any,
792791
) -> Dict[str, Tensor]:
793792
r"""
794793
Construct `Model` keyword arguments from a `RankingDataset`.
795794
796795
Args:
797796
training_data: A `RankingDataset`, with attributes `train_X`,
798797
`train_Y`, and, optionally, `train_Yvar`.
799-
kwargs: Ignored.
800798
801799
Returns:
802800
A dict of keyword arguments that can be used to initialize a

0 commit comments

Comments
 (0)