Skip to content

Commit 93838e8

Browse files
authored
Fix pyre error for logits
Differential Revision: D70183946 Pull Request resolved: #8687
1 parent 0fb4f85 commit 93838e8

File tree

1 file changed

+23
-21
lines changed

1 file changed

+23
-21
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -232,27 +232,29 @@ def forward(
232232
if self.apply_output:
233233
logits = self.output(h)
234234

235-
if self.output_prune_map is not None:
236-
# expand to original size so that downstream applications can use the logits as-is.
237-
if self.generate_full_logits:
238-
# (1, seq_len, pruned_size) -> (1, seq_len, original_size)
239-
expanded_logits = torch.full(
240-
[logits.shape[0], logits.shape[1], self.vocab_size],
241-
float("-inf"),
242-
device=logits.device,
243-
dtype=logits.dtype,
244-
)
245-
expanded_logits[:, :, list(self.output_prune_map.values())] = logits
246-
else:
247-
# (1, pruned_size) -> (1, original_size)
248-
expanded_logits = torch.full(
249-
[logits.shape[0], self.vocab_size],
250-
float("-inf"),
251-
device=logits.device,
252-
dtype=logits.dtype,
253-
)
254-
expanded_logits[:, list(self.output_prune_map.values())] = logits
255-
logits = expanded_logits
235+
if self.output_prune_map is not None:
236+
# expand to original size so that downstream applications can use the logits as-is.
237+
if self.generate_full_logits:
238+
# (1, seq_len, pruned_size) -> (1, seq_len, original_size)
239+
expanded_logits = torch.full(
240+
[logits.shape[0], logits.shape[1], self.vocab_size],
241+
float("-inf"),
242+
device=logits.device,
243+
dtype=logits.dtype,
244+
)
245+
expanded_logits[:, :, list(self.output_prune_map.values())] = logits
246+
else:
247+
# (1, pruned_size) -> (1, original_size)
248+
expanded_logits = torch.full(
249+
[logits.shape[0], self.vocab_size],
250+
float("-inf"),
251+
device=logits.device,
252+
dtype=logits.dtype,
253+
)
254+
expanded_logits[:, list(self.output_prune_map.values())] = logits
255+
logits = expanded_logits
256+
else:
257+
logits = h
256258

257259
if attn_options_update is not None:
258260
return logits, attn_options_update

0 commit comments

Comments
 (0)