@@ -1046,40 +1046,89 @@ def test_gen_one_shot_kg_initial_conditions(self):
1046
1046
raw_samples = raw_samples ,
1047
1047
options = {"frac_random" : 2.0 },
1048
1048
)
1049
+
1050
+ forbidden_indices = torch .tensor ([[0 , 1 ], [0 , 1 ]])
1051
+ for constraint_type in [
1052
+ "inequality_constraints" ,
1053
+ "equality_constraints" ,
1054
+ ]:
1055
+ constraint_kwarg = {
1056
+ constraint_type : [
1057
+ (
1058
+ forbidden_indices ,
1059
+ torch .ones (2 ).to (mean ),
1060
+ 0.5 ,
1061
+ )
1062
+ ]
1063
+ }
1064
+ with self .subTest (
1065
+ f"intra-point { constraint_type } not supported"
1066
+ ), self .assertRaisesRegex (
1067
+ NotImplementedError ,
1068
+ "Indices must be one-dimensional "
1069
+ "in gen_one_shot_kg_initial_conditions. "
1070
+ "Received indices" ,
1071
+ ):
1072
+ gen_one_shot_kg_initial_conditions (
1073
+ acq_function = mock_kg ,
1074
+ bounds = bounds ,
1075
+ q = 1 ,
1076
+ num_restarts = num_restarts ,
1077
+ raw_samples = raw_samples ,
1078
+ ** constraint_kwarg ,
1079
+ )
1080
+
1049
1081
# test generation logic
1050
1082
q = 2
1051
1083
mock_random_ics = torch .rand (num_restarts , q + num_fantasies , 2 )
1052
1084
mock_fantasy_cands = torch .ones (20 , 1 , 2 )
1053
1085
mock_fantasy_vals = torch .randn (20 )
1054
- with ExitStack () as es :
1055
- mock_gbics = es .enter_context (
1056
- mock .patch (
1057
- "botorch.optim.initializers.gen_batch_initial_conditions" ,
1058
- return_value = mock_random_ics ,
1059
- )
1060
- )
1061
- mock_optacqf = es .enter_context (
1062
- mock .patch (
1063
- "botorch.optim.optimize.optimize_acqf" ,
1064
- return_value = (mock_fantasy_cands , mock_fantasy_vals ),
1086
+ for constraint_kwargs in [
1087
+ {},
1088
+ {
1089
+ "inequality_constraints" : [
1090
+ (torch .tensor ([0 , 1 ]), torch .ones (2 ).to (mean ), 0.5 )
1091
+ ]
1092
+ }, # test that no error is raised
1093
+ {
1094
+ "equality_constraints" : [
1095
+ (torch .tensor ([0 , 1 ]), torch .ones (2 ).to (mean ), 0.5 )
1096
+ ]
1097
+ # test that no error is raised
1098
+ },
1099
+ ]:
1100
+ with ExitStack () as es :
1101
+ mock_gbics = es .enter_context (
1102
+ mock .patch (
1103
+ "botorch.optim.initializers.gen_batch_initial_conditions" ,
1104
+ return_value = mock_random_ics ,
1105
+ )
1065
1106
)
1066
- )
1067
- ics = gen_one_shot_kg_initial_conditions (
1068
- acq_function = mock_kg ,
1069
- bounds = bounds ,
1070
- q = q ,
1071
- num_restarts = num_restarts ,
1072
- raw_samples = raw_samples ,
1073
- )
1074
- mock_gbics .assert_called_once ()
1075
- mock_optacqf .assert_called_once ()
1076
- n_value = int ((1 - 0.1 ) * num_fantasies )
1077
- self .assertTrue (
1078
- torch .equal (
1079
- ics [..., :- n_value , :], mock_random_ics [..., :- n_value , :]
1107
+ mock_optacqf = es .enter_context (
1108
+ mock .patch (
1109
+ "botorch.optim.optimize.optimize_acqf" ,
1110
+ return_value = (mock_fantasy_cands , mock_fantasy_vals ),
1111
+ )
1080
1112
)
1081
- )
1082
- self .assertTrue (torch .all (ics [..., - n_value :, :] == 1 ))
1113
+ with self .subTest (f"Main test with { constraint_kwargs } " ):
1114
+ ics = gen_one_shot_kg_initial_conditions (
1115
+ acq_function = mock_kg ,
1116
+ bounds = bounds ,
1117
+ q = q ,
1118
+ num_restarts = num_restarts ,
1119
+ raw_samples = raw_samples ,
1120
+ ** constraint_kwargs ,
1121
+ )
1122
+ mock_gbics .assert_called_once ()
1123
+ mock_optacqf .assert_called_once ()
1124
+ n_value = int ((1 - 0.1 ) * num_fantasies )
1125
+ self .assertTrue (
1126
+ torch .equal (
1127
+ ics [..., :- n_value , :],
1128
+ mock_random_ics [..., :- n_value , :],
1129
+ )
1130
+ )
1131
+ self .assertTrue (torch .all (ics [..., - n_value :, :] == 1 ))
1083
1132
1084
1133
1085
1134
class TestGenOneShotHVKGInitialConditions (BotorchTestCase ):
0 commit comments