Skip to content

Commit d16f4ed

Browse files
committed
[shardformer] update t5 model (hpcaitech#5524)
1 parent 1a1fa6e commit d16f4ed

File tree

1 file changed

+10
-16
lines changed
  • colossalai/shardformer/modeling

1 file changed

+10
-16
lines changed

colossalai/shardformer/modeling/t5.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,12 @@ def t5_stack_forward(
118118
# required mask seq length can be calculated via length of past
119119
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
120120

121-
if attention_mask is None:
122-
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
123-
if in_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
124-
encoder_seq_length = encoder_hidden_states.shape[1]
125-
encoder_attention_mask = torch.ones(batch_size, encoder_seq_length, device=device, dtype=torch.long)
126-
127121
# initialize past_key_values with `None` if past does not exist
128122
if past_key_values is None:
129123
past_key_values = [None] * len(self.block)
124+
125+
if attention_mask is None:
126+
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
130127

131128
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
132129
# ourselves in which case we just need to make it broadcastable to all heads.
@@ -138,7 +135,9 @@ def t5_stack_forward(
138135
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
139136
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
140137
if encoder_attention_mask is None:
141-
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
138+
encoder_attention_mask = torch.ones(
139+
encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
140+
)
142141
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
143142
else:
144143
encoder_extended_attention_mask = None
@@ -162,15 +161,8 @@ def t5_stack_forward(
162161
torch.cuda.set_device(hidden_states.device)
163162

164163
if self.gradient_checkpointing and self.training:
165-
166-
def create_custom_forward(module):
167-
def custom_forward(*inputs):
168-
return tuple(module(*inputs, use_cache, output_attentions))
169-
170-
return custom_forward
171-
172-
layer_outputs = checkpoint(
173-
create_custom_forward(layer_module),
164+
layer_outputs = self._gradient_checkpointing_func(
165+
layer_module.forward,
174166
hidden_states,
175167
extended_attention_mask,
176168
position_bias,
@@ -180,6 +172,8 @@ def custom_forward(*inputs):
180172
layer_head_mask,
181173
cross_attn_layer_head_mask,
182174
None, # past_key_value is always None with gradient checkpointing
175+
use_cache,
176+
output_attentions,
183177
)
184178
else:
185179
layer_outputs = layer_module(

0 commit comments

Comments
 (0)