Skip to content

Commit ee40088

Browse files
authored
enable deterministic in bnb 4 bit tests (#11738)
* enable deterministic in bnb 4 bit tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix 8bit test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 7fc53b5 commit ee40088

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

tests/quantization/bnb/test_4bit.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ class Base4bitTests(unittest.TestCase):
9696
num_inference_steps = 10
9797
seed = 0
9898

99+
@classmethod
100+
def setUpClass(cls):
101+
torch.use_deterministic_algorithms(True)
102+
99103
def get_dummy_inputs(self):
100104
prompt_embeds = load_pt(
101105
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt",
@@ -480,7 +484,6 @@ def test_generate_quality_dequantize(self):
480484
r"""
481485
Test that loading the model and unquantize it produce correct results.
482486
"""
483-
torch.use_deterministic_algorithms(True)
484487
self.pipeline_4bit.transformer.dequantize()
485488
output = self.pipeline_4bit(
486489
prompt=self.prompt,

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ class Base8bitTests(unittest.TestCase):
9797
num_inference_steps = 10
9898
seed = 0
9999

100+
@classmethod
101+
def setUpClass(cls):
102+
torch.use_deterministic_algorithms(True)
103+
100104
def get_dummy_inputs(self):
101105
prompt_embeds = load_pt(
102106
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt",
@@ -485,7 +489,6 @@ def test_generate_quality_dequantize(self):
485489
r"""
486490
Test that loading the model and unquantize it produce correct results.
487491
"""
488-
torch.use_deterministic_algorithms(True)
489492
self.pipeline_8bit.transformer.dequantize()
490493
output = self.pipeline_8bit(
491494
prompt=self.prompt,

0 commit comments

Comments
 (0)