1212rounding functions, and log transformations. The input transformation
1313is typically part of a Model and applied within the model.forward()
1414method.
15-
1615"""
1716from __future__ import annotations
1817
1918from abc import ABC , abstractmethod
2019from collections import OrderedDict
2120from typing import Any , Callable , Dict , List , Optional , Union
21+ from warnings import warn
2222
2323import torch
2424from botorch .exceptions .errors import BotorchTensorDimensionError
2525from botorch .models .transforms .utils import subset_transform
2626from botorch .models .utils import fantasize
27- from botorch .utils .rounding import approximate_round
27+ from botorch .utils .rounding import approximate_round , OneHotArgmaxSTE , RoundSTE
2828from gpytorch import Module as GPyTorchModule
2929from gpytorch .constraints import GreaterThan
3030from gpytorch .priors import Prior
@@ -649,10 +649,10 @@ def _update_coefficients(self, X: Tensor) -> None:
649649
650650
651651class Round (InputTransform , Module ):
652- r"""A rounding transformation for integer inputs.
652+ r"""A discretization transformation for discrete inputs.
653653
654- This will typically be used in conjunction with normalization as
655- follows:
654+ For integers, this will typically be used in conjunction
655+ with normalization as follows:
656656
657657 In eval() mode (i.e. after training), the inputs pass
658658 would typically be normalized to the unit cube (e.g. during candidate
@@ -667,19 +667,26 @@ class Round(InputTransform, Module):
667667 should be set to False, so that the raw inputs are rounded and then
668668 normalized to the unit cube.
669669
670- This transformation uses differentiable approximate rounding by default.
671- The rounding function is approximated with a piece-wise function where
672- each piece is a hyperbolic tangent function.
670+ By default, the straight through estimators are used for the gradients as
671+ proposed in [Daulton2022bopr]_. This transformation supports differentiable
672+ approximate rounding (currently only for integers). The rounding function
673+ is approximated with a piece-wise function where each piece is a hyperbolic
674+ tangent function.
675+
676+ For categorical parameters, the input must be one-hot encoded.
673677
674678 Example:
679+ >>> bounds = torch.tensor([[0, 5], [0, 1], [0, 1]]).t()
680+ >>> integer_indices = [0]
681+ >>> categorical_features = {1: 2}
675682 >>> unnormalize_tf = Normalize(
676683 >>> d=d,
677684 >>> bounds=bounds,
678685 >>> transform_on_eval=True,
679686 >>> transform_on_train=True,
680687 >>> reverse=True,
681688 >>> )
682- >>> round_tf = Round(integer_indices)
689+ >>> round_tf = Round(integer_indices, categorical_features )
683690 >>> normalize_tf = Normalize(d=d, bounds=bounds)
684691 >>> tf = ChainedInputTransform(
685692 >>> tf1=unnormalize_tf, tf2=round_tf, tf3=normalize_tf
@@ -688,46 +695,76 @@ class Round(InputTransform, Module):
688695
689696 def __init__ (
690697 self ,
691- indices : List [int ],
698+ integer_indices : Optional [List [int ]] = None ,
699+ categorical_features : Optional [Dict [int , int ]] = None ,
692700 transform_on_train : bool = True ,
693701 transform_on_eval : bool = True ,
694702 transform_on_fantasize : bool = True ,
695- approximate : bool = True ,
703+ approximate : bool = False ,
696704 tau : float = 1e-3 ,
705+ ** kwargs ,
697706 ) -> None :
698707 r"""Initialize transform.
699708
700709 Args:
701- indices: The indices of the integer inputs.
710+ integer_indices: The indices of the integer inputs.
711+ categorical_features: A dictionary mapping the starting index of each
712+ categorical feature to its cardinality. This assumes that categoricals
713+ are one-hot encoded.
702714 transform_on_train: A boolean indicating whether to apply the
703715 transforms in train() mode. Default: True.
704716 transform_on_eval: A boolean indicating whether to apply the
705717 transform in eval() mode. Default: True.
706718 transform_on_fantasize: A boolean indicating whether to apply the
707719 transform when called from within a `fantasize` call. Default: True.
708720 approximate: A boolean indicating whether approximate or exact
709- rounding should be used. Default: approximate .
721+ rounding should be used. Default: False .
710722 tau: The temperature parameter for approximate rounding.
711723 """
724+ indices = kwargs .get ("indices" )
725+ if indices is not None :
726+ warn (
727+ "`indices` is marked for deprecation in favor of `integer_indices`." ,
728+ DeprecationWarning ,
729+ )
730+ integer_indices = indices
731+ if approximate and categorical_features is not None :
732+ raise NotImplementedError
712733 super ().__init__ ()
713734 self .transform_on_train = transform_on_train
714735 self .transform_on_eval = transform_on_eval
715736 self .transform_on_fantasize = transform_on_fantasize
716- self .register_buffer ("indices" , torch .tensor (indices , dtype = torch .long ))
737+ integer_indices = integer_indices or []
738+ self .register_buffer (
739+ "integer_indices" , torch .tensor (integer_indices , dtype = torch .long )
740+ )
741+ self .categorical_features = categorical_features or {}
717742 self .approximate = approximate
718743 self .tau = tau
719744
720- @subset_transform
721745 def transform (self , X : Tensor ) -> Tensor :
722- r"""Round the inputs.
746+ r"""Discretize the inputs.
723747
724748 Args:
725749 X: A `batch_shape x n x d`-dim tensor of inputs.
726750
727751 Returns:
728- A `batch_shape x n x d`-dim tensor of rounded inputs.
752+ A `batch_shape x n x d`-dim tensor of discretized inputs.
729753 """
730- return approximate_round (X , tau = self .tau ) if self .approximate else X .round ()
754+ X_rounded = X .clone ()
755+ # round integers
756+ X_int = X_rounded [..., self .integer_indices ]
757+ if self .approximate :
758+ X_int = approximate_round (X_int , tau = self .tau )
759+ else :
760+ X_int = RoundSTE .apply (X_int )
761+ X_rounded [..., self .integer_indices ] = X_int
762+ # discrete categoricals to the category with the largest value
763+ # in the continuous relaxation of the one-hot encoding
764+ for start , card in self .categorical_features .items ():
765+ end = start + card
766+ X_rounded [..., start :end ] = OneHotArgmaxSTE .apply (X [..., start :end ])
767+ return X_rounded
731768
732769 def equals (self , other : InputTransform ) -> bool :
733770 r"""Check if another input transform is equivalent.
@@ -740,6 +777,8 @@ def equals(self, other: InputTransform) -> bool:
740777 """
741778 return (
742779 super ().equals (other = other )
780+ and (self .integer_indices == other .integer_indices ).all ()
781+ and self .categorical_features == other .categorical_features
743782 and self .approximate == other .approximate
744783 and self .tau == other .tau
745784 )
0 commit comments