Skip to content

Commit

Permalink
Move logits.float() call (#308)
Browse files Browse the repository at this point in the history
## Summary
The analogous `logits.float()` calls were moved in the Hugging Face
modeling source code to be inside the `if labels is not None` block to
avoid upcasting logits unless they are being used in a loss calculation;
this avoids a memory spike during inference if the model is in lower
precision.

*
https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/llama/modeling_llama.py#L1211-L1212
*
https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/mixtral/modeling_mixtral.py#L1329-L1330
*
https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/phi3/modeling_phi3.py#L1303-L1304
*
https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/qwen2/modeling_qwen2.py#L1206-L1207

Some of your models already have this change:


https://github.com/linkedin/Liger-Kernel/blob/ff6650bbcef5d31b7522694cbeb73a21169460e9/src/liger_kernel/transformers/model/mistral.py#L114-L116


https://github.com/linkedin/Liger-Kernel/blob/ff6650bbcef5d31b7522694cbeb73a21169460e9/src/liger_kernel/transformers/model/gemma.py#L114-L116

See also:

* huggingface/transformers#30860

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
  • Loading branch information
ringohoffman authored Oct 14, 2024
1 parent ff6650b commit 04d5a0e
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/liger_kernel/transformers/model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ def lce_forward(
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down
3 changes: 2 additions & 1 deletion src/liger_kernel/transformers/model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def lce_forward(

hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()

loss = None
if self.training and (labels is not None):
Expand All @@ -116,6 +115,8 @@ def lce_forward(
lce = LigerFusedLinearCrossEntropyLoss()
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
elif labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down
3 changes: 2 additions & 1 deletion src/liger_kernel/transformers/model/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,11 @@ def lce_forward(
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()

loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down
3 changes: 2 additions & 1 deletion src/liger_kernel/transformers/model/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ def lce_forward(

else:
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down
3 changes: 2 additions & 1 deletion src/liger_kernel/transformers/model/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ def lce_forward(
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down

0 comments on commit 04d5a0e

Please sign in to comment.