77from colossalai .context .parallel_mode import ParallelMode
88from colossalai .core import global_context as gpc
99from colossalai .utils import get_current_device
10+ from torch .cuda .amp import custom_bwd , custom_fwd
1011
1112
1213def matmul_2d (a ,
@@ -60,6 +61,7 @@ class Matmul_AB_2D(torch.autograd.Function):
6061 """Matrix multiplication for :math:`C = AB`
6162 """
6263 @staticmethod
64+ @custom_fwd (cast_inputs = torch .float16 )
6365 def forward (ctx : Any ,
6466 A : Tensor ,
6567 B : Tensor ,
@@ -120,39 +122,40 @@ def forward(ctx: Any,
120122 return out
121123
122124 @staticmethod
125+ @custom_bwd
123126 def backward (ctx : Any , output_grad : Tensor ) -> Tuple [Tensor , ...]:
124127 A , B = ctx .saved_tensors
125- A_grad = Matmul_ABT_2D .forward (
126- None ,
127- output_grad , B ,
128- ctx .summa_dim , ctx .A_shape ,
129- ctx .row_rank , ctx .col_rank ,
130- ctx .row_parallel_mode ,
131- ctx .col_parallel_mode ,
132- ctx .data_parallel_rank ,
133- ctx .pipeline_parallel_rank ,
134- ctx .pipeline_parallel_size ,
135- ctx .tensor_parallel_size
136- )
137- B_grad = Matmul_ATB_2D .forward (
138- None ,
139- A , output_grad ,
140- ctx .summa_dim , ctx .B_shape ,
141- ctx .row_rank , ctx .col_rank ,
142- ctx .row_parallel_mode ,
143- ctx .col_parallel_mode ,
144- ctx .data_parallel_rank ,
145- ctx .pipeline_parallel_rank ,
146- ctx .pipeline_parallel_size ,
147- ctx .tensor_parallel_size
148- )
128+ with torch .no_grad ():
129+ A_grad = Matmul_ABT_2D .apply (
130+ output_grad , B ,
131+ ctx .summa_dim , ctx .A_shape ,
132+ ctx .row_rank , ctx .col_rank ,
133+ ctx .row_parallel_mode ,
134+ ctx .col_parallel_mode ,
135+ ctx .data_parallel_rank ,
136+ ctx .pipeline_parallel_rank ,
137+ ctx .pipeline_parallel_size ,
138+ ctx .tensor_parallel_size
139+ )
140+ B_grad = Matmul_ATB_2D .apply (
141+ A , output_grad ,
142+ ctx .summa_dim , ctx .B_shape ,
143+ ctx .row_rank , ctx .col_rank ,
144+ ctx .row_parallel_mode ,
145+ ctx .col_parallel_mode ,
146+ ctx .data_parallel_rank ,
147+ ctx .pipeline_parallel_rank ,
148+ ctx .pipeline_parallel_size ,
149+ ctx .tensor_parallel_size
150+ )
149151 return A_grad , B_grad , None , None , None , None , None , None , None , None , None , None
150152
151153
152154class Matmul_ABT_2D (torch .autograd .Function ):
153155 """Matrix multiplication for :math:`C = AB^T`
154156 """
155157 @staticmethod
158+ @custom_fwd (cast_inputs = torch .float16 )
156159 def forward (ctx : Any ,
157160 A : Tensor ,
158161 B : Tensor ,
@@ -214,39 +217,41 @@ def forward(ctx: Any,
214217 return out
215218
216219 @staticmethod
220+ @custom_bwd
217221 def backward (ctx : Any , output_grad : Tensor ) -> Tuple [Tensor , ...]:
218222 A , B = ctx .saved_tensors
219- A_grad = Matmul_AB_2D . forward (
220- None ,
221- output_grad , B ,
222- ctx . summa_dim , ctx . A_shape ,
223- ctx .row_rank , ctx .col_rank ,
224- ctx .row_parallel_mode ,
225- ctx .col_parallel_mode ,
226- ctx .data_parallel_rank ,
227- ctx .pipeline_parallel_rank ,
228- ctx .pipeline_parallel_size ,
229- ctx .tensor_parallel_size
230- )
231- B_grad = Matmul_ATB_2D . forward (
232- None ,
233- output_grad , A ,
234- ctx .summa_dim , ctx .B_shape ,
235- ctx .row_rank , ctx .col_rank ,
236- ctx .row_parallel_mode ,
237- ctx .col_parallel_mode ,
238- ctx .data_parallel_rank ,
239- ctx .pipeline_parallel_rank ,
240- ctx .pipeline_parallel_size ,
241- ctx .tensor_parallel_size
242- )
223+
224+ with torch . no_grad ():
225+ A_grad = Matmul_AB_2D . apply (
226+ output_grad , B ,
227+ ctx .summa_dim , ctx .A_shape ,
228+ ctx .row_rank , ctx . col_rank ,
229+ ctx .row_parallel_mode ,
230+ ctx .col_parallel_mode ,
231+ ctx .data_parallel_rank ,
232+ ctx .pipeline_parallel_rank ,
233+ ctx .pipeline_parallel_size ,
234+ ctx . tensor_parallel_size
235+ )
236+ B_grad = Matmul_ATB_2D . apply (
237+ output_grad , A ,
238+ ctx .summa_dim , ctx .B_shape ,
239+ ctx .row_rank , ctx .col_rank ,
240+ ctx .row_parallel_mode ,
241+ ctx .col_parallel_mode ,
242+ ctx .data_parallel_rank ,
243+ ctx .pipeline_parallel_rank ,
244+ ctx .pipeline_parallel_size ,
245+ ctx .tensor_parallel_size
246+ )
243247 return A_grad , B_grad , None , None , None , None , None , None , None , None , None , None
244248
245249
246250class Matmul_ATB_2D (torch .autograd .Function ):
247251 """Matrix multiplication for :math:`C = A^TB`
248252 """
249253 @staticmethod
254+ @custom_fwd (cast_inputs = torch .float16 )
250255 def forward (ctx : Any ,
251256 A : Tensor ,
252257 B : Tensor ,
@@ -308,39 +313,41 @@ def forward(ctx: Any,
308313 return out
309314
310315 @staticmethod
316+ @custom_bwd
311317 def backward (ctx : Any , output_grad : Tensor ) -> Tuple [Tensor , ...]:
312318 A , B = ctx .saved_tensors
313- A_grad = Matmul_ABT_2D . forward (
314- None ,
315- B , output_grad ,
316- ctx . summa_dim , ctx . A_shape ,
317- ctx .row_rank , ctx .col_rank ,
318- ctx .row_parallel_mode ,
319- ctx .col_parallel_mode ,
320- ctx .data_parallel_rank ,
321- ctx .pipeline_parallel_rank ,
322- ctx .pipeline_parallel_size ,
323- ctx .tensor_parallel_size
324- )
325- B_grad = Matmul_AB_2D . forward (
326- None ,
327- A , output_grad ,
328- ctx .summa_dim , ctx .B_shape ,
329- ctx .row_rank , ctx .col_rank ,
330- ctx .row_parallel_mode ,
331- ctx .col_parallel_mode ,
332- ctx .data_parallel_rank ,
333- ctx .pipeline_parallel_rank ,
334- ctx .pipeline_parallel_size ,
335- ctx .tensor_parallel_size
336- )
319+
320+ with torch . no_grad ():
321+ A_grad = Matmul_ABT_2D . apply (
322+ B , output_grad ,
323+ ctx .summa_dim , ctx .A_shape ,
324+ ctx .row_rank , ctx . col_rank ,
325+ ctx .row_parallel_mode ,
326+ ctx .col_parallel_mode ,
327+ ctx .data_parallel_rank ,
328+ ctx .pipeline_parallel_rank ,
329+ ctx .pipeline_parallel_size ,
330+ ctx . tensor_parallel_size
331+ )
332+ B_grad = Matmul_AB_2D . apply (
333+ A , output_grad ,
334+ ctx .summa_dim , ctx .B_shape ,
335+ ctx .row_rank , ctx .col_rank ,
336+ ctx .row_parallel_mode ,
337+ ctx .col_parallel_mode ,
338+ ctx .data_parallel_rank ,
339+ ctx .pipeline_parallel_rank ,
340+ ctx .pipeline_parallel_size ,
341+ ctx .tensor_parallel_size
342+ )
337343 return A_grad , B_grad , None , None , None , None , None , None , None , None , None , None
338344
339345
340346class Add_Bias_2D (torch .autograd .Function ):
341347 """Matrix add bias: :math:`C = A + b`
342348 """
343349 @staticmethod
350+ @custom_fwd (cast_inputs = torch .float16 )
344351 def forward (ctx : Any ,
345352 input : Tensor ,
346353 bias : Tensor ,
@@ -384,6 +391,7 @@ def forward(ctx: Any,
384391 return output
385392
386393 @staticmethod
394+ @custom_bwd
387395 def backward (ctx : Any , output_grad : Tensor ) -> Tuple [Tensor , ...]:
388396 row_rank = ctx .row_rank
389397 col_rank = ctx .col_rank
@@ -423,6 +431,7 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
423431class _LayerNorm_2D (torch .autograd .Function ):
424432
425433 @staticmethod
434+ @custom_fwd (cast_inputs = torch .float32 )
426435 def forward (ctx : Any ,
427436 input : Tensor ,
428437 E_x : Tensor ,
@@ -440,6 +449,7 @@ def forward(ctx: Any,
440449 return output
441450
442451 @staticmethod
452+ @custom_bwd
443453 def backward (ctx : Any , output_grad : Tensor ) -> Tuple [Tensor , ...]:
444454 row_parallel_mode = ctx .row_parallel_mode
445455 col_parallel_mode = ctx .col_parallel_mode
@@ -492,6 +502,7 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
492502class _ViT_Split_Input_2D (torch .autograd .Function ):
493503
494504 @staticmethod
505+ @custom_fwd (cast_inputs = torch .float16 )
495506 def forward (ctx : Any ,
496507 inputs : Tensor ,
497508 batch_size : int ,
@@ -509,6 +520,7 @@ def forward(ctx: Any,
509520 return output
510521
511522 @staticmethod
523+ @custom_bwd
512524 def backward (ctx : Any , output_grad : Tensor ) -> Tuple [Tensor , ...]:
513525 # output_grad: [b/q, s, h/q]
514526 # grads: [b, s, h/q]
0 commit comments