Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPT2] Fix gradient checkpointing bug #21772

Merged
merged 5 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,13 @@ def forward(

output_shape = input_shape + (hidden_states.size(-1),)

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
Expand All @@ -627,11 +634,6 @@ def forward(
all_hidden_states = all_hidden_states + (hidden_states,)

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
12 changes: 7 additions & 5 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,13 @@ def forward(

output_shape = input_shape + (hidden_states.size(-1),)

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
Expand All @@ -871,11 +878,6 @@ def forward(
all_hidden_states = all_hidden_states + (hidden_states,)

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
13 changes: 8 additions & 5 deletions src/transformers/models/prophetnet/modeling_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,14 @@ def forward(
all_main_stream_attns = () if output_attentions else None
all_ngram_stream_attns = () if output_attentions else None
all_cross_attns = () if output_attentions and self.config.add_cross_attention else None

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

present_key_values = () if use_cache else None

# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
Expand All @@ -1588,11 +1596,6 @@ def forward(
past_key_value = past_key_values[idx] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1592,6 +1592,14 @@ def forward(
all_main_stream_attns = () if output_attentions else None
all_ngram_stream_attns = () if output_attentions else None
all_cross_attns = () if output_attentions and self.config.add_cross_attention else None

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

present_key_values = () if use_cache else None

# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
Expand All @@ -1611,11 +1619,6 @@ def forward(
past_key_value = past_key_values[idx] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down