Skip to content

Commit

Permalink
TF: remove set_tensor_by_indices_to_value (huggingface#16729)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored Apr 12, 2022
1 parent a315988 commit d7f7f29
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 27 deletions.
11 changes: 3 additions & 8 deletions src/transformers/generation_tf_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import numpy as np
import tensorflow as tf

from .tf_utils import set_tensor_by_indices_to_value
from .utils import add_start_docstrings
from .utils.logging import get_logger

Expand Down Expand Up @@ -221,7 +220,7 @@ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.
# generate is not XLA - compileable anyways
if cur_len < self.min_length:
eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape)
scores = set_tensor_by_indices_to_value(scores, eos_token_id_mask, float("-inf"))
scores = tf.where(eos_token_id_mask, float("-inf"), scores)

return scores

Expand Down Expand Up @@ -339,9 +338,7 @@ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)

scores = set_tensor_by_indices_to_value(
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
)
scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores)

return scores

Expand Down Expand Up @@ -397,9 +394,7 @@ def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)

scores = set_tensor_by_indices_to_value(
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
)
scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores)

return scores

Expand Down
17 changes: 8 additions & 9 deletions src/transformers/generation_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
TFTopKLogitsWarper,
TFTopPLogitsWarper,
)
from .tf_utils import set_tensor_by_indices_to_value, shape_list
from .tf_utils import shape_list
from .utils import ModelOutput, logging


Expand Down Expand Up @@ -952,8 +952,7 @@ def _generate_beam_search(
[True if token == eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
)
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [num_batch_hypotheses, vocab_size])

scores = set_tensor_by_indices_to_value(scores, eos_token_indices_mask, -float("inf"))
scores = tf.where(eos_token_indices_mask, -float("inf"), scores)

if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
Expand All @@ -969,8 +968,8 @@ def _generate_beam_search(
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)

scores = set_tensor_by_indices_to_value(
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
scores = tf.where(
tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores
)

if bad_words_ids is not None:
Expand All @@ -983,8 +982,8 @@ def _generate_beam_search(
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)

scores = set_tensor_by_indices_to_value(
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
scores = tf.where(
tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores
)

assert shape_list(scores) == [batch_size * num_beams, vocab_size]
Expand Down Expand Up @@ -2950,7 +2949,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
top_k = min(max(top_k, min_tokens_to_keep), logits_shape[-1]) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None]
logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
logits = tf.where(indices_to_remove, filter_value, logits)
if top_p < 1.0:
sorted_indices = tf.argsort(logits, direction="DESCENDING")
sorted_logits = tf.gather(
Expand Down Expand Up @@ -2979,7 +2978,7 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
)
# scatter sorted tensors to original indexing
indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices)
logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
logits = tf.where(indices_to_remove, filter_value, logits)
return logits


Expand Down
5 changes: 0 additions & 5 deletions src/transformers/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@
logger = logging.get_logger(__name__)


def set_tensor_by_indices_to_value(tensor: tf.Tensor, indices: tf.Tensor, value: Union[tf.Tensor, int, float]):
# create value_tensor since tensor value assignment is not possible in TF
return tf.where(indices, value, tensor)


def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:
"""
Deal with dynamic shape in tensorflow cleanly.
Expand Down
9 changes: 4 additions & 5 deletions tests/generation/test_generation_tf_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
TFTopKLogitsWarper,
TFTopPLogitsWarper,
)
from transformers.tf_utils import set_tensor_by_indices_to_value

from ..test_modeling_tf_common import ids_tensor

Expand Down Expand Up @@ -112,9 +111,9 @@ def test_repetition_penalty_dist_process(self):
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)

mask = tf.cast(tf.constant([[1] + 9 * [0], 10 * [0]]), tf.bool)
scores = set_tensor_by_indices_to_value(scores, mask, -1 / vocab_size)
scores = tf.where(mask, -1 / vocab_size, scores)
mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool)
scores = set_tensor_by_indices_to_value(scores, mask, 4 / vocab_size)
scores = tf.where(mask, 4 / vocab_size, scores)

rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)

Expand Down Expand Up @@ -340,8 +339,8 @@ def test_processor_list(self):
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)

# remove inf
scores = set_tensor_by_indices_to_value(scores, tf.math.is_inf(scores), -1e9)
scores_comp = set_tensor_by_indices_to_value(scores_comp, tf.math.is_inf(scores_comp), -1e9)
scores = tf.where(tf.math.is_inf(scores), -1e9, scores)
scores_comp = tf.where(tf.math.is_inf(scores_comp), -1e9, scores_comp)

# scores should be equal
tf.debugging.assert_near(scores, scores_comp, atol=1e-3)
Expand Down

0 comments on commit d7f7f29

Please sign in to comment.