From dac3219bd62888dc197b6388b1dec0745af66f2c Mon Sep 17 00:00:00 2001 From: Aya <[kent831217@gmail.com]> Date: Sun, 25 Aug 2024 12:57:19 +0800 Subject: [PATCH] fix: modify test_force_tokens_logits_processor the checking value as scores.dtype.min --- tests/generation/test_tf_logits_process.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_tf_logits_process.py b/tests/generation/test_tf_logits_process.py index e87c843d9cb4de..f06f5695b1cef8 100644 --- a/tests/generation/test_tf_logits_process.py +++ b/tests/generation/test_tf_logits_process.py @@ -406,7 +406,12 @@ def test_force_tokens_logits_processor(self, use_xla): non_forced_inds = [i for i in range(vocab_size) if i != force_token_map[cur_len]] self.assertTrue( - tf.math.reduce_all(tf.math.is_inf(tf.gather(scores, [non_forced_inds], axis=1))), + tf.math.reduce_all( + tf.experimental.numpy.isclose( + tf.gather(scores, [non_forced_inds], axis=1), + tf.constant(scores.dtype.min), + ) + ) ) # check that if the cur_len is not contained in the force_token_map, the logits are not modified