Skip to content

Commit 893e94e

Browse files
kurisusnowdengFrankLeeeee
authored andcommitted
optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
1 parent c9684a1 commit 893e94e

File tree

34 files changed

+1808
-626
lines changed

34 files changed

+1808
-626
lines changed

colossalai/builder/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _partition_layers(self, method):
198198
for st, ed in self.parts[stage]:
199199
for idx, layer in enumerate(self.layers[st: ed]):
200200
log_str += f'\t{idx + st:2d}: {layer}\n'
201-
self._logger.info(log_str)
201+
self._logger.info(log_str, ranks=[0])
202202

203203
# Save the partition
204204
self._interval = self.parts[pipeline_rank]

colossalai/communication/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from .collective import all_gather, reduce_scatter, scatter
1+
from .collective import all_gather, reduce_scatter, all_reduce
22
from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward,
33
send_backward, send_backward_recv_backward, send_forward_recv_backward,
44
send_forward_backward_recv_forward_backward, recv_forward, recv_backward)
55
from .ring import ring_forward
66
from .utils import send_tensor_meta, recv_tensor_meta
77

88
__all__ = [
9-
'all_gather', 'reduce_scatter', 'scatter',
9+
'all_gather', 'reduce_scatter', 'all_reduce',
1010
'send_forward', 'send_forward_recv_forward', 'send_forward_backward_recv_forward_backward',
1111
'send_backward', 'send_backward_recv_backward', 'send_backward_recv_forward',
1212
'send_forward_recv_backward', 'recv_backward', 'recv_forward',

colossalai/communication/collective.py

Lines changed: 66 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
def all_gather(tensor: Tensor, dim: int,
14-
parallel_mode: ParallelMode) -> Tensor:
14+
parallel_mode: ParallelMode, async_op=False) -> Tensor:
1515
"""Gathers all tensors from the parallel group and concatenates them in a
1616
specific dimension.
1717
@@ -26,18 +26,28 @@ def all_gather(tensor: Tensor, dim: int,
2626
"""
2727
depth = gpc.get_world_size(parallel_mode)
2828
temp = tensor.clone()
29-
shape = list(temp.shape)
30-
shape[dim] *= depth
31-
out = torch.empty(shape, dtype=temp.dtype, device=get_current_device())
32-
out = list(torch.chunk(out, depth, dim=dim))
33-
out = [val.contiguous() for val in out]
34-
dist.all_gather(out, temp, group=gpc.get_group(parallel_mode))
35-
out = torch.cat(out, dim=dim)
36-
return out
29+
# shape = list(temp.shape)
30+
# shape[dim] *= depth
31+
# out = torch.zeros(shape, dtype=temp.dtype, device=get_current_device())
32+
# out = list(torch.chunk(out, depth, dim=dim))
33+
# out = [val.contiguous() for val in out]
34+
shape = [1] * len(tensor.shape)
35+
shape[dim] = depth
36+
out = tensor.repeat(shape)
37+
out = list(map(lambda x: x.contiguous(), torch.chunk(out, depth, dim=dim)))
38+
op = dist.all_gather(tensor_list=out,
39+
tensor=temp,
40+
group=gpc.get_group(parallel_mode),
41+
async_op=async_op)
42+
# out = torch.cat(out, dim=dim)
43+
if async_op:
44+
return out, op
45+
else:
46+
return out
3747

3848

3949
def reduce_scatter(tensor: Tensor, dim: int,
40-
parallel_mode: ParallelMode) -> Tensor:
50+
parallel_mode: ParallelMode, async_op=False) -> Tensor:
4151
"""Reduces all tensors then scatters it in a specific dimension to all
4252
members in the parallel group.
4353
@@ -51,34 +61,52 @@ def reduce_scatter(tensor: Tensor, dim: int,
5161
:rtype: Tensor
5262
"""
5363
depth = gpc.get_world_size(parallel_mode)
54-
temp = list(torch.chunk(tensor, depth, dim=dim))
55-
temp = [val.contiguous() for val in temp]
56-
out = torch.empty(temp[0].shape,
57-
dtype=temp[0].dtype,
58-
device=get_current_device())
59-
dist.reduce_scatter(output=out,
60-
input_list=temp,
61-
group=gpc.get_group(parallel_mode))
62-
return out
64+
# temp = list(torch.chunk(tensor, depth, dim=dim))
65+
# temp = [val.contiguous() for val in temp]
66+
# out = torch.zeros(temp[0].shape,
67+
# dtype=temp[0].dtype,
68+
# device=get_current_device())
69+
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
70+
out = temp[0].clone()
71+
op = dist.reduce_scatter(output=out,
72+
input_list=temp,
73+
group=gpc.get_group(parallel_mode),
74+
async_op=async_op)
75+
if async_op:
76+
return out, op
77+
else:
78+
return out
6379

6480

65-
def scatter(tensor: Tensor, src: int, dim: int,
66-
parallel_mode: ParallelMode) -> Tensor:
67-
"""Scatters in a specific dimension from source rank to all ranks in
68-
the parallel group.
81+
def all_reduce(tensor: Tensor,
82+
parallel_mode: ParallelMode,
83+
async_op=False) -> Tensor:
84+
op = dist.all_reduce(tensor,
85+
group=gpc.get_group(parallel_mode),
86+
async_op=async_op)
87+
if async_op:
88+
return tensor, op
89+
else:
90+
return tensor
91+
92+
93+
# def scatter(tensor: Tensor, src: int, dim: int,
94+
# parallel_mode: ParallelMode) -> Tensor:
95+
# """Scatters in a specific dimension from source rank to all ranks in
96+
# the parallel group.
6997

70-
:param tensor: Tensor to be scattered
71-
:param dim: The dimension scattering in
72-
:param parallel_mode: Parallel group mode used in this communication
73-
:type tensor: Tensor
74-
:type dim: int
75-
:type parallel_mode: ParallelMode
76-
:return: The tensor generated by scatter
77-
:rtype: Tensor
78-
"""
79-
depth = gpc.get_world_size(parallel_mode)
80-
temp = tensor.clone()
81-
dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
82-
rank = gpc.get_local_rank(parallel_mode)
83-
out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
84-
return out
98+
# :param tensor: Tensor to be scattered
99+
# :param dim: The dimension scattering in
100+
# :param parallel_mode: Parallel group mode used in this communication
101+
# :type tensor: Tensor
102+
# :type dim: int
103+
# :type parallel_mode: ParallelMode
104+
# :return: The tensor generated by scatter
105+
# :rtype: Tensor
106+
# """
107+
# depth = gpc.get_world_size(parallel_mode)
108+
# temp = tensor.clone()
109+
# dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
110+
# rank = gpc.get_local_rank(parallel_mode)
111+
# out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
112+
# return out

colossalai/constants.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525

2626
# 3D parallel
2727
DEPTH_3D = 'DEPTH_3D'
28+
INPUT_GROUP_3D = 'PARALLEL_3D_INPUT'
29+
WEIGHT_GROUP_3D = 'PARALLEL_3D_WEIGHT'
30+
OUTPUT_GROUP_3D = 'PARALLEL_3D_OUTPUT'
2831

2932
# Tensor parallel attributes
3033
IS_TENSOR_PARALLEL = 'is_tensor_parallel'
31-
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL]
34+
NUM_PARTITIONS = 'num_partitions'
35+
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS]

colossalai/context/parallel_context.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,18 @@ def init_global_dist(self, addr=None, port=None):
277277
:type port: int, optional
278278
"""
279279
# get config
280-
rank = self._dist_args.local_rank
280+
local_rank = self._dist_args.local_rank
281+
rank = self._dist_args.rank
281282
world_size = self._dist_args.world_size
283+
if local_rank is None:
284+
local_rank = os.getenv('LOCAL_RANK')
285+
if rank is None:
286+
rank = os.getenv('RANK')
287+
if world_size is None:
288+
world_size = os.getenv('WORLD_SIZE')
282289
# default env config, overwrite by exporting
283290
# them in your bash script
291+
284292
addr = os.getenv('MASTER_ADDR', 'localhost') if addr is None else addr
285293
port = os.getenv('MASTER_PORT', '8008') if port is None else port
286294
init_method = f'tcp://{addr}:{port}'
@@ -293,7 +301,8 @@ def init_global_dist(self, addr=None, port=None):
293301
# None will give the default global process group for pytorch dist operations
294302
self._register_dist(rank, world_size, None,
295303
list(range(world_size)), ParallelMode.GLOBAL)
296-
self._global_ranks[ParallelMode.GLOBAL] = rank
304+
self.add_global_rank(ParallelMode.GLOBAL, rank)
305+
# self._global_ranks[ParallelMode.GLOBAL] = rank
297306

298307
def _register_dist(self, local_rank, world_size,
299308
process_group, ranks_in_group, mode):
@@ -426,18 +435,15 @@ def set_seed(self):
426435
if torch.cuda.is_available():
427436
# create random seed for different parallel modes
428437
# data parallel seed are kept the same
429-
parallel_seed = seed
438+
tp_rank = self._local_ranks.get(ParallelMode.TENSOR, 0)
439+
pp_rank = self._local_ranks.get(ParallelMode.PIPELINE, 0)
440+
parallel_seed = seed + tp_rank + pp_rank * 1024
430441
add_seed(ParallelMode.DATA, parallel_seed)
431442

432-
# model parallel seeds are different across ranks
433-
pipeline_offset = self._local_ranks.get(ParallelMode.PIPELINE, 0)
434-
435443
# add seed for data parallel and tensor parallel only
436444
if self.is_initialized(ParallelMode.TENSOR):
437-
tp_rank = self.get_local_rank(ParallelMode.TENSOR)
438-
# 100 is only to increase the diff in seeds between pipeline stages
439-
tp_rank_with_offset = tp_rank + pipeline_offset * 1024
440-
tp_seed = seed + tp_rank_with_offset
445+
dp_rank = self._local_ranks.get(ParallelMode.DATA, 0) + 1
446+
tp_seed = parallel_seed + dp_rank * 128
441447
add_seed(ParallelMode.TENSOR, tp_seed)
442448

443449
set_mode(ParallelMode.DATA)

colossalai/context/process_group_initializer/initializer_3d.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66

77
import torch.distributed as dist
8-
from colossalai.constants import DEPTH_3D
8+
from colossalai.constants import DEPTH_3D, INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
99
from colossalai.registry import DIST_GROUP_INITIALIZER
1010

1111
from ..parallel_mode import ParallelMode
@@ -18,7 +18,7 @@ def _check_depth_env_var(depth):
1818

1919
if env_depth:
2020
assert int(env_depth) == depth, \
21-
'SUMMA_DIM has been set in the current environment and ' \
21+
'DEPTH_3D has been set in the current environment and ' \
2222
'does not match with the value passed to this initialized'
2323
else:
2424
os.environ[DEPTH_3D] = str(depth)
@@ -43,6 +43,7 @@ def init_dist_group(self):
4343
process_group = None
4444
group_world_size = None
4545
mode = ParallelMode.PARALLEL_3D_INPUT
46+
os.environ[INPUT_GROUP_3D] = INPUT_GROUP_3D
4647

4748
for h in range(self.num_group):
4849
for i in range(self.depth):
@@ -82,6 +83,7 @@ def init_dist_group(self):
8283
process_group = None
8384
group_world_size = None
8485
mode = ParallelMode.PARALLEL_3D_WEIGHT
86+
os.environ[WEIGHT_GROUP_3D] = WEIGHT_GROUP_3D
8587

8688
for h in range(self.num_group):
8789
for k in range(self.depth):
@@ -121,6 +123,7 @@ def init_dist_group(self):
121123
process_group = None
122124
group_world_size = None
123125
mode = ParallelMode.PARALLEL_3D_OUTPUT
126+
os.environ[OUTPUT_GROUP_3D] = OUTPUT_GROUP_3D
124127

125128
for h in range(self.num_group):
126129
for i in range(self.depth):

colossalai/initialize.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,21 @@ def parse_args():
4242
type=str,
4343
default=None,
4444
help='the master port for distributed training')
45-
parser.add_argument('--world_size', type=int, help='world size for ')
45+
parser.add_argument('--world_size', type=int, help='world size for distributed training')
46+
parser.add_argument('--rank', type=int, help='rank for the default process group')
4647
parser.add_argument('--local_rank',
4748
type=int,
48-
help='rank for the default process group')
49+
help='local rank on the node')
4950
parser.add_argument('--backend',
5051
type=str,
5152
default='nccl',
52-
help='backend for torch.distributed')
53+
help='backend for distributed communication')
5354
return parser.parse_args()
5455

5556

5657
def init_dist(config: Union[str, dict] = None,
5758
local_rank: int = None,
59+
rank: int = None,
5860
world_size: int = None,
5961
host: str = None,
6062
port: str = None,
@@ -86,6 +88,8 @@ def init_dist(config: Union[str, dict] = None,
8688
config = args.config
8789
if local_rank is None:
8890
local_rank = args.local_rank
91+
if rank is None:
92+
rank = args.rank
8993
if world_size is None:
9094
world_size = args.world_size
9195
if host is None:
@@ -99,12 +103,14 @@ def init_dist(config: Union[str, dict] = None,
99103
host=host,
100104
port=port,
101105
world_size=world_size,
106+
rank=rank,
102107
local_rank=local_rank,
103108
backend=backend))
104109

105110
# set distributed settings
106111
dist_args = Config(
107112
dict(local_rank=args.local_rank,
113+
rank=rank,
108114
world_size=args.world_size,
109115
backend=args.backend))
110116

@@ -178,6 +184,7 @@ def seed_worker(worker_id):
178184

179185
def initialize(config: Union[str, dict] = None,
180186
local_rank: int = None,
187+
rank: int = None,
181188
world_size: int = None,
182189
host: str = None,
183190
port: str = None,
@@ -209,6 +216,7 @@ def initialize(config: Union[str, dict] = None,
209216
# initialize distributed environment
210217
init_dist(config=config,
211218
local_rank=local_rank,
219+
rank=rank,
212220
world_size=world_size,
213221
host=host,
214222
port=port,

colossalai/nn/init.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import math
2+
3+
from torch import Tensor
4+
from torch.nn import init as init
5+
6+
7+
def init_weight_(tensor: Tensor, fan_in: int, fan_out: int = None, init_method: str = 'torch'):
8+
if init_method == 'torch':
9+
a = math.sqrt(5)
10+
nonlinearity = 'leaky_relu'
11+
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
12+
bound = math.sqrt(3.0) * std
13+
init.uniform_(tensor, -bound, bound)
14+
elif init_method == 'jax':
15+
std = math.sqrt(2.0 / float(fan_in + fan_out))
16+
a = math.sqrt(3.0) * std
17+
init.uniform_(tensor, -a, a)
18+
elif init_method == 'jax_embed':
19+
std = math.sqrt(1.0 / fan_in)
20+
init.trunc_normal_(tensor, std=std / .87962566103423978)
21+
elif init_method == 'zero':
22+
init.zeros_(tensor)
23+
24+
def init_bias_(tensor: Tensor, fan_in: int, init_method: str = 'torch'):
25+
if init_method == 'torch':
26+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
27+
init.uniform_(tensor, -bound, bound)
28+
elif init_method == 'jax':
29+
init.normal_(tensor, std=1e-6)
30+
elif init_method == 'jax_embed':
31+
init.trunc_normal_(tensor, std=.02)
32+
elif init_method == 'zero':
33+
init.zeros_(tensor)

colossalai/nn/layer/_common_utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33

44
import math
55

6+
import numpy as np
7+
from colossalai.utils.common import print_rank_0
68
import torch
7-
from torch import Tensor
8-
from torch import nn
9+
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
910
from colossalai.utils import checkpoint
10-
11-
from colossalai.constants import IS_TENSOR_PARALLEL
11+
from torch import Tensor, nn
1212

1313

1414
def divide(numerator, denominator):
@@ -33,9 +33,11 @@ def swish(x: Tensor) -> Tensor:
3333
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
3434

3535

36-
def set_tensor_parallel_attribute(param):
37-
if not hasattr(param, IS_TENSOR_PARALLEL):
38-
setattr(param, IS_TENSOR_PARALLEL, True)
36+
def set_tensor_parallel_attribute(param, size):
37+
# if not hasattr(param, IS_TENSOR_PARALLEL):
38+
setattr(param, IS_TENSOR_PARALLEL, True)
39+
# if not hasattr(param, NUM_PARTITIONS):
40+
setattr(param, NUM_PARTITIONS, size // np.prod(param.shape))
3941

4042

4143
class CheckpointModule(nn.Module):

0 commit comments

Comments
 (0)