Skip to content

Commit 3384c24

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Do not normalize or standardize dimension if all values are equal (#2185)
Summary: Pull Request resolved: #2185 Issue description with `Normalize` transform: Suppose that the train data has x0 as a constant (can happen with few data points) but it is being optimized in [0, 1]. In the current behavior, we first calculate a coefficient of 0.0, then clamp this up to 1e-8. During acqf optimization, we will evaluate the model with values in [0, 1], which will then get normalized to [0, 1e8]. This can cause numerical issues in GPyTorch and lead to non-PSD covariance matrices since the model was trained with constant inputs, likely learning much more reasonable lengthscales that don't play well with these large values. This diff updates the behavior of `min_range/min_std` in `Normalize/InputStandardize` transforms to skip transforming the given dimension if the range / std of the dimension is less than the minimum. This is achieved using an offset of 0 and a coefficient of 1 for the given dimension. Reviewed By: esantorella Differential Revision: D53213759 fbshipit-source-id: 9f738e9c6654e184f6e8a74bb8abe8a530290691
1 parent be7a58d commit 3384c24

File tree

2 files changed

+61
-18
lines changed

2 files changed

+61
-18
lines changed

botorch/models/transforms/input.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from warnings import warn
2222

2323
import numpy as np
24-
2524
import torch
2625
from botorch.exceptions.errors import BotorchTensorDimensionError
2726
from botorch.exceptions.warnings import UserInputWarning
@@ -513,6 +512,7 @@ def __init__(
513512
reverse: bool = False,
514513
min_range: float = 1e-8,
515514
learn_bounds: Optional[bool] = None,
515+
almost_zero: float = 1e-12,
516516
) -> None:
517517
r"""Normalize the inputs to the unit cube.
518518
@@ -533,10 +533,28 @@ def __init__(
533533
transform when called from within a `fantasize` call. Default: True.
534534
reverse: A boolean indicating whether the forward pass should untransform
535535
the inputs.
536-
min_range: Amount of noise to add to the range to ensure no division by
537-
zero errors.
536+
min_range: If the range of an input dimension is smaller than `min_range`,
537+
that input dimension will not be normalized. This is equivalent to
538+
using bounds of `[0, 1]` for this dimension, and helps avoid division
539+
by zero errors and related numerical issues. See the example below.
540+
NOTE: This only applies if `learn_bounds=True`.
538541
learn_bounds: Whether to learn the bounds in train mode. Defaults
539542
to False if bounds are provided, otherwise defaults to True.
543+
544+
Example:
545+
>>> t = Normalize(d=2)
546+
>>> t(torch.tensor([[3., 2.], [3., 6.]]))
547+
... tensor([[3., 2.],
548+
... [3., 6.]])
549+
>>> t.eval()
550+
... Normalize()
551+
>>> t(torch.tensor([[3.5, 2.8]]))
552+
... tensor([[3.5, 0.2]])
553+
>>> t.bounds
554+
... tensor([[0., 2.],
555+
... [1., 6.]])
556+
>>> t.coefficient
557+
... tensor([[1., 4.]])
540558
"""
541559
if learn_bounds is not None:
542560
self.learn_coefficients = learn_bounds
@@ -601,9 +619,11 @@ def _update_coefficients(self, X) -> None:
601619
# Aggregate mins and ranges over extra batch and marginal dims
602620
batch_ndim = min(len(self.batch_shape), X.ndim - 2) # batch rank of `X`
603621
reduce_dims = (*range(X.ndim - batch_ndim - 2), X.ndim - 2)
604-
self._offset = torch.amin(X, dim=reduce_dims).unsqueeze(-2)
605-
self._coefficient = torch.amax(X, dim=reduce_dims).unsqueeze(-2) - self.offset
606-
self._coefficient.clamp_(min=self.min_range)
622+
offset = torch.amin(X, dim=reduce_dims).unsqueeze(-2)
623+
coefficient = torch.amax(X, dim=reduce_dims).unsqueeze(-2) - offset
624+
almost_zero = coefficient < self.min_range
625+
self._coefficient = torch.where(almost_zero, 1.0, coefficient)
626+
self._offset = torch.where(almost_zero, 0.0, offset)
607627

608628
def get_init_args(self) -> Dict[str, Any]:
609629
r"""Get the arguments necessary to construct an exact copy of the transform."""
@@ -655,8 +675,11 @@ def __init__(
655675
transform in eval() mode. Default: True
656676
reverse: A boolean indicating whether the forward pass should untransform
657677
the inputs.
658-
min_std: Amount of noise to add to the standard deviation to ensure no
659-
division by zero errors.
678+
min_std: If the standard deviation of an input dimension is smaller than
679+
`min_std`, that input dimension will not be standardized. This is
680+
equivalent to using a standard deviation of 1.0 and a mean of 0.0 for
681+
this dimension, and helps avoid division by zero errors and related
682+
numerical issues.
660683
"""
661684
transform_dimension = d if indices is None else len(indices)
662685
super().__init__(
@@ -688,11 +711,13 @@ def _update_coefficients(self, X: Tensor) -> None:
688711
# Aggregate means and standard deviations over extra batch and marginal dims
689712
batch_ndim = min(len(self.batch_shape), X.ndim - 2) # batch rank of `X`
690713
reduce_dims = (*range(X.ndim - batch_ndim - 2), X.ndim - 2)
691-
coefficient, self._offset = (
714+
coefficient, offset = (
692715
values.unsqueeze(-2)
693716
for values in torch.std_mean(X, dim=reduce_dims, unbiased=True)
694717
)
695-
self._coefficient = coefficient.clamp_(min=self.min_std)
718+
almost_zero = coefficient < self.min_std
719+
self._coefficient = torch.where(almost_zero, 1.0, coefficient)
720+
self._offset = torch.where(almost_zero, 0.0, offset)
696721

697722

698723
class Round(InputTransform, Module):

test/models/transforms/test_input.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,8 @@ def test_normalize(self) -> None:
229229
nlz.to(other_dtype)
230230
self.assertTrue(nlz.mins.dtype == other_dtype)
231231
# test incompatible dimensions of specified bounds
232+
bounds = torch.zeros(2, 3, device=self.device, dtype=dtype)
232233
with self.assertRaises(BotorchTensorDimensionError):
233-
bounds = torch.zeros(2, 3, device=self.device, dtype=dtype)
234234
Normalize(d=2, bounds=bounds)
235235

236236
# test jitter
@@ -380,7 +380,25 @@ def test_normalize(self) -> None:
380380
self.assertIsNone(nlz.coefficient.grad_fn)
381381
self.assertIsNone(nlz.offset.grad_fn)
382382

383-
def test_standardize(self):
383+
# test that zero range is not scaled.
384+
nlz = Normalize(d=2)
385+
X = torch.tensor([[1.0, 0.0], [1.0, 2.0]], device=self.device, dtype=dtype)
386+
nlzd_X = nlz(X)
387+
self.assertAllClose(
388+
nlz.coefficient,
389+
torch.tensor([[1.0, 2.0]], device=self.device, dtype=dtype),
390+
)
391+
expected_X = torch.tensor(
392+
[[1.0, 0.0], [1.0, 1.0]], device=self.device, dtype=dtype
393+
)
394+
self.assertAllClose(nlzd_X, expected_X)
395+
nlz.eval()
396+
X = torch.tensor([[1.5, 1.5]], device=self.device, dtype=dtype)
397+
nlzd_X = nlz(X)
398+
expected_X = torch.tensor([[1.5, 0.75]], device=self.device, dtype=dtype)
399+
self.assertAllClose(nlzd_X, expected_X)
400+
401+
def test_standardize(self) -> None:
384402
for dtype in (torch.float, torch.double):
385403
# basic init
386404
stdz = InputStandardize(d=2)
@@ -527,7 +545,7 @@ def test_standardize(self):
527545
stdz8 = InputStandardize(d=3, batch_shape=batch_shape, indices=[0, 2])
528546
self.assertFalse(stdz7.equals(stdz8))
529547

530-
def test_chained_input_transform(self):
548+
def test_chained_input_transform(self) -> None:
531549
ds = (1, 2)
532550
batch_shapes = (torch.Size(), torch.Size([2]))
533551
dtypes = (torch.float, torch.double)
@@ -1157,7 +1175,7 @@ def test_one_hot_to_numeric(self) -> None:
11571175

11581176

11591177
class TestAppendFeatures(BotorchTestCase):
1160-
def test_append_features(self):
1178+
def test_append_features(self) -> None:
11611179
with self.assertRaises(ValueError):
11621180
AppendFeatures(torch.ones(1))
11631181
with self.assertRaises(ValueError):
@@ -1198,7 +1216,7 @@ def test_append_features(self):
11981216
self.assertEqual(transform.feature_set.device.type, "cpu")
11991217
self.assertEqual(transform.feature_set.dtype, torch.half)
12001218

1201-
def test_w_skip_expand(self):
1219+
def test_w_skip_expand(self) -> None:
12021220
for dtype in (torch.float, torch.double):
12031221
tkwargs = {"device": self.device, "dtype": dtype}
12041222
feature_set = torch.tensor([[0.0], [1.0]], **tkwargs)
@@ -1221,7 +1239,7 @@ def test_w_skip_expand(self):
12211239
tf_X = append_tf(pert_tf(test_X.expand(3, 5, -1, -1)))
12221240
self.assertAllClose(tf_X, expected_X.expand(3, 5, -1, -1))
12231241

1224-
def test_w_f(self):
1242+
def test_w_f(self) -> None:
12251243
def f1(x: Tensor, n_f: int = 1) -> Tensor:
12261244
result = torch.sum(x, dim=-1, keepdim=True).unsqueeze(-2)
12271245
return result.expand(*result.shape[:-2], n_f, -1)
@@ -1453,7 +1471,7 @@ def f2(x: Tensor, n_f: int = 1) -> Tensor:
14531471

14541472

14551473
class TestFilterFeatures(BotorchTestCase):
1456-
def test_filter_features(self):
1474+
def test_filter_features(self) -> None:
14571475
with self.assertRaises(ValueError):
14581476
FilterFeatures(torch.tensor([[1, 2]], dtype=torch.long))
14591477
with self.assertRaises(ValueError):
@@ -1527,7 +1545,7 @@ def test_filter_features(self):
15271545

15281546

15291547
class TestInputPerturbation(BotorchTestCase):
1530-
def test_input_perturbation(self):
1548+
def test_input_perturbation(self) -> None:
15311549
with self.assertRaisesRegex(ValueError, "-dim tensor!"):
15321550
InputPerturbation(torch.ones(1))
15331551
with self.assertRaisesRegex(ValueError, "-dim tensor!"):

0 commit comments

Comments
 (0)