@@ -232,27 +232,29 @@ def forward(
232
232
if self .apply_output :
233
233
logits = self .output (h )
234
234
235
- if self .output_prune_map is not None :
236
- # expand to original size so that downstream applications can use the logits as-is.
237
- if self .generate_full_logits :
238
- # (1, seq_len, pruned_size) -> (1, seq_len, original_size)
239
- expanded_logits = torch .full (
240
- [logits .shape [0 ], logits .shape [1 ], self .vocab_size ],
241
- float ("-inf" ),
242
- device = logits .device ,
243
- dtype = logits .dtype ,
244
- )
245
- expanded_logits [:, :, list (self .output_prune_map .values ())] = logits
246
- else :
247
- # (1, pruned_size) -> (1, original_size)
248
- expanded_logits = torch .full (
249
- [logits .shape [0 ], self .vocab_size ],
250
- float ("-inf" ),
251
- device = logits .device ,
252
- dtype = logits .dtype ,
253
- )
254
- expanded_logits [:, list (self .output_prune_map .values ())] = logits
255
- logits = expanded_logits
235
+ if self .output_prune_map is not None :
236
+ # expand to original size so that downstream applications can use the logits as-is.
237
+ if self .generate_full_logits :
238
+ # (1, seq_len, pruned_size) -> (1, seq_len, original_size)
239
+ expanded_logits = torch .full (
240
+ [logits .shape [0 ], logits .shape [1 ], self .vocab_size ],
241
+ float ("-inf" ),
242
+ device = logits .device ,
243
+ dtype = logits .dtype ,
244
+ )
245
+ expanded_logits [:, :, list (self .output_prune_map .values ())] = logits
246
+ else :
247
+ # (1, pruned_size) -> (1, original_size)
248
+ expanded_logits = torch .full (
249
+ [logits .shape [0 ], self .vocab_size ],
250
+ float ("-inf" ),
251
+ device = logits .device ,
252
+ dtype = logits .dtype ,
253
+ )
254
+ expanded_logits [:, list (self .output_prune_map .values ())] = logits
255
+ logits = expanded_logits
256
+ else :
257
+ logits = h
256
258
257
259
if attn_options_update is not None :
258
260
return logits , attn_options_update
0 commit comments