forked from facebookresearch/dlrm
-
Notifications
You must be signed in to change notification settings - Fork 2
/
sharders.py
196 lines (154 loc) · 8.21 KB
/
sharders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
from dataclasses import dataclass
import numpy as np
import sys, GPUtil
_sharders = {}
def get_device_total_memory(factor=0.8):
gpus = GPUtil.getGPUs()
return [(gpu.memoryTotal * factor * 1024 * 1024) for gpu in gpus] # MB -> bytes
DEVICE_MEMORY = get_device_total_memory()
@dataclass
class TableInfo:
def __init__(self, index, D, E, L, rfs):
self.index = index
self.cost = 0
self.D = D
self.E = E
self.L = L
self.rfs = [float(x) for x in rfs.split('-')] if isinstance(rfs, str) else rfs
self.size = D * E * 4 # float32
def __str__(self):
return f"""
index: {self.index}
cost: {self.cost}
D: {self.D}
E: {self.E}
L: {self.L}
rfs: {self.rfs}
"""
def plan2allocation(plan):
num_tables = sum([len(shard) for shard in plan])
table_device_indices = [-1] * num_tables
for bin_id, partition in enumerate(plan):
for index in partition:
table_device_indices[index] = bin_id
return table_device_indices
def _cost_func_multi_cost(
emb_dim: int,
pooling_factor: int,
ndevices: int = 4,
use_uvm_cache: bool = False
):
k_fwd = 0.5 * (8 / ndevices)
k_pf = 3.0 if use_uvm_cache else 1.0
emb_fwd_cost = k_fwd * float((emb_dim + 127) / 128) * pooling_factor
emb_bwd_cost = 2 * emb_fwd_cost # + 10 if weighted
return [emb_dim, k_pf * pooling_factor, 0 * emb_fwd_cost, emb_bwd_cost]
def get_splits(T, ndevices):
k, m = divmod(T, ndevices)
if m == 0:
splits = [k] * ndevices
else:
splits = [(k + 1) if i < m else k for i in range(ndevices)]
return splits
def greedy_partition(ndevices, table_info_list, sort_type="size_greedy"):
if sort_type == "dim_greedy":
key = lambda x: (x.D, x.index)
elif sort_type == "size_greedy":
key = lambda x: (x.E, x.index)
elif sort_type == "lookup_greedy":
key = lambda x: (x.L * x.D, x.index)
elif sort_type == "norm_lookup_greedy":
key = lambda x: (x.L / x.E, x.index)
elif sort_type == "size_lookup_greedy":
key = lambda x: (x.L * x.D * np.log10(x.E), x.index)
else:
sys.exit("Sharders: Unrecognized sort type!")
sorted_table_info_list = sorted(table_info_list, key=key)
num_bins = ndevices
partitions = [[] for _ in range(num_bins)]
partition_sums = [0.0] * num_bins
partition_size_sums = [0.0] * num_bins
while sorted_table_info_list:
table_info = sorted_table_info_list.pop()
min_sum = np.inf
min_size_taken = np.inf
min_r = -1
for r in range(num_bins):
if partition_size_sums[r] + table_info.size <= DEVICE_MEMORY[r]:
if partition_sums[r] < min_sum or (
partition_sums[r] == min_sum
and partition_size_sums[r] < min_size_taken
):
min_sum = partition_sums[r]
min_r = r
min_size_taken = partition_size_sums[r]
partitions[min_r].append(table_info)
partition_sums[min_r] += table_info.cost
partition_size_sums[min_r] += table_info.size
partitions = [[table_info.index for table_info in partition] for partition in partitions]
return plan2allocation(partitions)
def register_sharder(sharder_name):
def decorate(func):
_sharders[sharder_name] = func
return func
return decorate
# get device indices for tables
# e.g 8 tables, No. [1,3,5,6] on device 0, No. [2,4,7,8] on device 1, then
# return [0, 1, 0, 1, 0, 0, 1, 1]
# N.B.: only for single-node multi-GPU for now.
def shard(ndevices, table_info_list, alg="naive"):
if alg not in _sharders:
sys.exit("ERROR: sharder not found")
return _sharders[alg](ndevices, table_info_list)
@register_sharder("naive")
def naive_shard(ndevices, table_info_list):
return [(x % ndevices) for x in range(len(table_info_list))]
@register_sharder("naive_chunk")
def naive_shard(ndevices, table_info_list):
T = len(table_info_list)
splits = get_splits(T, ndevices)
table_device_indices = []
for idx, s in enumerate(splits):
table_device_indices.extend([idx] * s)
return table_device_indices
@register_sharder("hardcode")
def hardcode_shard(ndevices, table_info_list):
T = len(table_info_list)
return [0] * int(T/2) + [1] * (T - int(T/2))
@register_sharder("random")
def random_shard(ndevices, table_info_list):
table_device_indices = []
for _ in range(len(table_info_list)):
table_device_indices.append(np.random.randint(ndevices))
return table_device_indices
@register_sharder("dim_greedy")
def dim_greedy_shard(ndevices, table_info_list):
return greedy_partition(ndevices, table_info_list, sort_type="dim_greedy")
@register_sharder("size_greedy")
def size_greedy_shard(ndevices, table_info_list):
return greedy_partition(ndevices, table_info_list, sort_type="size_greedy")
@register_sharder("lookup_greedy")
def lookup_greedy_shard(ndevices, table_info_list):
return greedy_partition(ndevices, table_info_list, sort_type="lookup_greedy")
@register_sharder("norm_lookup_greedy")
def lookup_greedy_shard(ndevices, table_info_list):
return greedy_partition(ndevices, table_info_list, sort_type="norm_lookup_greedy")
@register_sharder("size_lookup_greedy")
def size_lookup_greedy_shard(ndevices, table_info_list):
return greedy_partition(ndevices, table_info_list, sort_type="size_lookup_greedy")
if __name__ == '__main__':
nums = list(range(16))
Ds = [128, 64, 32, 32, 32, 128, 64, 32, 128, 32, 64, 32, 256, 64, 64, 32]
Es = [100, 5999681, 2714889, 6516585, 5820669, 5999994, 1136544, 5999981, 426, 5999995, 100, 5999878, 100, 5999808, 5999929, 5999929]
Ls = [27.639102711397058, 2.549273322610294, 48.56380687040441, 106.00396369485294, 17.78176700367647, 5.277961282169118, 14.807014016544118, 9.307516659007353, 9.802016314338236, 8.356847426470589, 7.246998506433823, 0.6754222196691176, 13.421200022977942, 0.2587747012867647, 1.6537511488970589, 6.931540096507353]
rfs = [[0.7847, 0.1148, 0.0548, 0.0273, 0.0148, 0.0032, 0.0004, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.062, 0.0148, 0.0108, 0.0067, 0.0039, 0.0014, 0.0003, 0.0001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7853, 0.1067, 0.0602, 0.0283, 0.0119, 0.005, 0.0018, 0.0006, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5189, 0.222, 0.1711, 0.073, 0.0137, 0.001, 0.0002, 0.0001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.5539, 0.2046, 0.133, 0.0699, 0.0274, 0.0085, 0.0018, 0.0009, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3615, 0.162, 0.1888, 0.1727, 0.0693, 0.0159, 0.0094, 0.0083, 0.0081, 0.004, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.1198, 0.071, 0.0679, 0.0795, 0.1099, 0.1277, 0.1051, 0.1252, 0.1401, 0.0538, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3135, 0.128, 0.1116, 0.1121, 0.1144, 0.0999, 0.104, 0.0159, 0.0005, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.7301, 0.1464, 0.0734, 0.0359, 0.0131, 0.0011, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.3957, 0.1599, 0.1168, 0.1043, 0.0777, 0.0625, 0.04, 0.0135, 0.0223, 0.0075, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6039, 0.2084, 0.1182, 0.0553, 0.013, 0.0011, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0615, 0.0209, 0.0122, 0.0046, 0.0008, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0006, 0.001, 0.0095, 0.0718, 0.3252, 0.3321, 0.1466, 0.0151, 0.0554, 0.0322, 0.0106, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0556, 0.0215, 0.0139, 0.0078, 0.0012, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0593, 0.0219, 0.013, 0.0043, 0.0011, 0.0002, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.6013, 0.2063, 0.1193, 0.0576, 0.0139, 0.0014, 0.0001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
tables = [TableInfo(idx, D, E, L, rf) for (idx, D, E, L, rf) in zip(nums, Ds, Es, Ls, rfs)]
print(shard(2, tables, 'naive'))
print(shard(2, tables, 'naive_chunk'))
print(shard(2, tables, 'hardcode'))
print(shard(2, tables, 'random'))
print(shard(2, tables, 'dim_greedy'))
print(shard(2, tables, 'size_greedy'))
print(shard(2, tables, 'lookup_greedy'))
print(shard(2, tables, 'norm_lookup_greedy'))
print(shard(2, tables, 'size_lookup_greedy'))