Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions colossalai/shardformer/modeling/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,12 @@ def t5_stack_forward(
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length

if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)
if in_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(batch_size, encoder_seq_length, device=device, dtype=torch.long)

# initialize past_key_values with `None` if past does not exist
if past_key_values is None:
past_key_values = [None] * len(self.block)

if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=device)

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
Expand All @@ -138,7 +135,9 @@ def t5_stack_forward(
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
Expand All @@ -162,15 +161,8 @@ def t5_stack_forward(
torch.cuda.set_device(hidden_states.device)

if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
return tuple(module(*inputs, use_cache, output_attentions))

return custom_forward

layer_outputs = checkpoint(
create_custom_forward(layer_module),
layer_outputs = self._gradient_checkpointing_func(
layer_module.forward,
hidden_states,
extended_attention_mask,
position_bias,
Expand All @@ -180,6 +172,8 @@ def custom_forward(*inputs):
layer_head_mask,
cross_attn_layer_head_mask,
None, # past_key_value is always None with gradient checkpointing
use_cache,
output_attentions,
)
else:
layer_outputs = layer_module(
Expand Down