Skip to content

Commit

Permalink
[PipelineTesterMixin] Handle non-image outputs for attn slicing test (h…
Browse files Browse the repository at this point in the history
…uggingface#2504)

* [PipelineTesterMixin] Handle non-image outputs for batch/sinle inference test

* style

---------

Co-authored-by: William Berman <WLBberman@gmail.com>
  • Loading branch information
2 people authored and mengfei25 committed Mar 27, 2023
1 parent 20a74fb commit f2eb38e
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions tests/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,9 @@ def test_to_device(self):
def test_attention_slicing_forward_pass(self):
self._test_attention_slicing_forward_pass()

def _test_attention_slicing_forward_pass(self, test_max_difference=True, expected_max_diff=1e-3):
def _test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return

Expand All @@ -474,7 +476,8 @@ def _test_attention_slicing_forward_pass(self, test_max_difference=True, expecte
max_diff = np.abs(output_with_slicing - output_without_slicing).max()
self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results")

assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0])
if test_mean_pixel_difference:
assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0])

@unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
Expand Down

0 comments on commit f2eb38e

Please sign in to comment.