Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions botorch/models/utils/assorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def check_min_max_scaling(
msg = "contained"
if msg is not None:
msg = (
f"Input data is not {msg} to the unit cube. "
f"Data (input features) not {msg} to the unit cube. "
"Please consider min-max scaling the input data."
)
if raise_on_fail:
Expand Down Expand Up @@ -197,7 +197,7 @@ def check_standardization(
if Y.shape[-2] <= 1:
if mean_not_zero:
msg = (
f"Data is not standardized (mean = {Ymean}). "
f"Data (outcome observations) not standardized (mean = {Ymean}). "
"Please consider scaling the input to zero mean and unit variance."
)
if raise_on_fail:
Expand All @@ -208,7 +208,8 @@ def check_standardization(
std_not_one = torch.abs(Ystd - 1).max() > atol_std
if mean_not_zero or std_not_one:
msg = (
f"Data is not standardized (std = {Ystd}, mean = {Ymean}). "
"Data (outcome observations) not standardized "
f"(std = {Ystd}, mean = {Ymean})."
"Please consider scaling the input to zero mean and unit variance."
)
if raise_on_fail:
Expand Down
4 changes: 2 additions & 2 deletions test/models/utils/test_assorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,14 @@ def test_check_standardization(self):
check_standardization(Y=y, raise_on_fail=True)

# check nonzero mean for case where >= 2 observations per batch
msg_more_than_1_obs = r"Data is not standardized \(std ="
msg_more_than_1_obs = r"Data \(outcome observations\) not standardized \(std ="
with self.assertWarnsRegex(InputDataWarning, msg_more_than_1_obs):
check_standardization(Y=Yst + 1)
with self.assertRaisesRegex(InputDataError, msg_more_than_1_obs):
check_standardization(Y=Yst + 1, raise_on_fail=True)

# check nonzero mean for case where < 2 observations per batch
msg_one_obs = r"Data is not standardized \(mean ="
msg_one_obs = r"Data \(outcome observations\) not standardized \(mean ="
y = torch.ones((3, 1, 2), dtype=torch.float32)
with self.assertWarnsRegex(InputDataWarning, msg_one_obs):
check_standardization(Y=y)
Expand Down