-
Notifications
You must be signed in to change notification settings - Fork 4.2k
/
Copy pathauto_tp.py
76 lines (67 loc) · 2.74 KB
/
auto_tp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# Automatic Tensor Parallelism
import re
from torch import nn
class AutoTP():
def in_module_list(module, module_list):
for item in module_list:
if type(item).__name__ == type(module).__name__:
return True
return False
def get_module_list(model):
mlist = []
for child in model.children():
if isinstance(child, nn.ModuleList):
for module in child.children():
if not mlist:
mlist = [module]
elif not AutoTP.in_module_list(module, mlist):
mlist = mlist + [module]
else:
mlist = mlist + AutoTP.get_module_list(child)
return mlist
def supported(model):
unsupported = ['bloom', 'codegen', 'flaubert', 'xlm']
model = str(model)
key = re.search(r": (.*?)Model", model)
if key is None:
key = re.search(r": (.*?)Stack", model)
if key is None:
key = re.match(r"(.*?)Model", model)
if key.group(1).lower() in unsupported:
return False
return True
def get_layers(parent, module):
layer_list = []
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + [parent + "." + key]
elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm':
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)
return layer_list
def tp_parser(model):
policy_list = []
module_list = []
layer_list = []
gem_list = []
assert AutoTP.supported(model), "Automatic policy not supported for model. Please provide policy."
module_list = AutoTP.get_module_list(model)
for module in module_list:
for key, submodule in module._modules.items():
if isinstance(submodule, nn.Linear):
layer_list = layer_list + ["." + key]
elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm':
layer_list = layer_list + ["ln"]
else:
layer_list = layer_list + AutoTP.get_layers(key, submodule)
for i, layer in enumerate(layer_list):
if layer == 'ln':
if layer_list[i - 1] != 'ln':
gem_list = gem_list + [layer_list[i - 1]]
elif 'out_proj' in layer:
gem_list = gem_list + [layer]
if gem_list != []:
policy_list.append(tuple([type(module), gem_list]))
gem_list = []
return policy_list