@@ -118,15 +118,12 @@ def t5_stack_forward(
118118 # required mask seq length can be calculated via length of past
119119 mask_seq_length = past_key_values [0 ][0 ].shape [2 ] + seq_length if past_key_values is not None else seq_length
120120
121- if attention_mask is None :
122- attention_mask = torch .ones (batch_size , mask_seq_length , device = device )
123- if in_decoder and encoder_attention_mask is None and encoder_hidden_states is not None :
124- encoder_seq_length = encoder_hidden_states .shape [1 ]
125- encoder_attention_mask = torch .ones (batch_size , encoder_seq_length , device = device , dtype = torch .long )
126-
127121 # initialize past_key_values with `None` if past does not exist
128122 if past_key_values is None :
129123 past_key_values = [None ] * len (self .block )
124+
125+ if attention_mask is None :
126+ attention_mask = torch .ones (batch_size , mask_seq_length , device = device )
130127
131128 # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
132129 # ourselves in which case we just need to make it broadcastable to all heads.
@@ -138,7 +135,9 @@ def t5_stack_forward(
138135 encoder_batch_size , encoder_sequence_length , _ = encoder_hidden_states .size ()
139136 encoder_hidden_shape = (encoder_batch_size , encoder_sequence_length )
140137 if encoder_attention_mask is None :
141- encoder_attention_mask = torch .ones (encoder_hidden_shape , device = inputs_embeds .device )
138+ encoder_attention_mask = torch .ones (
139+ encoder_hidden_shape , device = inputs_embeds .device , dtype = torch .long
140+ )
142141 encoder_extended_attention_mask = self .invert_attention_mask (encoder_attention_mask )
143142 else :
144143 encoder_extended_attention_mask = None
@@ -162,15 +161,8 @@ def t5_stack_forward(
162161 torch .cuda .set_device (hidden_states .device )
163162
164163 if self .gradient_checkpointing and self .training :
165-
166- def create_custom_forward (module ):
167- def custom_forward (* inputs ):
168- return tuple (module (* inputs , use_cache , output_attentions ))
169-
170- return custom_forward
171-
172- layer_outputs = checkpoint (
173- create_custom_forward (layer_module ),
164+ layer_outputs = self ._gradient_checkpointing_func (
165+ layer_module .forward ,
174166 hidden_states ,
175167 extended_attention_mask ,
176168 position_bias ,
@@ -180,6 +172,8 @@ def custom_forward(*inputs):
180172 layer_head_mask ,
181173 cross_attn_layer_head_mask ,
182174 None , # past_key_value is always None with gradient checkpointing
175+ use_cache ,
176+ output_attentions ,
183177 )
184178 else :
185179 layer_outputs = layer_module (
0 commit comments