@@ -406,6 +406,7 @@ class RwkvPreTrainedModel(PreTrainedModel):
406406 base_model_prefix = "rwkv"
407407 _no_split_modules = ["RwkvBlock" ]
408408 _keep_in_fp32_modules = ["time_decay" , "time_first" ]
409+ supports_gradient_checkpointing = True
409410
410411 def _init_weights (self , module ):
411412 """Initialize the weights."""
@@ -605,6 +606,8 @@ def __init__(self, config):
605606
606607 self .layers_are_rescaled = False
607608
609+ self .gradient_checkpointing = False
610+
608611 # Initialize weights and apply final processing
609612 self .post_init ()
610613
@@ -659,14 +662,35 @@ def forward(
659662 ]
660663 state [4 ] -= 1e30
661664
665+ if self .gradient_checkpointing and self .training :
666+ if use_cache :
667+ logger .warning_once (
668+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
669+ )
670+ use_cache = False
671+
662672 hidden_states = inputs_embeds
663673
664674 all_self_attentions = () if output_attentions else None
665675 all_hidden_states = () if output_hidden_states else None
666676 for idx , block in enumerate (self .blocks ):
667- hidden_states , state , attentions = block (
668- hidden_states , state = state , use_cache = use_cache , output_attentions = output_attentions
669- )
677+ if self .gradient_checkpointing and self .training :
678+
679+ def create_custom_forward (module ):
680+ def custom_forward (* inputs ):
681+ # None for past_key_value
682+ return module (* inputs , use_cache = use_cache , output_attentions = output_attentions )
683+
684+ return custom_forward
685+
686+ hidden_states , state , attentions = torch .utils .checkpoint .checkpoint (
687+ create_custom_forward (block ), hidden_states , state
688+ )
689+ else :
690+ hidden_states , state , attentions = block (
691+ hidden_states , state = state , use_cache = use_cache , output_attentions = output_attentions
692+ )
693+
670694 if (
671695 self .layers_are_rescaled
672696 and self .config .rescale_every > 0
0 commit comments