Skip to content

Commit

Permalink
Fix the gradient checkpointing bug of the llama model (huggingface#22270
Browse files Browse the repository at this point in the history
)

fix grad ckpt bug of llama
  • Loading branch information
yqy2001 authored Mar 20, 2023
1 parent cf0af9a commit 89f0fda
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def _init_weights(self, module):
module.weight.data[module.padding_idx].zero_()

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LlamaDecoderLayer)):
if isinstance(module, LlamaModel):
module.gradient_checkpointing = value


Expand Down

0 comments on commit 89f0fda

Please sign in to comment.