@@ -1000,6 +1000,8 @@ def optimize_acqf_discrete(
10001000 choices : Tensor ,
10011001 max_batch_size : int = 2048 ,
10021002 unique : bool = True ,
1003+ X_avoid : Tensor | None = None ,
1004+ inequality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
10031005) -> tuple [Tensor , Tensor ]:
10041006 r"""Optimize over a discrete set of points using batch evaluation.
10051007
@@ -1017,6 +1019,12 @@ def optimize_acqf_discrete(
10171019 a large training set.
10181020 unique: If True return unique choices, o/w choices may be repeated
10191021 (only relevant if `q > 1`).
1022+ X_avoid: An `n x d` tensor of candidates that we aren't allowed to pick.
1023+ These will be removed from the set of choices.
1024+ inequality constraints: A list of tuples (indices, coefficients, rhs),
1025+ with each tuple encoding an inequality constraint of the form
1026+ `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
1027+ Infeasible points will be removed from the set of choices.
10201028
10211029 Returns:
10221030 A two-element tuple containing
@@ -1029,8 +1037,31 @@ def optimize_acqf_discrete(
10291037 "Discrete optimization is not supported for"
10301038 "one-shot acquisition functions."
10311039 )
1032- if choices .numel () == 0 :
1033- raise InputDataError ("`choices` must be non-emtpy." )
1040+ if X_avoid is not None and unique :
1041+ choices = _filter_invalid (X = choices , X_avoid = X_avoid )
1042+ if inequality_constraints is not None :
1043+ choices = _filter_infeasible (
1044+ X = choices , inequality_constraints = inequality_constraints
1045+ )
1046+ len_choices = len (choices )
1047+ if len_choices == 0 :
1048+ message = "`choices` must be non-empty."
1049+ if X_avoid is not None or inequality_constraints is not None :
1050+ message += (
1051+ " No feasible points remain after removing `X_avoid` and "
1052+ "filtering out infeasible points."
1053+ )
1054+ raise InputDataError (message )
1055+ elif len_choices < q and unique :
1056+ warnings .warn (
1057+ (
1058+ f"Requested { q = } candidates from fully discrete search "
1059+ f"space, but only { len_choices } possible choices remain. "
1060+ ),
1061+ OptimizationWarning ,
1062+ stacklevel = 2 ,
1063+ )
1064+ q = len_choices
10341065 choices_batched = choices .unsqueeze (- 2 )
10351066 if q > 1 :
10361067 candidate_list , acq_value_list = [], []
@@ -1081,7 +1112,7 @@ def _generate_neighbors(
10811112 discrete_choices : list [Tensor ],
10821113 X_avoid : Tensor ,
10831114 inequality_constraints : list [tuple [Tensor , Tensor , float ]],
1084- ):
1115+ ) -> Tensor :
10851116 # generate all 1D perturbations
10861117 npts = sum ([len (c ) for c in discrete_choices ])
10871118 X_loc = x .repeat (npts , 1 )
@@ -1097,15 +1128,15 @@ def _generate_neighbors(
10971128
10981129def _filter_infeasible (
10991130 X : Tensor , inequality_constraints : list [tuple [Tensor , Tensor , float ]]
1100- ):
1131+ ) -> Tensor :
11011132 """Remove all points from `X` that don't satisfy the constraints."""
11021133 is_feasible = torch .ones (X .shape [0 ], dtype = torch .bool , device = X .device )
11031134 for inds , weights , bound in inequality_constraints :
11041135 is_feasible &= (X [..., inds ] * weights ).sum (dim = - 1 ) >= bound
11051136 return X [is_feasible ]
11061137
11071138
1108- def _filter_invalid (X : Tensor , X_avoid : Tensor ):
1139+ def _filter_invalid (X : Tensor , X_avoid : Tensor ) -> Tensor :
11091140 """Remove all occurences of `X_avoid` from `X`."""
11101141 return X [~ (X == X_avoid .unsqueeze (- 2 )).all (dim = - 1 ).any (dim = - 2 )]
11111142
0 commit comments