Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions botorch/models/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ def batched_multi_output_to_single_output(
Example:
>>> train_X = torch.rand(5, 2)
>>> train_Y = torch.rand(5, 2)
>>> batch_mo_gp = SingleTaskGP(train_X, train_Y)
>>> batch_so_gp = batched_multioutput_to_single_output(batch_gp)
>>> batch_mo_gp = SingleTaskGP(train_X, train_Y, outcome_transform=None)
>>> batch_so_gp = batched_multi_output_to_single_output(batch_mo_gp)
"""
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
was_training = batch_mo_model.training
Expand Down
2 changes: 1 addition & 1 deletion botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ def _transform(self, X: Tensor) -> Tensor:
Returns:
A `batch_shape x n x d`-dim tensor of transformed inputs.
"""
self._check_shape(X)
if self.learn_coefficients and self.training:
self._check_shape(X)
self._update_coefficients(X)
self._to(X)
return (X - self.offset) / self.coefficient
Expand Down
8 changes: 4 additions & 4 deletions test/acquisition/test_proximal.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_proximal(self):
proximal_test_X = test_X.clone()
if transformed_weighting:
if input_transform is not None:
last_X = input_transform(train_X[-1])
last_X = input_transform(train_X[-1].unsqueeze(0))
proximal_test_X = input_transform(test_X)

mv_normal = MultivariateNormal(last_X, torch.diag(proximal_weights))
Expand All @@ -105,7 +105,7 @@ def test_proximal(self):
proximal_test_X = test_X.clone()
if transformed_weighting:
if input_transform is not None:
last_X = input_transform(train_X[-1])
last_X = input_transform(train_X[-1].unsqueeze(0))
proximal_test_X = input_transform(test_X)

mv_normal = MultivariateNormal(last_X, torch.diag(proximal_weights))
Expand All @@ -122,7 +122,7 @@ def test_proximal(self):
proximal_test_X = test_X.clone()
if transformed_weighting:
if input_transform is not None:
last_X = input_transform(train_X[-1])
last_X = input_transform(train_X[-1].unsqueeze(0))
proximal_test_X = input_transform(test_X)

ei = EI(test_X)
Expand All @@ -143,7 +143,7 @@ def test_proximal(self):
proximal_test_X = test_X.clone()
if transformed_weighting:
if input_transform is not None:
last_X = input_transform(train_X[-1])
last_X = input_transform(train_X[-1].unsqueeze(0))
proximal_test_X = input_transform(test_X)

qEI_prox = ProximalAcquisitionFunction(
Expand Down
2 changes: 1 addition & 1 deletion test/models/test_approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,5 +327,5 @@ def test_input_transform(self) -> None:
model.likelihood, model.model, num_data=train_X.shape[-2]
)
fit_gpytorch_mll(mll)
post = model.posterior(torch.tensor([train_X.mean()]))
post = model.posterior(torch.tensor([[train_X.mean()]]))
self.assertAllClose(post.mean[0][0], y.mean(), atol=1e-3, rtol=1e-3)
15 changes: 11 additions & 4 deletions test/models/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,21 @@ def test_model_list_to_batched(self):
batch_shape=torch.Size([3]),
)
gp1_ = SingleTaskGP(
train_X, train_Y1, input_transform=input_tf2, outcome_transform=None
train_X=train_X.unsqueeze(0),
train_Y=train_Y1.unsqueeze(0),
input_transform=input_tf2,
outcome_transform=None,
)
gp2_ = SingleTaskGP(
train_X, train_Y2, input_transform=input_tf2, outcome_transform=None
train_X=train_X.unsqueeze(0),
train_Y=train_Y2.unsqueeze(0),
input_transform=input_tf2,
outcome_transform=None,
)
list_gp = ModelListGP(gp1_, gp2_)
with self.assertRaises(UnsupportedError):
with self.assertRaisesRegex(
UnsupportedError, "Batched input_transforms are not supported."
):
model_list_to_batched(list_gp)

# test outcome transform
Expand Down Expand Up @@ -457,7 +465,6 @@ def test_batched_multi_output_to_single_output(self):
bounds=torch.tensor(
[[-1.0, -1.0], [1.0, 1.0]], device=self.device, dtype=dtype
),
batch_shape=torch.Size([2]),
)
batched_mo_model = SingleTaskGP(
train_X, train_Y, input_transform=input_tf, outcome_transform=None
Expand Down
17 changes: 12 additions & 5 deletions test/models/transforms/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,10 @@ def test_normalize(self) -> None:
self.assertTrue(nlz.mins.dtype == other_dtype)
# test incompatible dimensions of specified bounds
bounds = torch.zeros(2, 3, device=self.device, dtype=dtype)
with self.assertRaises(BotorchTensorDimensionError):
with self.assertRaisesRegex(
BotorchTensorDimensionError,
"Dimensions of provided `bounds` are incompatible",
):
Normalize(d=2, bounds=bounds)

# test jitter
Expand Down Expand Up @@ -266,7 +269,12 @@ def test_normalize(self) -> None:
# test errors on wrong shape
nlz = Normalize(d=2, batch_shape=batch_shape)
X = torch.randn(*batch_shape, 2, 1, device=self.device, dtype=dtype)
with self.assertRaises(BotorchTensorDimensionError):
expected_msg = "Wrong input dimension. Received 1, expected 2."
with self.assertRaisesRegex(BotorchTensorDimensionError, expected_msg):
nlz(X)
# Same error in eval mode
nlz.eval()
with self.assertRaisesRegex(BotorchTensorDimensionError, expected_msg):
nlz(X)

# fixed bounds
Expand Down Expand Up @@ -328,9 +336,8 @@ def test_normalize(self) -> None:
[X.min(dim=-2, keepdim=True)[0], X.max(dim=-2, keepdim=True)[0]],
dim=-2,
)[..., indices]
self.assertTrue(
torch.allclose(nlz.bounds, expected_bounds, atol=1e-4, rtol=1e-4)
)
self.assertAllClose(nlz.bounds, expected_bounds, atol=1e-4, rtol=1e-4)

# test errors on wrong shape
nlz = Normalize(d=2, batch_shape=batch_shape)
X = torch.randn(*batch_shape, 2, 1, device=self.device, dtype=dtype)
Expand Down
8 changes: 5 additions & 3 deletions test_community/models/test_gp_regression_multisource.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def _get_model_and_data(
None if train_Yvar else get_gaussian_likelihood_with_gamma_prior()
),
}
model = SingleTaskAugmentedGP(**model_kwargs, **extra_model_kwargs)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=OptimizationWarning)
model = SingleTaskAugmentedGP(**model_kwargs, **extra_model_kwargs)
return model, model_kwargs

def test_data_init(self):
Expand Down Expand Up @@ -139,8 +141,8 @@ def test_get_reliable_observation(self):
self.assertListEqual(res.tolist(), true_res.tolist())

def test_gp(self):
bounds = torch.tensor([[-1.0], [1.0]])
d = 5
bounds = torch.stack((torch.full((d - 1,), -1), torch.ones(d - 1)))
for batch_shape, dtype, use_octf, use_intf, train_Yvar in itertools.product(
(torch.Size(), torch.Size([2])),
(torch.float, torch.double),
Expand All @@ -151,7 +153,7 @@ def test_gp(self):
tkwargs = {"device": self.device, "dtype": dtype}
octf = Standardize(m=1, batch_shape=torch.Size()) if use_octf else None
intf = (
Normalize(d=1, bounds=bounds.to(**tkwargs), transform_on_train=True)
Normalize(d=d - 1, bounds=bounds.to(**tkwargs), transform_on_train=True)
if use_intf
else None
)
Expand Down