@@ -110,6 +110,8 @@ class MinLengthLogitsProcessor(LogitsProcessor):
110110 The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
111111 eos_token_id (`Union[int, List[int], torch.Tensor]`):
112112 The id(s) of the *end-of-sequence* token.
113+ device (`str`, *optional*, defaults to `"cpu"`):
114+ The device to allocate the tensors.
113115
114116 Examples:
115117
@@ -137,22 +139,21 @@ class MinLengthLogitsProcessor(LogitsProcessor):
137139 ```
138140 """
139141
140- def __init__ (self , min_length : int , eos_token_id : Union [int , List [int ], torch .Tensor ]):
142+ def __init__ (self , min_length : int , eos_token_id : Union [int , List [int ], torch .Tensor ], device : str = "cpu" ):
141143 if not isinstance (min_length , int ) or min_length < 0 :
142144 raise ValueError (f"`min_length` has to be a non-negative integer, but is { min_length } " )
143145
144146 if not isinstance (eos_token_id , torch .Tensor ):
145147 if isinstance (eos_token_id , int ):
146148 eos_token_id = [eos_token_id ]
147- eos_token_id = torch .tensor (eos_token_id )
149+ eos_token_id = torch .tensor (eos_token_id , device = device )
148150
149151 self .min_length = min_length
150152 self .eos_token_id = eos_token_id
151153
152154 @add_start_docstrings (LOGITS_PROCESSOR_INPUTS_DOCSTRING )
153155 def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor ) -> torch .FloatTensor :
154156 vocab_tensor = torch .arange (scores .shape [- 1 ], device = scores .device )
155- self .eos_token_id = self .eos_token_id .to (scores .device )
156157 eos_token_mask = torch .isin (vocab_tensor , self .eos_token_id )
157158 scores_processed = scores .clone ()
158159 if input_ids .shape [- 1 ] < self .min_length :
@@ -173,6 +174,8 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
173174 The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
174175 eos_token_id (`Union[int, List[int], torch.Tensor]`):
175176 The id(s) of the *end-of-sequence* token.
177+ device (`str`, *optional*, defaults to `"cpu"`):
178+ The device to allocate the tensors.
176179
177180 Examples:
178181
@@ -196,7 +199,11 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
196199 """
197200
198201 def __init__ (
199- self , prompt_length_to_skip : int , min_new_tokens : int , eos_token_id : Union [int , List [int ], torch .Tensor ]
202+ self ,
203+ prompt_length_to_skip : int ,
204+ min_new_tokens : int ,
205+ eos_token_id : Union [int , List [int ], torch .Tensor ],
206+ device : str = "cpu" ,
200207 ):
201208 for arg_name , arg_value in [
202209 ("prompt_length_to_skip" , prompt_length_to_skip ),
@@ -208,7 +215,7 @@ def __init__(
208215 if not isinstance (eos_token_id , torch .Tensor ):
209216 if isinstance (eos_token_id , int ):
210217 eos_token_id = [eos_token_id ]
211- eos_token_id = torch .tensor (eos_token_id )
218+ eos_token_id = torch .tensor (eos_token_id , device = device )
212219
213220 self .prompt_length_to_skip = prompt_length_to_skip
214221 self .min_new_tokens = min_new_tokens
@@ -219,7 +226,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
219226 new_tokens_length = input_ids .shape [- 1 ] - self .prompt_length_to_skip
220227 scores_processed = scores .clone ()
221228 vocab_tensor = torch .arange (scores .shape [- 1 ], device = scores .device )
222- self .eos_token_id = self .eos_token_id .to (scores .device )
223229 eos_token_mask = torch .isin (vocab_tensor , self .eos_token_id )
224230 if new_tokens_length < self .min_new_tokens :
225231 scores_processed = torch .where (eos_token_mask , - math .inf , scores )
@@ -779,6 +785,8 @@ class EtaLogitsWarper(LogitsWarper):
779785 Specifies the minimum number of tokens that must be kept for generation, regardless of their probabilities.
780786 For example, if `min_tokens_to_keep` is set to 1, at least one token will always be kept for generation,
781787 even if all tokens have probabilities below the cutoff `eta`.
788+ device (`str`, *optional*, defaults to `"cpu"`):
789+ The device to allocate the tensors.
782790
783791 Examples:
784792 ```python
@@ -806,7 +814,9 @@ class EtaLogitsWarper(LogitsWarper):
806814 ```
807815 """
808816
809- def __init__ (self , epsilon : float , filter_value : float = - float ("Inf" ), min_tokens_to_keep : int = 1 ):
817+ def __init__ (
818+ self , epsilon : float , filter_value : float = - float ("Inf" ), min_tokens_to_keep : int = 1 , device : str = "cpu"
819+ ):
810820 epsilon = float (epsilon )
811821 if epsilon <= 0 or epsilon >= 1 :
812822 raise ValueError (f"`eta_cutoff` has to be a float > 0 and < 1, but is { epsilon } " )
@@ -817,13 +827,12 @@ def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_toke
817827 f"`min_tokens_to_keep` has to be a strictly positive integer, but is { min_tokens_to_keep } "
818828 )
819829
820- self .epsilon = torch .tensor (epsilon )
830+ self .epsilon = torch .tensor (epsilon , device = device )
821831 self .filter_value = filter_value
822832 self .min_tokens_to_keep = min_tokens_to_keep
823833
824834 @add_start_docstrings (LOGITS_PROCESSOR_INPUTS_DOCSTRING )
825835 def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor ) -> torch .FloatTensor :
826- # Calculate the adaptive cutoff
827836 probabilities = scores .softmax (dim = - 1 )
828837 entropy = torch .distributions .Categorical (logits = scores ).entropy ()
829838 eta = torch .min (self .epsilon , torch .sqrt (self .epsilon ) * torch .exp (- entropy ))[..., None ]
@@ -1530,6 +1539,8 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
15301539 The maximum length of the sequence to be generated.
15311540 eos_token_id (`Union[int, List[int], torch.Tensor]`):
15321541 The id(s) of the *end-of-sequence* token.
1542+ device (`str`, *optional*, defaults to `"cpu"`):
1543+ The device to allocate the tensors.
15331544
15341545 Examples:
15351546
@@ -1553,13 +1564,13 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
15531564 ```
15541565 """
15551566
1556- def __init__ (self , max_length : int , eos_token_id : Union [int , List [int ], torch .Tensor ]):
1567+ def __init__ (self , max_length : int , eos_token_id : Union [int , List [int ], torch .Tensor ], device : str = "cpu" ):
15571568 self .max_length = max_length
15581569
15591570 if not isinstance (eos_token_id , torch .Tensor ):
15601571 if isinstance (eos_token_id , int ):
15611572 eos_token_id = [eos_token_id ]
1562- eos_token_id = torch .tensor (eos_token_id )
1573+ eos_token_id = torch .tensor (eos_token_id , device = device )
15631574 self .eos_token_id = eos_token_id
15641575
15651576 if torch .is_floating_point (eos_token_id ) or (eos_token_id < 0 ).any ():
@@ -1568,7 +1579,6 @@ def __init__(self, max_length: int, eos_token_id: Union[int, List[int], torch.Te
15681579 @add_start_docstrings (LOGITS_PROCESSOR_INPUTS_DOCSTRING )
15691580 def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor ) -> torch .FloatTensor :
15701581 cur_len = input_ids .shape [- 1 ]
1571- self .eos_token_id = self .eos_token_id .to (scores .device )
15721582 scores_processed = scores
15731583 if cur_len == self .max_length - 1 :
15741584 scores_processed = torch .full_like (scores , - math .inf )
@@ -1770,8 +1780,8 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
17701780 ```
17711781 """
17721782
1773- def __init__ (self , begin_suppress_tokens , begin_index ):
1774- self .begin_suppress_tokens = torch .tensor (list (begin_suppress_tokens ))
1783+ def __init__ (self , begin_suppress_tokens , begin_index , device : str = "cpu" ):
1784+ self .begin_suppress_tokens = torch .tensor (list (begin_suppress_tokens ), device = device )
17751785 self .begin_index = begin_index
17761786
17771787 def set_begin_index (self , begin_index ):
@@ -1780,7 +1790,6 @@ def set_begin_index(self, begin_index):
17801790 @add_start_docstrings (LOGITS_PROCESSOR_INPUTS_DOCSTRING )
17811791 def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor ) -> torch .FloatTensor :
17821792 vocab_tensor = torch .arange (scores .shape [- 1 ], device = scores .device )
1783- self .begin_suppress_tokens = self .begin_suppress_tokens .to (scores .device )
17841793 suppress_token_mask = torch .isin (vocab_tensor , self .begin_suppress_tokens )
17851794 scores_processed = scores
17861795 if input_ids .shape [- 1 ] == self .begin_index :
@@ -1818,13 +1827,12 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
18181827 ```
18191828 """
18201829
1821- def __init__ (self , suppress_tokens ):
1822- self .suppress_tokens = torch .tensor (list (suppress_tokens ))
1830+ def __init__ (self , suppress_tokens , device : str = "cpu" ):
1831+ self .suppress_tokens = torch .tensor (list (suppress_tokens ), device = device )
18231832
18241833 @add_start_docstrings (LOGITS_PROCESSOR_INPUTS_DOCSTRING )
18251834 def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor ) -> torch .FloatTensor :
18261835 vocab_tensor = torch .arange (scores .shape [- 1 ], device = scores .device )
1827- self .suppress_tokens = self .suppress_tokens .to (scores .device )
18281836 suppress_token_mask = torch .isin (vocab_tensor , self .suppress_tokens )
18291837 scores = torch .where (suppress_token_mask , - float ("inf" ), scores )
18301838 return scores
@@ -1915,7 +1923,10 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
19151923 """
19161924
19171925 def __init__ (
1918- self , generate_config , begin_index : Optional [int ] = None , _detect_timestamp_from_logprob : Optional [bool ] = None
1926+ self ,
1927+ generate_config ,
1928+ begin_index : Optional [int ] = None ,
1929+ _detect_timestamp_from_logprob : Optional [bool ] = None ,
19191930 ): # support for the kwargs
19201931 self .no_timestamps_token_id = generate_config .no_timestamps_token_id
19211932 self .timestamp_begin = generate_config .no_timestamps_token_id + 1
@@ -2292,11 +2303,11 @@ class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
22922303 Minimum end of speech threshold.
22932304 """
22942305
2295- def __init__ (self , eos_token_id : Union [int , List [int ], torch .Tensor ], min_eos_p : float ):
2306+ def __init__ (self , eos_token_id : Union [int , List [int ], torch .Tensor ], min_eos_p : float , device : str = "cpu" ):
22962307 if not isinstance (eos_token_id , torch .Tensor ):
22972308 if isinstance (eos_token_id , int ):
22982309 eos_token_id = [eos_token_id ]
2299- eos_token_id = torch .tensor (eos_token_id )
2310+ eos_token_id = torch .tensor (eos_token_id , device = device )
23002311 self .eos_token_id = eos_token_id
23012312
23022313 if torch .is_floating_point (eos_token_id ) or (eos_token_id < 0 ).any ():
@@ -2309,7 +2320,6 @@ def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p:
23092320 @add_start_docstrings (LOGITS_PROCESSOR_INPUTS_DOCSTRING )
23102321 def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor ) -> torch .FloatTensor :
23112322 scores_processed = scores
2312- self .eos_token_id = self .eos_token_id .to (scores .device )
23132323 if self .min_eos_p :
23142324 probs = torch .nn .functional .softmax (scores .float (), dim = - 1 )
23152325 # create scores full of -inf except for the eos_token_id
0 commit comments