Skip to content

Commit 1033c52

Browse files
esantorellafacebook-github-bot
authored andcommitted
Affine input transforms should error with data of incorrect dimension, even in eval mode (#2510)
Summary: Pull Request resolved: #2510 Context: #2509 gives a clear overview This PR: * Checks the shape of the `X` provided to an `AffineInputTransform` when it transforms the data, regardless of whether it is updating the coefficients. Makes some unrelated changes: * Fixes the example in the docstring for `batched_multi_output_to_single_output` * fixes an incorrect shape in `test_approximate_gp` * Makes data and transform batch shapes match in `TestConverters`, since those usages will now (appropriately) error Reviewed By: saitcakmak Differential Revision: D62318530
1 parent 1417189 commit 1033c52

File tree

5 files changed

+27
-13
lines changed

5 files changed

+27
-13
lines changed

botorch/models/converter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,8 @@ def batched_multi_output_to_single_output(
388388
Example:
389389
>>> train_X = torch.rand(5, 2)
390390
>>> train_Y = torch.rand(5, 2)
391-
>>> batch_mo_gp = SingleTaskGP(train_X, train_Y)
392-
>>> batch_so_gp = batched_multioutput_to_single_output(batch_gp)
391+
>>> batch_mo_gp = SingleTaskGP(train_X, train_Y, outcome_transform=None)
392+
>>> batch_so_gp = batched_multi_output_to_single_output(batch_mo_gp)
393393
"""
394394
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
395395
was_training = batch_mo_model.training

botorch/models/transforms/input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,8 @@ def _transform(self, X: Tensor) -> Tensor:
412412
Returns:
413413
A `batch_shape x n x d`-dim tensor of transformed inputs.
414414
"""
415+
self._check_shape(X)
415416
if self.learn_coefficients and self.training:
416-
self._check_shape(X)
417417
self._update_coefficients(X)
418418
self._to(X)
419419
return (X - self.offset) / self.coefficient

test/models/test_approximate_gp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,5 +327,5 @@ def test_input_transform(self) -> None:
327327
model.likelihood, model.model, num_data=train_X.shape[-2]
328328
)
329329
fit_gpytorch_mll(mll)
330-
post = model.posterior(torch.tensor([train_X.mean()]))
330+
post = model.posterior(torch.tensor([[train_X.mean()]]))
331331
self.assertAllClose(post.mean[0][0], y.mean(), atol=1e-3, rtol=1e-3)

test/models/test_converter.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,13 +278,21 @@ def test_model_list_to_batched(self):
278278
batch_shape=torch.Size([3]),
279279
)
280280
gp1_ = SingleTaskGP(
281-
train_X, train_Y1, input_transform=input_tf2, outcome_transform=None
281+
train_X=train_X.unsqueeze(0),
282+
train_Y=train_Y1.unsqueeze(0),
283+
input_transform=input_tf2,
284+
outcome_transform=None,
282285
)
283286
gp2_ = SingleTaskGP(
284-
train_X, train_Y2, input_transform=input_tf2, outcome_transform=None
287+
train_X=train_X.unsqueeze(0),
288+
train_Y=train_Y2.unsqueeze(0),
289+
input_transform=input_tf2,
290+
outcome_transform=None,
285291
)
286292
list_gp = ModelListGP(gp1_, gp2_)
287-
with self.assertRaises(UnsupportedError):
293+
with self.assertRaisesRegex(
294+
UnsupportedError, "Batched input_transforms are not supported."
295+
):
288296
model_list_to_batched(list_gp)
289297

290298
# test outcome transform
@@ -457,7 +465,6 @@ def test_batched_multi_output_to_single_output(self):
457465
bounds=torch.tensor(
458466
[[-1.0, -1.0], [1.0, 1.0]], device=self.device, dtype=dtype
459467
),
460-
batch_shape=torch.Size([2]),
461468
)
462469
batched_mo_model = SingleTaskGP(
463470
train_X, train_Y, input_transform=input_tf, outcome_transform=None

test/models/transforms/test_input.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,10 @@ def test_normalize(self) -> None:
228228
self.assertTrue(nlz.mins.dtype == other_dtype)
229229
# test incompatible dimensions of specified bounds
230230
bounds = torch.zeros(2, 3, device=self.device, dtype=dtype)
231-
with self.assertRaises(BotorchTensorDimensionError):
231+
with self.assertRaisesRegex(
232+
BotorchTensorDimensionError,
233+
"Dimensions of provided `bounds` are incompatible",
234+
):
232235
Normalize(d=2, bounds=bounds)
233236

234237
# test jitter
@@ -266,7 +269,12 @@ def test_normalize(self) -> None:
266269
# test errors on wrong shape
267270
nlz = Normalize(d=2, batch_shape=batch_shape)
268271
X = torch.randn(*batch_shape, 2, 1, device=self.device, dtype=dtype)
269-
with self.assertRaises(BotorchTensorDimensionError):
272+
expected_msg = "Wrong input dimension. Received 1, expected 2."
273+
with self.assertRaisesRegex(BotorchTensorDimensionError, expected_msg):
274+
nlz(X)
275+
# Same error in eval mode
276+
nlz.eval()
277+
with self.assertRaisesRegex(BotorchTensorDimensionError, expected_msg):
270278
nlz(X)
271279

272280
# fixed bounds
@@ -328,9 +336,8 @@ def test_normalize(self) -> None:
328336
[X.min(dim=-2, keepdim=True)[0], X.max(dim=-2, keepdim=True)[0]],
329337
dim=-2,
330338
)[..., indices]
331-
self.assertTrue(
332-
torch.allclose(nlz.bounds, expected_bounds, atol=1e-4, rtol=1e-4)
333-
)
339+
self.assertAllClose(nlz.bounds, expected_bounds, atol=1e-4, rtol=1e-4)
340+
334341
# test errors on wrong shape
335342
nlz = Normalize(d=2, batch_shape=batch_shape)
336343
X = torch.randn(*batch_shape, 2, 1, device=self.device, dtype=dtype)

0 commit comments

Comments
 (0)