11from functools import cached_property
2- from typing import List , Optional , Tuple
2+ from typing import Dict , List , Optional , Tuple
33
44import torch
55import torch .jit
@@ -36,7 +36,7 @@ def forward(
3636 bonus_token_ids : torch .Tensor ,
3737 draft_probs : torch .Tensor ,
3838 draft_token_ids : torch .Tensor ,
39- generators : List [ Optional [torch .Generator ]],
39+ seeded_seqs : Optional [Dict [ int , torch .Generator ]] = None ,
4040 ) -> torch .Tensor :
4141 """Sample token ids using rejection sampling. This accepts or rejects
4242 tokens proposed by the draft model using the probability of each token
@@ -66,6 +66,9 @@ def forward(
6666 probabilities.
6767 shape = [batch_size, num_speculative_tokens]
6868
69+ seeded_seqs: Dict of batch row index to torch generator, for
70+ sequences using seeded generation.
71+
6972 Returns:
7073 output_token_ids: The token ids sampled via rejection sampling,
7174 or -1 if unable to sample a token because the previous token
@@ -83,7 +86,7 @@ def forward(
8386 target_probs ,
8487 draft_probs ,
8588 draft_token_ids ,
86- generators ,
89+ seeded_seqs ,
8790 ))
8891
8992 output_token_ids = self ._create_output (
@@ -100,7 +103,7 @@ def _batch_modified_rejection_sampling(
100103 target_probs : torch .Tensor , # [batch_size, k, vocab_size]
101104 draft_probs : torch .Tensor , # [batch_size, k, vocab_size]
102105 draft_token_ids : torch .Tensor , # [batch_size, k]
103- generators : List [ Optional [torch .Generator ]],
106+ seeded_seqs : Optional [Dict [ int , torch .Generator ]],
104107 ) -> Tuple [torch .Tensor , torch .Tensor ]:
105108 """Perform modified rejection sampling on each sequence.
106109
@@ -117,23 +120,17 @@ def _batch_modified_rejection_sampling(
117120
118121 # shape [batch_size, k]
119122 accepted = self ._get_accepted (target_probs , draft_probs ,
120- draft_token_ids , generators )
123+ draft_token_ids , seeded_seqs )
121124
122125 recovered_probs = self ._get_recovered_probs (
123126 target_probs , draft_probs ).reshape (batch_size * k , vocab_size )
124127
125- seed_indices , non_seed_indices = self ._split_batch_by_seeded (
126- generators , k = k )
127-
128128 # NOTE: the recovered_probs are overwritten by this method.
129129 recovered_token_ids = _multinomial (
130130 recovered_probs ,
131131 num_samples = 1 ,
132132 k = k ,
133- generators = generators ,
134- seed_indices = seed_indices ,
135- # this arg is unused when None but torch.jit requires a list
136- non_seed_indices = non_seed_indices or [],
133+ seeded_seqs = seeded_seqs or {},
137134 ).reshape (batch_size , k )
138135
139136 return accepted , recovered_token_ids
@@ -143,7 +140,7 @@ def _get_accepted(
143140 target_probs : torch .Tensor , # [batch_size, k, vocab_size]
144141 draft_probs : torch .Tensor , # [batch_size, k, vocab_size]
145142 draft_token_ids : torch .Tensor , # [batch_size, k]
146- generators : List [ Optional [torch .Generator ]],
143+ seeded_seqs : Optional [Dict [ int , torch .Generator ]],
147144 ) -> torch .Tensor :
148145 r"""Create bool matrix over the proposed draft tokens. If
149146 True, then a token can be accepted, else it should be
@@ -178,24 +175,26 @@ def _get_accepted(
178175 selected_target_probs = target_probs [batch_indices , probs_indicies ,
179176 draft_token_ids ]
180177
181- seed_indices , non_seed_indices = self ._split_batch_by_seeded (
182- generators )
183-
184- if len (seed_indices ) == 0 :
178+ if not seeded_seqs :
185179 uniform_rand = torch .rand_like (selected_target_probs )
186180 else :
187181 uniform_rand = torch .empty_like (selected_target_probs )
188182
189- for idx in seed_indices :
190- uniform_rand [idx , :] = torch .rand (1 ,
191- k ,
192- dtype = self .probs_dtype ,
193- device = target_probs .device ,
194- generator = generators [idx ])
195-
196- if non_seed_indices :
197- uniform_rand [non_seed_indices , :] = torch .rand (
198- len (non_seed_indices ),
183+ non_seeded_indices = []
184+ for idx in range (batch_size ):
185+ generator = seeded_seqs .get (idx )
186+ if generator is None :
187+ non_seeded_indices .append (idx )
188+ else :
189+ uniform_rand [idx , :] = torch .rand (
190+ 1 ,
191+ k ,
192+ dtype = self .probs_dtype ,
193+ device = target_probs .device ,
194+ generator = generator )
195+ if non_seeded_indices :
196+ uniform_rand [non_seeded_indices , :] = torch .rand (
197+ len (non_seeded_indices ),
199198 k ,
200199 dtype = self .probs_dtype ,
201200 device = target_probs .device )
@@ -272,27 +271,6 @@ def _smallest_positive_value(self) -> float:
272271 """
273272 return torch .finfo (self .probs_dtype ).tiny
274273
275- # partition batch into indices for which a generator is provided
276- # and indicies for which no generator is provided
277- @staticmethod
278- def _split_batch_by_seeded (
279- generators : List [Optional [torch .Generator ]],
280- k : int = 1 ,
281- ) -> Tuple [List [int ], Optional [List [int ]]]:
282-
283- if all (generator is None for generator in generators ):
284- seed_indices : List [int ] = []
285- non_seed_indices : Optional [List [int ]] = None
286- else :
287- seed_indices , non_seed_indices = [], []
288- for i , generator in enumerate (generators ):
289- if generator is None :
290- non_seed_indices .extend (range (k * i , k * (i + 1 )))
291- else :
292- seed_indices .extend (range (k * i , k * (i + 1 )))
293-
294- return seed_indices , non_seed_indices
295-
296274
297275# torch.multinomial forces a GPU<->CPU sync.
298276# Therefore, we use an optimized implementation instead that skips the sync.
@@ -304,9 +282,7 @@ def _multinomial(
304282 probs : torch .Tensor ,
305283 num_samples : int ,
306284 k : int ,
307- generators : List [Optional [torch .Generator ]],
308- seed_indices : List [int ],
309- non_seed_indices : List [int ],
285+ seeded_seqs : Dict [int , torch .Generator ],
310286) -> torch .Tensor :
311287
312288 if num_samples > 1 :
@@ -315,13 +291,20 @@ def _multinomial(
315291 probs = probs [:, None , :].expand (probs .shape [0 ], num_samples ,
316292 probs .shape [1 ]).contiguous ().view (
317293 - 1 , probs .shape [1 ])
318-
319294 q = torch .empty_like (probs )
320- if len ( seed_indices ) == 0 :
295+ if not seeded_seqs :
321296 q .exponential_ (1.0 )
322297 else :
323- q [non_seed_indices ].exponential_ (1.0 )
324- for idx in seed_indices :
325- q [idx ].exponential_ (1.0 , generator = generators [idx // k ])
298+ non_seeded_indices : List [int ] = []
299+ start = 0
300+ for idx in range (len (q ) // k ):
301+ end = start + k
302+ generator = seeded_seqs .get (idx )
303+ if generator is None :
304+ non_seeded_indices .extend (list (range (start , end )))
305+ else :
306+ q [start :end ].exponential_ (1.0 , generator = generator )
307+ start = end
308+ q [non_seeded_indices ].exponential_ (1.0 )
326309
327310 return probs .div_ (q ).argmax (dim = 1 ).view (- 1 , num_samples )
0 commit comments