Skip to content

Commit 054cd5c

Browse files
committed
Merge branch 'main' into entropy_search
2 parents fb459a1 + 706e2a1 commit 054cd5c

File tree

6 files changed

+293
-36
lines changed

6 files changed

+293
-36
lines changed

botorch/models/transforms/input.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,19 @@
1212
rounding functions, and log transformations. The input transformation
1313
is typically part of a Model and applied within the model.forward()
1414
method.
15-
1615
"""
1716
from __future__ import annotations
1817

1918
from abc import ABC, abstractmethod
2019
from collections import OrderedDict
2120
from typing import Any, Callable, Dict, List, Optional, Union
21+
from warnings import warn
2222

2323
import torch
2424
from botorch.exceptions.errors import BotorchTensorDimensionError
2525
from botorch.models.transforms.utils import subset_transform
2626
from botorch.models.utils import fantasize
27-
from botorch.utils.rounding import approximate_round
27+
from botorch.utils.rounding import approximate_round, OneHotArgmaxSTE, RoundSTE
2828
from gpytorch import Module as GPyTorchModule
2929
from gpytorch.constraints import GreaterThan
3030
from gpytorch.priors import Prior
@@ -649,10 +649,10 @@ def _update_coefficients(self, X: Tensor) -> None:
649649

650650

651651
class Round(InputTransform, Module):
652-
r"""A rounding transformation for integer inputs.
652+
r"""A discretization transformation for discrete inputs.
653653
654-
This will typically be used in conjunction with normalization as
655-
follows:
654+
For integers, this will typically be used in conjunction
655+
with normalization as follows:
656656
657657
In eval() mode (i.e. after training), the inputs pass
658658
would typically be normalized to the unit cube (e.g. during candidate
@@ -667,19 +667,26 @@ class Round(InputTransform, Module):
667667
should be set to False, so that the raw inputs are rounded and then
668668
normalized to the unit cube.
669669
670-
This transformation uses differentiable approximate rounding by default.
671-
The rounding function is approximated with a piece-wise function where
672-
each piece is a hyperbolic tangent function.
670+
By default, the straight through estimators are used for the gradients as
671+
proposed in [Daulton2022bopr]_. This transformation supports differentiable
672+
approximate rounding (currently only for integers). The rounding function
673+
is approximated with a piece-wise function where each piece is a hyperbolic
674+
tangent function.
675+
676+
For categorical parameters, the input must be one-hot encoded.
673677
674678
Example:
679+
>>> bounds = torch.tensor([[0, 5], [0, 1], [0, 1]]).t()
680+
>>> integer_indices = [0]
681+
>>> categorical_features = {1: 2}
675682
>>> unnormalize_tf = Normalize(
676683
>>> d=d,
677684
>>> bounds=bounds,
678685
>>> transform_on_eval=True,
679686
>>> transform_on_train=True,
680687
>>> reverse=True,
681688
>>> )
682-
>>> round_tf = Round(integer_indices)
689+
>>> round_tf = Round(integer_indices, categorical_features)
683690
>>> normalize_tf = Normalize(d=d, bounds=bounds)
684691
>>> tf = ChainedInputTransform(
685692
>>> tf1=unnormalize_tf, tf2=round_tf, tf3=normalize_tf
@@ -688,46 +695,76 @@ class Round(InputTransform, Module):
688695

689696
def __init__(
690697
self,
691-
indices: List[int],
698+
integer_indices: Optional[List[int]] = None,
699+
categorical_features: Optional[Dict[int, int]] = None,
692700
transform_on_train: bool = True,
693701
transform_on_eval: bool = True,
694702
transform_on_fantasize: bool = True,
695-
approximate: bool = True,
703+
approximate: bool = False,
696704
tau: float = 1e-3,
705+
**kwargs,
697706
) -> None:
698707
r"""Initialize transform.
699708
700709
Args:
701-
indices: The indices of the integer inputs.
710+
integer_indices: The indices of the integer inputs.
711+
categorical_features: A dictionary mapping the starting index of each
712+
categorical feature to its cardinality. This assumes that categoricals
713+
are one-hot encoded.
702714
transform_on_train: A boolean indicating whether to apply the
703715
transforms in train() mode. Default: True.
704716
transform_on_eval: A boolean indicating whether to apply the
705717
transform in eval() mode. Default: True.
706718
transform_on_fantasize: A boolean indicating whether to apply the
707719
transform when called from within a `fantasize` call. Default: True.
708720
approximate: A boolean indicating whether approximate or exact
709-
rounding should be used. Default: approximate.
721+
rounding should be used. Default: False.
710722
tau: The temperature parameter for approximate rounding.
711723
"""
724+
indices = kwargs.get("indices")
725+
if indices is not None:
726+
warn(
727+
"`indices` is marked for deprecation in favor of `integer_indices`.",
728+
DeprecationWarning,
729+
)
730+
integer_indices = indices
731+
if approximate and categorical_features is not None:
732+
raise NotImplementedError
712733
super().__init__()
713734
self.transform_on_train = transform_on_train
714735
self.transform_on_eval = transform_on_eval
715736
self.transform_on_fantasize = transform_on_fantasize
716-
self.register_buffer("indices", torch.tensor(indices, dtype=torch.long))
737+
integer_indices = integer_indices or []
738+
self.register_buffer(
739+
"integer_indices", torch.tensor(integer_indices, dtype=torch.long)
740+
)
741+
self.categorical_features = categorical_features or {}
717742
self.approximate = approximate
718743
self.tau = tau
719744

720-
@subset_transform
721745
def transform(self, X: Tensor) -> Tensor:
722-
r"""Round the inputs.
746+
r"""Discretize the inputs.
723747
724748
Args:
725749
X: A `batch_shape x n x d`-dim tensor of inputs.
726750
727751
Returns:
728-
A `batch_shape x n x d`-dim tensor of rounded inputs.
752+
A `batch_shape x n x d`-dim tensor of discretized inputs.
729753
"""
730-
return approximate_round(X, tau=self.tau) if self.approximate else X.round()
754+
X_rounded = X.clone()
755+
# round integers
756+
X_int = X_rounded[..., self.integer_indices]
757+
if self.approximate:
758+
X_int = approximate_round(X_int, tau=self.tau)
759+
else:
760+
X_int = RoundSTE.apply(X_int)
761+
X_rounded[..., self.integer_indices] = X_int
762+
# discrete categoricals to the category with the largest value
763+
# in the continuous relaxation of the one-hot encoding
764+
for start, card in self.categorical_features.items():
765+
end = start + card
766+
X_rounded[..., start:end] = OneHotArgmaxSTE.apply(X[..., start:end])
767+
return X_rounded
731768

732769
def equals(self, other: InputTransform) -> bool:
733770
r"""Check if another input transform is equivalent.
@@ -740,6 +777,8 @@ def equals(self, other: InputTransform) -> bool:
740777
"""
741778
return (
742779
super().equals(other=other)
780+
and (self.integer_indices == other.integer_indices).all()
781+
and self.categorical_features == other.categorical_features
743782
and self.approximate == other.approximate
744783
and self.tau == other.tau
745784
)

botorch/optim/initializers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def gen_batch_initial_conditions(
114114
batch_limit: Optional[int] = options.get(
115115
"init_batch_limit", options.get("batch_limit")
116116
)
117-
batch_initial_arms: Tensor
118117
factor, max_factor = 1, 5
119118
init_kwargs = {}
120119
device = bounds.device

botorch/test_functions/multi_objective.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
1212
.. [Daulton2022]
1313
S. Daulton, S. Cakmak, M. Balandat, M. A. Osborne, E. Zhou, and E. Bakshy.
14-
Robust Multi-Objective Bayesian Optimization Under Input Noise. 2022.
14+
Robust Multi-Objective Bayesian Optimization Under Input Noise.
15+
Proceedings of the 39th International Conference on Machine Learning, 2022.
1516
1617
.. [Deb2005dtlz]
1718
K. Deb, L. Thiele, M. Laumanns, E. Zitzler, A. Abraham, L. Jain, and

botorch/utils/rounding.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,24 @@
44
# This source code is licensed under the MIT license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
r"""
8+
Discretization (rounding) functions for acquisition optimization.
9+
10+
References
11+
12+
.. [Daulton2022bopr]
13+
S. Daulton, X. Wan, D. Eriksson, M. Balandat, M. A. Osborne, E. Bakshy.
14+
Bayesian Optimization over Discrete and Mixed Spaces via Probabilistic
15+
Reparameterization. Advances in Neural Information Processing Systems
16+
35, 2022.
17+
"""
18+
719
from __future__ import annotations
820

921
import torch
1022
from torch import Tensor
23+
from torch.autograd import Function
24+
from torch.nn.functional import one_hot
1125

1226

1327
def approximate_round(X: Tensor, tau: float = 1e-3) -> Tensor:
@@ -27,3 +41,68 @@ def approximate_round(X: Tensor, tau: float = 1e-3) -> Tensor:
2741
scaled_remainder = (X - offset - 0.5) / tau
2842
rounding_component = (torch.tanh(scaled_remainder) + 1) / 2
2943
return offset + rounding_component
44+
45+
46+
class IdentitySTEFunction(Function):
47+
"""Base class for functions using straight through gradient estimators.
48+
49+
This class approximates the gradient with the identity function.
50+
"""
51+
52+
@staticmethod
53+
def backward(ctx, grad_output: Tensor) -> Tensor:
54+
r"""Use a straight-through estimator the gradient.
55+
56+
This uses the identity function.
57+
58+
Args:
59+
grad_output: A tensor of gradients.
60+
61+
Returns:
62+
The provided tensor.
63+
"""
64+
return grad_output
65+
66+
67+
class RoundSTE(IdentitySTEFunction):
68+
r"""Round the input tensor and use a straight-through gradient estimator.
69+
70+
[Daulton2022bopr]_ proposes using this in acquisition optimization.
71+
"""
72+
73+
@staticmethod
74+
def forward(ctx, X: Tensor) -> Tensor:
75+
r"""Round the input tensor element-wise.
76+
77+
Args:
78+
X: The tensor to be rounded.
79+
80+
Returns:
81+
A tensor where each element is rounded to the nearest integer.
82+
"""
83+
return X.round()
84+
85+
86+
class OneHotArgmaxSTE(IdentitySTEFunction):
87+
r"""Discretize a continuous relaxation of a one-hot encoded categorical.
88+
89+
This returns a one-hot encoded categorical and use a straight-through
90+
gradient estimator via an identity function.
91+
92+
[Daulton2022bopr]_ proposes using this in acquisition optimization.
93+
"""
94+
95+
@staticmethod
96+
def forward(ctx, X: Tensor) -> Tensor:
97+
r"""Discretize the input tensor.
98+
99+
This applies a argmax along the last dimensions of the input tensor
100+
and one-hot encodes the result.
101+
102+
Args:
103+
X: The tensor to be rounded.
104+
105+
Returns:
106+
A tensor where each element is rounded to the nearest integer.
107+
"""
108+
return one_hot(X.argmax(dim=-1), num_classes=X.shape[-1]).to(X)

0 commit comments

Comments
 (0)