@@ -148,7 +148,8 @@ def _split_by_proposal_len(
148
148
nonzero_proposal_len_indices ,
149
149
)
150
150
151
- def _remove_no_proposal_seqs (self , proposal_lens , maybe_sampler_output ,
151
+ @staticmethod
152
+ def _remove_no_proposal_seqs (proposal_lens , maybe_sampler_output ,
152
153
nonzero_proposal_len_indices , transposed ):
153
154
"""Remove sequences from nonzero_proposal_len_indices and reset
154
155
their proposal_len to 0 the draft worker does not provide a proposal
@@ -207,7 +208,7 @@ def _merge_outputs(
207
208
self ,
208
209
batch_size : int ,
209
210
proposal_len : int ,
210
- maybe_sampler_output : Optional [SamplerOutput ],
211
+ maybe_sampler_output : Optional [List [ SamplerOutput ] ],
211
212
proposal_lens : List [int ],
212
213
nonzero_proposal_len_indices : List [int ],
213
214
sampler_transposed : bool ,
@@ -218,25 +219,19 @@ def _merge_outputs(
218
219
if maybe_sampler_output is None :
219
220
# If no speculative tokens, the sampler output will be None.
220
221
# In this case we return empty proposals.
221
- proposal_tokens = torch .full (
222
- size = (
223
- batch_size ,
224
- proposal_len ,
225
- ),
226
- fill_value = - 1 ,
227
- dtype = torch .long ,
228
- device = self ._device ,
229
- )
230
- proposal_probs = torch .zeros (
231
- batch_size ,
232
- proposal_len ,
233
- self ._vocab_size ,
234
- dtype = torch .float32 ,
235
- device = self ._device ,
236
- )
237
- proposal_lens_tensor = torch .zeros (len (proposal_lens ),
238
- dtype = torch .long ,
239
- device = self ._device )
222
+ proposal_tokens = torch .tensor (- 1 ,
223
+ dtype = torch .long ,
224
+ device = self ._device ).expand (
225
+ batch_size , proposal_len )
226
+ proposal_probs = torch .tensor (0 ,
227
+ dtype = torch .float32 ,
228
+ device = self ._device ).expand (
229
+ batch_size , proposal_len ,
230
+ self ._vocab_size )
231
+ proposal_lens_tensor = torch .tensor (0 ,
232
+ dtype = torch .long ,
233
+ device = self ._device ).expand (
234
+ len (proposal_lens ))
240
235
return proposal_tokens , proposal_probs , proposal_lens_tensor
241
236
242
237
sampler_output = maybe_sampler_output
@@ -246,18 +241,14 @@ def _merge_outputs(
246
241
# Now, reformat the output GPU tensors such that each sequence has
247
242
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
248
243
249
- entire_proposal_tokens = torch . full (
244
+ entire_proposal_tokens = proposal_tokens . new_full (
250
245
size = (batch_size , * proposal_tokens .shape [1 :]),
251
246
fill_value = - 1 ,
252
- dtype = torch .long ,
253
- device = self ._device ,
254
247
)
255
248
entire_proposal_tokens [nonzero_proposal_len_indices ] = proposal_tokens
256
- entire_proposal_probs = torch . zeros (
249
+ entire_proposal_probs = proposal_probs . new_zeros (
257
250
batch_size ,
258
251
* proposal_probs .shape [1 :],
259
- dtype = torch .float32 ,
260
- device = self ._device ,
261
252
)
262
253
entire_proposal_probs [nonzero_proposal_len_indices ] = proposal_probs
263
254
0 commit comments