Skip to content

Commit af88570

Browse files
committed
fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
1 parent c8cb9f9 commit af88570

File tree

7 files changed

+112
-111
lines changed

7 files changed

+112
-111
lines changed

colossalai/engine/schedule/_no_pipeline.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from colossalai.nn import (ZeroRedundancyOptimizer_Level_2,
1616
ZeroRedundancyOptimizer_Level_3)
1717
from colossalai.nn.optimizer._utils import clip_grad_norm_fp32
18-
from ._utils import convert_to_fp16
18+
from ._utils import convert_to_fp16, convert_to_fp32
1919
from ._base_schedule import BaseSchedule
2020
from ..amp import AMP_TYPE, GradScaler
2121

@@ -43,12 +43,6 @@ def __init__(
4343
assert amp_type is None or isinstance(amp_type, AMP_TYPE), \
4444
'unrecognised value for argument fp16, it can only be None, torch or apex'
4545

46-
# LSG: check compatibility
47-
# LSG: torch.cuda.amp and apex.amp cannot be used for tensor parallel
48-
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.get_world_size(
49-
ParallelMode.TENSOR) > 1:
50-
assert amp_type != AMP_TYPE.TORCH and amp_type != AMP_TYPE.APEX, \
51-
'You can only AMP_TYPE.PARALLEL for tensor parallel training'
5246
self.use_zero_level_2_3 = False
5347

5448
if amp_type is not None:
@@ -121,18 +115,6 @@ def forward_backward_step(self, forward_only=False, return_loss=True):
121115
data, label = self.load_batch()
122116
loss = None
123117

124-
# LSG: leave for debug, make sure dataloader is deterministic
125-
# if forward_only:
126-
# img = data[0]
127-
# rank = gpc.get_local_rank(ParallelMode.DATA)
128-
# world_size = gpc.get_world_size(ParallelMode.DATA)
129-
# group = gpc.get_group(ParallelMode.DATA)
130-
# input_list = [img.clone() for _ in range(world_size)]
131-
# output_list = [torch.empty_like(img) for _ in range(world_size)]
132-
# output_list[rank] = img.clone()
133-
# dist.all_to_all(output_tensor_list=output_list, input_tensor_list=input_list, group=group)
134-
# assert torch.equal(output_list[0], output_list[1]) # and torch.equal(output_list[1], output_list[2])
135-
136118
# forward
137119
if self.fp16 and self.amp_type == AMP_TYPE.TORCH:
138120
with torch_amp.autocast():
@@ -146,6 +128,10 @@ def forward_backward_step(self, forward_only=False, return_loss=True):
146128
data = convert_to_fp16(data)
147129

148130
output = self.model(*data)
131+
132+
if self.use_zero_level_2_3 or self.amp_type == AMP_TYPE.PARALLEL:
133+
output = convert_to_fp32(output)
134+
149135
if not isinstance(output, (tuple, list)):
150136
output = (output,)
151137
if return_loss:

colossalai/engine/schedule/_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,14 @@ def convert_to_fp16(data: Union[Tensor, List[Tensor]]):
1414
else:
1515
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
1616
return ret
17+
18+
19+
def convert_to_fp32(data: Union[Tensor, List[Tensor]]):
20+
if isinstance(data, Tensor):
21+
ret = data.float()
22+
elif isinstance(data, (list, tuple)):
23+
ret = [val.float() for val in data]
24+
else:
25+
raise TypeError(f"Expected argument 'data' to be a Tensor or a list/tuple of Tensor, but got {type(data)}")
26+
return ret
27+

colossalai/nn/layer/parallel_2d/_operation.py

Lines changed: 84 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from colossalai.context.parallel_mode import ParallelMode
88
from colossalai.core import global_context as gpc
99
from colossalai.utils import get_current_device
10+
from torch.cuda.amp import custom_bwd, custom_fwd
1011

1112

1213
def matmul_2d(a,
@@ -60,6 +61,7 @@ class Matmul_AB_2D(torch.autograd.Function):
6061
"""Matrix multiplication for :math:`C = AB`
6162
"""
6263
@staticmethod
64+
@custom_fwd(cast_inputs=torch.float16)
6365
def forward(ctx: Any,
6466
A: Tensor,
6567
B: Tensor,
@@ -120,39 +122,40 @@ def forward(ctx: Any,
120122
return out
121123

122124
@staticmethod
125+
@custom_bwd
123126
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
124127
A, B = ctx.saved_tensors
125-
A_grad = Matmul_ABT_2D.forward(
126-
None,
127-
output_grad, B,
128-
ctx.summa_dim, ctx.A_shape,
129-
ctx.row_rank, ctx.col_rank,
130-
ctx.row_parallel_mode,
131-
ctx.col_parallel_mode,
132-
ctx.data_parallel_rank,
133-
ctx.pipeline_parallel_rank,
134-
ctx.pipeline_parallel_size,
135-
ctx.tensor_parallel_size
136-
)
137-
B_grad = Matmul_ATB_2D.forward(
138-
None,
139-
A, output_grad,
140-
ctx.summa_dim, ctx.B_shape,
141-
ctx.row_rank, ctx.col_rank,
142-
ctx.row_parallel_mode,
143-
ctx.col_parallel_mode,
144-
ctx.data_parallel_rank,
145-
ctx.pipeline_parallel_rank,
146-
ctx.pipeline_parallel_size,
147-
ctx.tensor_parallel_size
148-
)
128+
with torch.no_grad():
129+
A_grad = Matmul_ABT_2D.apply(
130+
output_grad, B,
131+
ctx.summa_dim, ctx.A_shape,
132+
ctx.row_rank, ctx.col_rank,
133+
ctx.row_parallel_mode,
134+
ctx.col_parallel_mode,
135+
ctx.data_parallel_rank,
136+
ctx.pipeline_parallel_rank,
137+
ctx.pipeline_parallel_size,
138+
ctx.tensor_parallel_size
139+
)
140+
B_grad = Matmul_ATB_2D.apply(
141+
A, output_grad,
142+
ctx.summa_dim, ctx.B_shape,
143+
ctx.row_rank, ctx.col_rank,
144+
ctx.row_parallel_mode,
145+
ctx.col_parallel_mode,
146+
ctx.data_parallel_rank,
147+
ctx.pipeline_parallel_rank,
148+
ctx.pipeline_parallel_size,
149+
ctx.tensor_parallel_size
150+
)
149151
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
150152

151153

152154
class Matmul_ABT_2D(torch.autograd.Function):
153155
"""Matrix multiplication for :math:`C = AB^T`
154156
"""
155157
@staticmethod
158+
@custom_fwd(cast_inputs=torch.float16)
156159
def forward(ctx: Any,
157160
A: Tensor,
158161
B: Tensor,
@@ -214,39 +217,41 @@ def forward(ctx: Any,
214217
return out
215218

216219
@staticmethod
220+
@custom_bwd
217221
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
218222
A, B = ctx.saved_tensors
219-
A_grad = Matmul_AB_2D.forward(
220-
None,
221-
output_grad, B,
222-
ctx.summa_dim, ctx.A_shape,
223-
ctx.row_rank, ctx.col_rank,
224-
ctx.row_parallel_mode,
225-
ctx.col_parallel_mode,
226-
ctx.data_parallel_rank,
227-
ctx.pipeline_parallel_rank,
228-
ctx.pipeline_parallel_size,
229-
ctx.tensor_parallel_size
230-
)
231-
B_grad = Matmul_ATB_2D.forward(
232-
None,
233-
output_grad, A,
234-
ctx.summa_dim, ctx.B_shape,
235-
ctx.row_rank, ctx.col_rank,
236-
ctx.row_parallel_mode,
237-
ctx.col_parallel_mode,
238-
ctx.data_parallel_rank,
239-
ctx.pipeline_parallel_rank,
240-
ctx.pipeline_parallel_size,
241-
ctx.tensor_parallel_size
242-
)
223+
224+
with torch.no_grad():
225+
A_grad = Matmul_AB_2D.apply(
226+
output_grad, B,
227+
ctx.summa_dim, ctx.A_shape,
228+
ctx.row_rank, ctx.col_rank,
229+
ctx.row_parallel_mode,
230+
ctx.col_parallel_mode,
231+
ctx.data_parallel_rank,
232+
ctx.pipeline_parallel_rank,
233+
ctx.pipeline_parallel_size,
234+
ctx.tensor_parallel_size
235+
)
236+
B_grad = Matmul_ATB_2D.apply(
237+
output_grad, A,
238+
ctx.summa_dim, ctx.B_shape,
239+
ctx.row_rank, ctx.col_rank,
240+
ctx.row_parallel_mode,
241+
ctx.col_parallel_mode,
242+
ctx.data_parallel_rank,
243+
ctx.pipeline_parallel_rank,
244+
ctx.pipeline_parallel_size,
245+
ctx.tensor_parallel_size
246+
)
243247
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
244248

245249

246250
class Matmul_ATB_2D(torch.autograd.Function):
247251
"""Matrix multiplication for :math:`C = A^TB`
248252
"""
249253
@staticmethod
254+
@custom_fwd(cast_inputs=torch.float16)
250255
def forward(ctx: Any,
251256
A: Tensor,
252257
B: Tensor,
@@ -308,39 +313,41 @@ def forward(ctx: Any,
308313
return out
309314

310315
@staticmethod
316+
@custom_bwd
311317
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
312318
A, B = ctx.saved_tensors
313-
A_grad = Matmul_ABT_2D.forward(
314-
None,
315-
B, output_grad,
316-
ctx.summa_dim, ctx.A_shape,
317-
ctx.row_rank, ctx.col_rank,
318-
ctx.row_parallel_mode,
319-
ctx.col_parallel_mode,
320-
ctx.data_parallel_rank,
321-
ctx.pipeline_parallel_rank,
322-
ctx.pipeline_parallel_size,
323-
ctx.tensor_parallel_size
324-
)
325-
B_grad = Matmul_AB_2D.forward(
326-
None,
327-
A, output_grad,
328-
ctx.summa_dim, ctx.B_shape,
329-
ctx.row_rank, ctx.col_rank,
330-
ctx.row_parallel_mode,
331-
ctx.col_parallel_mode,
332-
ctx.data_parallel_rank,
333-
ctx.pipeline_parallel_rank,
334-
ctx.pipeline_parallel_size,
335-
ctx.tensor_parallel_size
336-
)
319+
320+
with torch.no_grad():
321+
A_grad = Matmul_ABT_2D.apply(
322+
B, output_grad,
323+
ctx.summa_dim, ctx.A_shape,
324+
ctx.row_rank, ctx.col_rank,
325+
ctx.row_parallel_mode,
326+
ctx.col_parallel_mode,
327+
ctx.data_parallel_rank,
328+
ctx.pipeline_parallel_rank,
329+
ctx.pipeline_parallel_size,
330+
ctx.tensor_parallel_size
331+
)
332+
B_grad = Matmul_AB_2D.apply(
333+
A, output_grad,
334+
ctx.summa_dim, ctx.B_shape,
335+
ctx.row_rank, ctx.col_rank,
336+
ctx.row_parallel_mode,
337+
ctx.col_parallel_mode,
338+
ctx.data_parallel_rank,
339+
ctx.pipeline_parallel_rank,
340+
ctx.pipeline_parallel_size,
341+
ctx.tensor_parallel_size
342+
)
337343
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
338344

339345

340346
class Add_Bias_2D(torch.autograd.Function):
341347
"""Matrix add bias: :math:`C = A + b`
342348
"""
343349
@staticmethod
350+
@custom_fwd(cast_inputs=torch.float16)
344351
def forward(ctx: Any,
345352
input: Tensor,
346353
bias: Tensor,
@@ -384,6 +391,7 @@ def forward(ctx: Any,
384391
return output
385392

386393
@staticmethod
394+
@custom_bwd
387395
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
388396
row_rank = ctx.row_rank
389397
col_rank = ctx.col_rank
@@ -423,6 +431,7 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
423431
class _LayerNorm_2D(torch.autograd.Function):
424432

425433
@staticmethod
434+
@custom_fwd(cast_inputs=torch.float32)
426435
def forward(ctx: Any,
427436
input: Tensor,
428437
E_x: Tensor,
@@ -440,6 +449,7 @@ def forward(ctx: Any,
440449
return output
441450

442451
@staticmethod
452+
@custom_bwd
443453
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
444454
row_parallel_mode = ctx.row_parallel_mode
445455
col_parallel_mode = ctx.col_parallel_mode
@@ -492,6 +502,7 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
492502
class _ViT_Split_Input_2D(torch.autograd.Function):
493503

494504
@staticmethod
505+
@custom_fwd(cast_inputs=torch.float16)
495506
def forward(ctx: Any,
496507
inputs: Tensor,
497508
batch_size: int,
@@ -509,6 +520,7 @@ def forward(ctx: Any,
509520
return output
510521

511522
@staticmethod
523+
@custom_bwd
512524
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
513525
# output_grad: [b/q, s, h/q]
514526
# grads: [b, s, h/q]

colossalai/nn/lr_scheduler/cosine.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,10 @@ class CosineAnnealingWarmupLR(WarmupScheduler):
6666
:type last_epoch: int, optional
6767
"""
6868

69-
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1,
70-
**kwargs):
69+
def __init__(self, optimizer, total_steps: int, warmup_steps: int = 0, eta_min: int = 0, last_epoch: int = -1):
7170
base_scheduler = _CosineAnnealingLR(
72-
optimizer, total_steps - warmup_steps, eta_min=eta_min)
73-
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)
71+
optimizer, total_steps - warmup_steps, eta_min=eta_min, last_epoch=last_epoch)
72+
super().__init__(optimizer, warmup_steps, base_scheduler)
7473

7574

7675
@LR_SCHEDULERS.register_module

colossalai/nn/lr_scheduler/delayed.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def step(self, epoch=None):
5555

5656

5757
class WarmupScheduler(_LRScheduler):
58-
""" Starts with a linear warmup lr schedule until it reaches N epochs the applies a scheduler
58+
""" Starts with a linear warmup lr schedule until it reaches N epochs the applies a scheduler
5959
6060
:param optimizer: Wrapped optimizer.
6161
:type optimizer: torch.optim.Optimizer
@@ -66,11 +66,8 @@ class WarmupScheduler(_LRScheduler):
6666
:param last_epoch: The index of last epoch, defaults to -1
6767
:type last_epoch: int, optional
6868
"""
69-
7069
def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
71-
if warmup_epochs < 0:
72-
raise ValueError(f'warmup_epochs must >= 0, got {warmup_epochs}')
73-
self.warmup_epochs = warmup_epochs
70+
self.warmup_epochs = int(warmup_epochs)
7471
self.after_scheduler = after_scheduler
7572
self.finished = False
7673
super().__init__(optimizer, last_epoch)
@@ -79,14 +76,10 @@ def get_lr(self):
7976
if self.last_epoch >= self.warmup_epochs:
8077
if not self.finished:
8178
self.after_scheduler.base_lrs = self.base_lrs
82-
# reset lr to base_lr
83-
for group, base_lr in zip(self.optimizer.param_groups, self.base_lrs):
84-
group['lr'] = base_lr
8579
self.finished = True
86-
with _enable_get_lr_call(self.after_scheduler):
87-
return self.after_scheduler.get_lr()
80+
return self.after_scheduler.get_lr()
8881

89-
return [(self.last_epoch + 1) / (self.warmup_epochs + 1) * lr for lr in self.base_lrs]
82+
return [(self.last_epoch + 1) / self.warmup_epochs * lr for lr in self.base_lrs]
9083

9184
def step(self, epoch=None):
9285
if self.finished:

0 commit comments

Comments
 (0)