Skip to content

Commit 89a1f34

Browse files
[RWKV] Add Gradient Checkpointing support for RWKV (#24955)
add GC support for RWKV
1 parent 9f912ef commit 89a1f34

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

src/transformers/models/rwkv/modeling_rwkv.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ class RwkvPreTrainedModel(PreTrainedModel):
406406
base_model_prefix = "rwkv"
407407
_no_split_modules = ["RwkvBlock"]
408408
_keep_in_fp32_modules = ["time_decay", "time_first"]
409+
supports_gradient_checkpointing = True
409410

410411
def _init_weights(self, module):
411412
"""Initialize the weights."""
@@ -605,6 +606,8 @@ def __init__(self, config):
605606

606607
self.layers_are_rescaled = False
607608

609+
self.gradient_checkpointing = False
610+
608611
# Initialize weights and apply final processing
609612
self.post_init()
610613

@@ -659,14 +662,35 @@ def forward(
659662
]
660663
state[4] -= 1e30
661664

665+
if self.gradient_checkpointing and self.training:
666+
if use_cache:
667+
logger.warning_once(
668+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
669+
)
670+
use_cache = False
671+
662672
hidden_states = inputs_embeds
663673

664674
all_self_attentions = () if output_attentions else None
665675
all_hidden_states = () if output_hidden_states else None
666676
for idx, block in enumerate(self.blocks):
667-
hidden_states, state, attentions = block(
668-
hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions
669-
)
677+
if self.gradient_checkpointing and self.training:
678+
679+
def create_custom_forward(module):
680+
def custom_forward(*inputs):
681+
# None for past_key_value
682+
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
683+
684+
return custom_forward
685+
686+
hidden_states, state, attentions = torch.utils.checkpoint.checkpoint(
687+
create_custom_forward(block), hidden_states, state
688+
)
689+
else:
690+
hidden_states, state, attentions = block(
691+
hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions
692+
)
693+
670694
if (
671695
self.layers_are_rescaled
672696
and self.config.rescale_every > 0

0 commit comments

Comments
 (0)