Skip to content

Commit

Permalink
fix: multilingual midel convert to tflite get wrong token (huggingfac…
Browse files Browse the repository at this point in the history
…e#32079)

* fix: multilingual midel convert to tflite get wrong token

* fix: modify test_force_tokens_logits_processor the checking value as scores.dtype.min

---------

Co-authored-by: kent.sc.hung <kent.sc.hung@benq.com>
Co-authored-by: Aya <[kent831217@gmail.com]>
  • Loading branch information
3 people authored and zucchini-nlp committed Aug 30, 2024
1 parent 1fa3333 commit d37a8eb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/transformers/generation/tf_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def _force_token(generation_idx):
batch_size = scores.shape[0]
current_token = self.force_token_array[generation_idx]

new_scores = tf.ones_like(scores, dtype=scores.dtype) * -float("inf")
new_scores = tf.zeros_like(scores, dtype=scores.dtype) + tf.constant([scores.dtype.min])
indices = tf.stack((tf.range(batch_size), tf.tile([current_token], [batch_size])), axis=1)
updates = tf.zeros((batch_size,), dtype=scores.dtype)
new_scores = tf.tensor_scatter_nd_update(new_scores, indices, updates)
Expand Down
7 changes: 6 additions & 1 deletion tests/generation/test_tf_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,12 @@ def test_force_tokens_logits_processor(self, use_xla):

non_forced_inds = [i for i in range(vocab_size) if i != force_token_map[cur_len]]
self.assertTrue(
tf.math.reduce_all(tf.math.is_inf(tf.gather(scores, [non_forced_inds], axis=1))),
tf.math.reduce_all(
tf.experimental.numpy.isclose(
tf.gather(scores, [non_forced_inds], axis=1),
tf.constant(scores.dtype.min),
)
)
)

# check that if the cur_len is not contained in the force_token_map, the logits are not modified
Expand Down

0 comments on commit d37a8eb

Please sign in to comment.