Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion botorch/models/transforms/outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,15 @@ def forward(
f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected "
f"{self._m}."
)
stdvs = Y.std(dim=-2, keepdim=True)
if Y.shape[-2] < 1:
raise ValueError(f"Can't standardize with no observations. {Y.shape=}.")

elif Y.shape[-2] == 1:
stdvs = torch.ones(
(*Y.shape[:-2], 1, Y.shape[-1]), dtype=Y.dtype, device=Y.device
)
else:
stdvs = Y.std(dim=-2, keepdim=True)
stdvs = stdvs.where(stdvs >= self._min_stdv, torch.full_like(stdvs, 1.0))
means = Y.mean(dim=-2, keepdim=True)
if self._outputs is not None:
Expand Down
33 changes: 23 additions & 10 deletions botorch/models/utils/assorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def check_min_max_scaling(
)
if raise_on_fail:
raise InputDataError(msg)
warnings.warn(msg, InputDataWarning)
warnings.warn(msg, InputDataWarning, stacklevel=2)


def check_standardization(
Expand All @@ -191,15 +191,28 @@ def check_standardization(
raise_on_fail: If True, raise an exception instead of a warning.
"""
with torch.no_grad():
Ymean, Ystd = torch.mean(Y, dim=-2), torch.std(Y, dim=-2)
if torch.abs(Ymean).max() > atol_mean or torch.abs(Ystd - 1).max() > atol_std:
msg = (
f"Input data is not standardized (mean = {Ymean}, std = {Ystd}). "
"Please consider scaling the input to zero mean and unit variance."
)
if raise_on_fail:
raise InputDataError(msg)
warnings.warn(msg, InputDataWarning)
Ymean = torch.mean(Y, dim=-2)
mean_not_zero = torch.abs(Ymean).max() > atol_mean
if Y.shape[-2] <= 1:
if mean_not_zero:
msg = (
f"Data is not standardized (mean = {Ymean}). "
"Please consider scaling the input to zero mean and unit variance."
)
if raise_on_fail:
raise InputDataError(msg)
warnings.warn(msg, InputDataWarning, stacklevel=2)
else:
Ystd = torch.std(Y, dim=-2)
std_not_one = torch.abs(Ystd - 1).max() > atol_std
if mean_not_zero or std_not_one:
msg = (
f"Data is not standardized (std = {Ystd}, mean = {Ymean}). "
"Please consider scaling the input to zero mean and unit variance."
)
if raise_on_fail:
raise InputDataError(msg)
warnings.warn(msg, InputDataWarning, stacklevel=2)


def validate_input_scaling(
Expand Down
2 changes: 1 addition & 1 deletion botorch/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def setUp(self, suppress_input_warnings: bool = True) -> None:
)
warnings.filterwarnings(
"ignore",
message="Input data is not standardized.",
message="Data is not standardized.",
category=InputDataWarning,
)
warnings.filterwarnings(
Expand Down
32 changes: 21 additions & 11 deletions test/models/transforms/test_outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,14 @@ def test_is_linear(self) -> None:
)
self.assertEqual(posterior_is_gpt, transform._is_linear)

def test_standardize(self):
def test_standardize_raises_when_no_observations(self) -> None:
tf = Standardize(m=1)
with self.assertRaisesRegex(
ValueError, "Can't standardize with no observations."
):
tf(torch.zeros(0, 1, device=self.device), None)

def test_standardize(self) -> None:
# test error on incompatible dim
tf = Standardize(m=1)
with self.assertRaisesRegex(
Expand All @@ -134,9 +141,10 @@ def test_standardize(self):
ms = (1, 2)
batch_shapes = (torch.Size(), torch.Size([2]))
dtypes = (torch.float, torch.double)
ns = [1, 3]

# test transform, untransform, untransform_posterior
for m, batch_shape, dtype in itertools.product(ms, batch_shapes, dtypes):
for m, batch_shape, dtype, n in itertools.product(ms, batch_shapes, dtypes, ns):
# test init
tf = Standardize(m=m, batch_shape=batch_shape)
self.assertTrue(tf.training)
Expand All @@ -148,7 +156,7 @@ def test_standardize(self):
# no observation noise
with torch.random.fork_rng():
torch.manual_seed(0)
Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype)
Y = torch.rand(*batch_shape, n, m, device=self.device, dtype=dtype)
Y_tf, Yvar_tf = tf(Y, None)
self.assertTrue(tf.training)
self.assertTrue(torch.all(Y_tf.mean(dim=-2).abs() < 1e-4))
Expand All @@ -171,14 +179,16 @@ def test_standardize(self):
tf = Standardize(m=m, batch_shape=batch_shape)
with torch.random.fork_rng():
torch.manual_seed(0)
Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype)
Y = torch.rand(*batch_shape, n, m, device=self.device, dtype=dtype)
Yvar = 1e-8 + torch.rand(
*batch_shape, 3, m, device=self.device, dtype=dtype
*batch_shape, n, m, device=self.device, dtype=dtype
)
Y_tf, Yvar_tf = tf(Y, Yvar)
self.assertTrue(tf.training)
self.assertTrue(torch.all(Y_tf.mean(dim=-2).abs() < 1e-4))
Yvar_tf_expected = Yvar / Y.std(dim=-2, keepdim=True) ** 2
Yvar_tf_expected = (
Yvar if n == 1 else Yvar / Y.std(dim=-2, keepdim=True) ** 2
)
self.assertAllClose(Yvar_tf, Yvar_tf_expected)
tf.eval()
self.assertFalse(tf.training)
Expand All @@ -190,7 +200,7 @@ def test_standardize(self):
for interleaved, lazy in itertools.product((True, False), (True, False)):
if m == 1 and interleaved: # interleave has no meaning for m=1
continue
shape = batch_shape + torch.Size([3, m])
shape = batch_shape + torch.Size([n, m])
posterior = _get_test_posterior(
shape,
device=self.device,
Expand All @@ -216,12 +226,12 @@ def test_standardize(self):
# Untransform BlockDiagLinearOperator.
if m > 1:
base_lcv = DiagLinearOperator(
torch.rand(*batch_shape, m, 3, device=self.device, dtype=dtype)
torch.rand(*batch_shape, m, n, device=self.device, dtype=dtype)
)
lcv = BlockDiagLinearOperator(base_lcv)
mvn = MultitaskMultivariateNormal(
mean=torch.rand(
*batch_shape, 3, m, device=self.device, dtype=dtype
*batch_shape, n, m, device=self.device, dtype=dtype
),
covariance_matrix=lcv,
interleaved=False,
Expand All @@ -240,7 +250,7 @@ def test_standardize(self):
samples2 = p_utf.rsample(sample_shape=torch.Size([4, 2]))
self.assertEqual(
samples2.shape,
torch.Size([4, 2]) + batch_shape + torch.Size([3, m]),
torch.Size([4, 2]) + batch_shape + torch.Size([n, m]),
)

# untransform_posterior for non-GPyTorch posterior
Expand All @@ -252,7 +262,7 @@ def test_standardize(self):
)
p_utf2 = tf.untransform_posterior(posterior2)
self.assertEqual(p_utf2.device.type, self.device.type)
self.assertTrue(p_utf2.dtype == dtype)
self.assertEqual(p_utf2.dtype, dtype)
mean_expected = tf.means + tf.stdvs * posterior.mean
variance_expected = tf.stdvs**2 * posterior.variance
self.assertAllClose(p_utf2.mean, mean_expected)
Expand Down
31 changes: 23 additions & 8 deletions test/models/utils/test_assorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,26 +141,41 @@ def test_check_min_max_scaling(self):
def test_check_standardization(self):
# Ensure that it is not filtered out.
warnings.filterwarnings("always", category=InputDataWarning)
torch.manual_seed(0)
Y = torch.randn(3, 4, 2)
# check standardized input
Yst = (Y - Y.mean(dim=-2, keepdim=True)) / Y.std(dim=-2, keepdim=True)
with warnings.catch_warnings(record=True) as ws:
check_standardization(Y=Yst)
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
check_standardization(Y=Yst, raise_on_fail=True)
# check nonzero mean

# check standardized input with one observation
y = torch.zeros((3, 1, 2))
with warnings.catch_warnings(record=True) as ws:
check_standardization(Y=y)
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
check_standardization(Y=y, raise_on_fail=True)

# check nonzero mean for case where >= 2 observations per batch
msg_more_than_1_obs = r"Data is not standardized \(std ="
with self.assertWarnsRegex(InputDataWarning, msg_more_than_1_obs):
check_standardization(Y=Yst + 1)
self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws))
self.assertTrue(any("not standardized" in str(w.message) for w in ws))
with self.assertRaises(InputDataError):
with self.assertRaisesRegex(InputDataError, msg_more_than_1_obs):
check_standardization(Y=Yst + 1, raise_on_fail=True)

# check nonzero mean for case where < 2 observations per batch
msg_one_obs = r"Data is not standardized \(mean ="
y = torch.ones((3, 1, 2), dtype=torch.float32)
with self.assertWarnsRegex(InputDataWarning, msg_one_obs):
check_standardization(Y=y)
with self.assertRaisesRegex(InputDataError, msg_one_obs):
check_standardization(Y=y, raise_on_fail=True)

# check non-unit variance
with warnings.catch_warnings(record=True) as ws:
with self.assertWarnsRegex(InputDataWarning, msg_more_than_1_obs):
check_standardization(Y=Yst * 2)
self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws))
self.assertTrue(any("not standardized" in str(w.message) for w in ws))
with self.assertRaises(InputDataError):
with self.assertRaisesRegex(InputDataError, msg_more_than_1_obs):
check_standardization(Y=Yst * 2, raise_on_fail=True)

def test_validate_input_scaling(self):
Expand Down