@@ -123,11 +123,9 @@ def gptj_model_forward(
123123 head_mask = self .get_head_mask (head_mask , self .config .n_layer )
124124
125125 # position id to be assigned not just for the first stage for attn input
126- if position_ids is not None :
127- position_ids = position_ids .view (- 1 , seq_length )
128- else :
126+ if position_ids is None :
129127 position_ids = torch .arange (0 , seq_length , dtype = torch .long , device = device )
130- position_ids = position_ids .unsqueeze (0 ). view ( - 1 , input_shape [ - 1 ])
128+ position_ids = position_ids .unsqueeze (0 )
131129 if stage_manager .is_first_stage ():
132130 if inputs_embeds is None :
133131 inputs_embeds = self .wte (input_ids )
@@ -172,21 +170,15 @@ def gptj_model_forward(
172170 all_hidden_states = all_hidden_states + (hidden_states ,)
173171
174172 if self .gradient_checkpointing and self .training :
175-
176- def create_custom_forward (module ):
177- def custom_forward (* inputs ):
178- # None for past_key_value
179- return module (* inputs , use_cache , output_attentions )
180-
181- return custom_forward
182-
183- outputs = torch .utils .checkpoint .checkpoint (
184- create_custom_forward (block ),
173+ outputs = self ._gradient_checkpointing_func (
174+ block .__call__ ,
185175 hidden_states ,
186176 None ,
187177 attention_mask ,
188178 position_ids ,
189179 head_mask [i ],
180+ use_cache ,
181+ output_attentions ,
190182 )
191183 else :
192184 outputs = block (
@@ -603,7 +595,9 @@ def forward(
603595 value = torch .cat ((past_value , value ), dim = 1 )
604596
605597 if use_cache is True :
606- present = (key , value )
598+ # Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation.
599+ # Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128
600+ present = (key .to (hidden_states .dtype ), value )
607601 else :
608602 present = None
609603
0 commit comments