Skip to content

Commit

Permalink
fix: modify test_force_tokens_logits_processor the checking value as …
Browse files Browse the repository at this point in the history
…scores.dtype.min
  • Loading branch information
Aya committed Aug 25, 2024
1 parent b4073dd commit dac3219
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion tests/generation/test_tf_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,12 @@ def test_force_tokens_logits_processor(self, use_xla):

non_forced_inds = [i for i in range(vocab_size) if i != force_token_map[cur_len]]
self.assertTrue(
tf.math.reduce_all(tf.math.is_inf(tf.gather(scores, [non_forced_inds], axis=1))),
tf.math.reduce_all(
tf.experimental.numpy.isclose(
tf.gather(scores, [non_forced_inds], axis=1),
tf.constant(scores.dtype.min),
)
)
)

# check that if the cur_len is not contained in the force_token_map, the logits are not modified
Expand Down

0 comments on commit dac3219

Please sign in to comment.