Skip to content

Commit 83238ee

Browse files
zucchini-nlpgante
andauthored
Pass device in Logits Processor's init (#29804)
* add device in logits processor * remove device when not needed * codestyle * tests * forgot `melody` version * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * codestyle * updates --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
1 parent c73ee13 commit 83238ee

File tree

7 files changed

+119
-52
lines changed

7 files changed

+119
-52
lines changed

src/transformers/generation/logits_process.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ class MinLengthLogitsProcessor(LogitsProcessor):
110110
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
111111
eos_token_id (`Union[int, List[int], torch.Tensor]`):
112112
The id(s) of the *end-of-sequence* token.
113+
device (`str`, *optional*, defaults to `"cpu"`):
114+
The device to allocate the tensors.
113115
114116
Examples:
115117
@@ -137,22 +139,21 @@ class MinLengthLogitsProcessor(LogitsProcessor):
137139
```
138140
"""
139141

140-
def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor]):
142+
def __init__(self, min_length: int, eos_token_id: Union[int, List[int], torch.Tensor], device: str = "cpu"):
141143
if not isinstance(min_length, int) or min_length < 0:
142144
raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}")
143145

144146
if not isinstance(eos_token_id, torch.Tensor):
145147
if isinstance(eos_token_id, int):
146148
eos_token_id = [eos_token_id]
147-
eos_token_id = torch.tensor(eos_token_id)
149+
eos_token_id = torch.tensor(eos_token_id, device=device)
148150

149151
self.min_length = min_length
150152
self.eos_token_id = eos_token_id
151153

152154
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
153155
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
154156
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
155-
self.eos_token_id = self.eos_token_id.to(scores.device)
156157
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
157158
scores_processed = scores.clone()
158159
if input_ids.shape[-1] < self.min_length:
@@ -173,6 +174,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
173174
The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
174175
eos_token_id (`Union[int, List[int], torch.Tensor]`):
175176
The id(s) of the *end-of-sequence* token.
177+
device (`str`, *optional*, defaults to `"cpu"`):
178+
The device to allocate the tensors.
176179
177180
Examples:
178181
@@ -196,7 +199,11 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
196199
"""
197200

198201
def __init__(
199-
self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int], torch.Tensor]
202+
self,
203+
prompt_length_to_skip: int,
204+
min_new_tokens: int,
205+
eos_token_id: Union[int, List[int], torch.Tensor],
206+
device: str = "cpu",
200207
):
201208
for arg_name, arg_value in [
202209
("prompt_length_to_skip", prompt_length_to_skip),
@@ -208,7 +215,7 @@ def __init__(
208215
if not isinstance(eos_token_id, torch.Tensor):
209216
if isinstance(eos_token_id, int):
210217
eos_token_id = [eos_token_id]
211-
eos_token_id = torch.tensor(eos_token_id)
218+
eos_token_id = torch.tensor(eos_token_id, device=device)
212219

213220
self.prompt_length_to_skip = prompt_length_to_skip
214221
self.min_new_tokens = min_new_tokens
@@ -219,7 +226,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
219226
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
220227
scores_processed = scores.clone()
221228
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
222-
self.eos_token_id = self.eos_token_id.to(scores.device)
223229
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
224230
if new_tokens_length < self.min_new_tokens:
225231
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
@@ -779,6 +785,8 @@ class EtaLogitsWarper(LogitsWarper):
779785
Specifies the minimum number of tokens that must be kept for generation, regardless of their probabilities.
780786
For example, if `min_tokens_to_keep` is set to 1, at least one token will always be kept for generation,
781787
even if all tokens have probabilities below the cutoff `eta`.
788+
device (`str`, *optional*, defaults to `"cpu"`):
789+
The device to allocate the tensors.
782790
783791
Examples:
784792
```python
@@ -806,7 +814,9 @@ class EtaLogitsWarper(LogitsWarper):
806814
```
807815
"""
808816

809-
def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
817+
def __init__(
818+
self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1, device: str = "cpu"
819+
):
810820
epsilon = float(epsilon)
811821
if epsilon <= 0 or epsilon >= 1:
812822
raise ValueError(f"`eta_cutoff` has to be a float > 0 and < 1, but is {epsilon}")
@@ -817,13 +827,12 @@ def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_toke
817827
f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}"
818828
)
819829

820-
self.epsilon = torch.tensor(epsilon)
830+
self.epsilon = torch.tensor(epsilon, device=device)
821831
self.filter_value = filter_value
822832
self.min_tokens_to_keep = min_tokens_to_keep
823833

824834
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
825835
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
826-
# Calculate the adaptive cutoff
827836
probabilities = scores.softmax(dim=-1)
828837
entropy = torch.distributions.Categorical(logits=scores).entropy()
829838
eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None]
@@ -1530,6 +1539,8 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
15301539
The maximum length of the sequence to be generated.
15311540
eos_token_id (`Union[int, List[int], torch.Tensor]`):
15321541
The id(s) of the *end-of-sequence* token.
1542+
device (`str`, *optional*, defaults to `"cpu"`):
1543+
The device to allocate the tensors.
15331544
15341545
Examples:
15351546
@@ -1553,13 +1564,13 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
15531564
```
15541565
"""
15551566

1556-
def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor]):
1567+
def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Tensor], device: str = "cpu"):
15571568
self.max_length = max_length
15581569

15591570
if not isinstance(eos_token_id, torch.Tensor):
15601571
if isinstance(eos_token_id, int):
15611572
eos_token_id = [eos_token_id]
1562-
eos_token_id = torch.tensor(eos_token_id)
1573+
eos_token_id = torch.tensor(eos_token_id, device=device)
15631574
self.eos_token_id = eos_token_id
15641575

15651576
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
@@ -1568,7 +1579,6 @@ def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Te
15681579
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
15691580
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
15701581
cur_len = input_ids.shape[-1]
1571-
self.eos_token_id = self.eos_token_id.to(scores.device)
15721582
scores_processed = scores
15731583
if cur_len == self.max_length - 1:
15741584
scores_processed = torch.full_like(scores, -math.inf)
@@ -1770,8 +1780,8 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
17701780
```
17711781
"""
17721782

1773-
def __init__(self, begin_suppress_tokens, begin_index):
1774-
self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens))
1783+
def __init__(self, begin_suppress_tokens, begin_index, device: str = "cpu"):
1784+
self.begin_suppress_tokens = torch.tensor(list(begin_suppress_tokens), device=device)
17751785
self.begin_index = begin_index
17761786

17771787
def set_begin_index(self, begin_index):
@@ -1780,7 +1790,6 @@ def set_begin_index(self, begin_index):
17801790
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
17811791
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
17821792
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
1783-
self.begin_suppress_tokens = self.begin_suppress_tokens.to(scores.device)
17841793
suppress_token_mask = torch.isin(vocab_tensor, self.begin_suppress_tokens)
17851794
scores_processed = scores
17861795
if input_ids.shape[-1] == self.begin_index:
@@ -1818,13 +1827,12 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
18181827
```
18191828
"""
18201829

1821-
def __init__(self, suppress_tokens):
1822-
self.suppress_tokens = torch.tensor(list(suppress_tokens))
1830+
def __init__(self, suppress_tokens, device: str = "cpu"):
1831+
self.suppress_tokens = torch.tensor(list(suppress_tokens), device=device)
18231832

18241833
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
18251834
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
18261835
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
1827-
self.suppress_tokens = self.suppress_tokens.to(scores.device)
18281836
suppress_token_mask = torch.isin(vocab_tensor, self.suppress_tokens)
18291837
scores = torch.where(suppress_token_mask, -float("inf"), scores)
18301838
return scores
@@ -1915,7 +1923,10 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
19151923
"""
19161924

19171925
def __init__(
1918-
self, generate_config, begin_index: Optional[int] = None, _detect_timestamp_from_logprob: Optional[bool] = None
1926+
self,
1927+
generate_config,
1928+
begin_index: Optional[int] = None,
1929+
_detect_timestamp_from_logprob: Optional[bool] = None,
19191930
): # support for the kwargs
19201931
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
19211932
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
@@ -2292,11 +2303,11 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
22922303
Minimum end of speech threshold.
22932304
"""
22942305

2295-
def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float):
2306+
def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p: float, device: str = "cpu"):
22962307
if not isinstance(eos_token_id, torch.Tensor):
22972308
if isinstance(eos_token_id, int):
22982309
eos_token_id = [eos_token_id]
2299-
eos_token_id = torch.tensor(eos_token_id)
2310+
eos_token_id = torch.tensor(eos_token_id, device=device)
23002311
self.eos_token_id = eos_token_id
23012312

23022313
if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
@@ -2309,7 +2320,6 @@ def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p:
23092320
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
23102321
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
23112322
scores_processed = scores
2312-
self.eos_token_id = self.eos_token_id.to(scores.device)
23132323
if self.min_eos_p:
23142324
probs = torch.nn.functional.softmax(scores.float(), dim=-1)
23152325
# create scores full of -inf except for the eos_token_id

src/transformers/generation/utils.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,7 @@ def _get_candidate_generator(
723723
def _get_logits_warper(
724724
self,
725725
generation_config: GenerationConfig,
726+
device: str,
726727
) -> LogitsProcessorList:
727728
"""
728729
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
@@ -765,7 +766,9 @@ def _get_logits_warper(
765766
)
766767
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
767768
warpers.append(
768-
EtaLogitsWarper(epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep)
769+
EtaLogitsWarper(
770+
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
771+
)
769772
)
770773
# `LogitNormalization` should always be the last logit processor, when present
771774
if generation_config.renormalize_logits is True:
@@ -818,7 +821,8 @@ def _get_logits_processor(
818821
):
819822
processors.append(
820823
EncoderRepetitionPenaltyLogitsProcessor(
821-
penalty=generation_config.encoder_repetition_penalty, encoder_input_ids=encoder_input_ids
824+
penalty=generation_config.encoder_repetition_penalty,
825+
encoder_input_ids=encoder_input_ids,
822826
)
823827
)
824828
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
@@ -830,39 +834,63 @@ def _get_logits_processor(
830834
and generation_config.encoder_no_repeat_ngram_size > 0
831835
):
832836
processors.append(
833-
EncoderNoRepeatNGramLogitsProcessor(generation_config.encoder_no_repeat_ngram_size, encoder_input_ids)
837+
EncoderNoRepeatNGramLogitsProcessor(
838+
generation_config.encoder_no_repeat_ngram_size,
839+
encoder_input_ids,
840+
)
834841
)
835842
if generation_config.bad_words_ids is not None:
836843
processors.append(
837-
NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)
844+
NoBadWordsLogitsProcessor(
845+
generation_config.bad_words_ids,
846+
generation_config.eos_token_id,
847+
)
838848
)
839849
if (
840850
generation_config.min_length is not None
841851
and generation_config.eos_token_id is not None
842852
and generation_config.min_length > 0
843853
):
844-
processors.append(MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id))
854+
processors.append(
855+
MinLengthLogitsProcessor(
856+
generation_config.min_length,
857+
generation_config.eos_token_id,
858+
device=device,
859+
)
860+
)
845861
if (
846862
generation_config.min_new_tokens is not None
847863
and generation_config.eos_token_id is not None
848864
and generation_config.min_new_tokens > 0
849865
):
850866
processors.append(
851867
MinNewTokensLengthLogitsProcessor(
852-
input_ids_seq_length, generation_config.min_new_tokens, generation_config.eos_token_id
868+
input_ids_seq_length,
869+
generation_config.min_new_tokens,
870+
generation_config.eos_token_id,
871+
device=device,
853872
)
854873
)
855874
if prefix_allowed_tokens_fn is not None:
856875
processors.append(
857876
PrefixConstrainedLogitsProcessor(
858-
prefix_allowed_tokens_fn, generation_config.num_beams // generation_config.num_beam_groups
877+
prefix_allowed_tokens_fn,
878+
generation_config.num_beams // generation_config.num_beam_groups,
859879
)
860880
)
861881
if generation_config.forced_bos_token_id is not None:
862-
processors.append(ForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id))
882+
processors.append(
883+
ForcedBOSTokenLogitsProcessor(
884+
generation_config.forced_bos_token_id,
885+
)
886+
)
863887
if generation_config.forced_eos_token_id is not None:
864888
processors.append(
865-
ForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id)
889+
ForcedEOSTokenLogitsProcessor(
890+
generation_config.max_length,
891+
generation_config.forced_eos_token_id,
892+
device=device,
893+
)
866894
)
867895
if generation_config.remove_invalid_values is True:
868896
processors.append(InfNanRemoveLogitsProcessor())
@@ -875,7 +903,12 @@ def _get_logits_processor(
875903
)
876904
)
877905
if generation_config.suppress_tokens is not None:
878-
processors.append(SuppressTokensLogitsProcessor(generation_config.suppress_tokens))
906+
processors.append(
907+
SuppressTokensLogitsProcessor(
908+
generation_config.suppress_tokens,
909+
device=device,
910+
)
911+
)
879912
if generation_config.begin_suppress_tokens is not None:
880913
begin_index = input_ids_seq_length
881914
begin_index = (
@@ -887,7 +920,11 @@ def _get_logits_processor(
887920
# generation starts after the last token that is forced
888921
begin_index += generation_config.forced_decoder_ids[-1][0]
889922
processors.append(
890-
SuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
923+
SuppressTokensAtBeginLogitsProcessor(
924+
generation_config.begin_suppress_tokens,
925+
begin_index,
926+
device=device,
927+
)
891928
)
892929
if generation_config.forced_decoder_ids is not None:
893930
# TODO(Sanchit): deprecate in v4.40 by removing this logic
@@ -1779,7 +1816,12 @@ def generate(
17791816

17801817
# 12. prepare logits warper (if `do_sample` is `True`)
17811818
prepared_logits_warper = (
1782-
self._get_logits_warper(generation_config) if generation_config.do_sample else None
1819+
self._get_logits_warper(
1820+
generation_config,
1821+
device=input_ids.device,
1822+
)
1823+
if generation_config.do_sample
1824+
else None
17831825
)
17841826

17851827
# 13. run assisted generate
@@ -1812,7 +1854,9 @@ def generate(
18121854
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
18131855
# 11. prepare logits warper
18141856
prepared_logits_warper = (
1815-
self._get_logits_warper(generation_config) if generation_config.do_sample else None
1857+
self._get_logits_warper(generation_config, device=input_ids.device)
1858+
if generation_config.do_sample
1859+
else None
18161860
)
18171861

18181862
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
@@ -1838,7 +1882,9 @@ def generate(
18381882
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
18391883
# 11. prepare logits warper
18401884
prepared_logits_warper = (
1841-
self._get_logits_warper(generation_config) if generation_config.do_sample else None
1885+
self._get_logits_warper(generation_config, device=input_ids.device)
1886+
if generation_config.do_sample
1887+
else None
18421888
)
18431889

18441890
# 12. prepare beam search scorer

0 commit comments

Comments
 (0)