Skip to content

Commit 9354fd7

Browse files
esantorellafacebook-github-bot
authored andcommitted
Set seed in test_learned_preference_objective to stop it from flaking (#2145)
Summary: ## Motivation * The test was flaky due to a varying amount of numerical error depending on random inputs, so I set a seed to a random number between 0 and 10 * Changed some data to double precision to avoid a warning Pull Request resolved: #2145 Test Plan: * checked that the test passes for each seed between 0 and 10 * I confirmed that there are seeds that do cause it to fail * Increased the number of samples a lot to confirm that numerical error because small when the number of samples is large -- in other words, the error is due to a low number of samples Reviewed By: Balandat Differential Revision: D52002349 Pulled By: esantorella fbshipit-source-id: c1908bdf649db0d51c8e8c2806b9e55258ffb855
1 parent f2003dd commit 9354fd7

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

test/acquisition/test_objective.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,8 @@ def _get_pref_model(
454454
return pref_model
455455

456456
def test_learned_preference_objective(self) -> None:
457+
seed = torch.randint(low=0, high=10, size=torch.Size([1]))
458+
torch.manual_seed(seed)
457459
pref_model = self._get_pref_model(dtype=torch.float64)
458460

459461
og_sample_shape = 3
@@ -492,7 +494,7 @@ def test_learned_preference_objective(self) -> None:
492494
with self.assertRaisesRegex(
493495
ValueError, "samples should have at least 3 dimensions."
494496
):
495-
pref_obj(torch.rand(q, self.x_dim))
497+
pref_obj(torch.rand(q, self.x_dim, dtype=torch.float64))
496498

497499
# test when sampler has multiple preference samples
498500
with self.subTest("Multiple samples"):

0 commit comments

Comments
 (0)