Skip to content

Commit 9a5edc3

Browse files
authored
[shardformer] update gptj model (#5503)
1 parent fd44440 commit 9a5edc3

File tree

1 file changed

+9
-15
lines changed
  • colossalai/shardformer/modeling

1 file changed

+9
-15
lines changed

colossalai/shardformer/modeling/gptj.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,9 @@ def gptj_model_forward(
123123
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
124124

125125
# position id to be assigned not just for the first stage for attn input
126-
if position_ids is not None:
127-
position_ids = position_ids.view(-1, seq_length)
128-
else:
126+
if position_ids is None:
129127
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
130-
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
128+
position_ids = position_ids.unsqueeze(0)
131129
if stage_manager.is_first_stage():
132130
if inputs_embeds is None:
133131
inputs_embeds = self.wte(input_ids)
@@ -172,21 +170,15 @@ def gptj_model_forward(
172170
all_hidden_states = all_hidden_states + (hidden_states,)
173171

174172
if self.gradient_checkpointing and self.training:
175-
176-
def create_custom_forward(module):
177-
def custom_forward(*inputs):
178-
# None for past_key_value
179-
return module(*inputs, use_cache, output_attentions)
180-
181-
return custom_forward
182-
183-
outputs = torch.utils.checkpoint.checkpoint(
184-
create_custom_forward(block),
173+
outputs = self._gradient_checkpointing_func(
174+
block.__call__,
185175
hidden_states,
186176
None,
187177
attention_mask,
188178
position_ids,
189179
head_mask[i],
180+
use_cache,
181+
output_attentions,
190182
)
191183
else:
192184
outputs = block(
@@ -603,7 +595,9 @@ def forward(
603595
value = torch.cat((past_value, value), dim=1)
604596

605597
if use_cache is True:
606-
present = (key, value)
598+
# Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation.
599+
# Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128
600+
present = (key.to(hidden_states.dtype), value)
607601
else:
608602
present = None
609603

0 commit comments

Comments
 (0)