Skip to content

Commit fa53de4

Browse files
henrylhtsangfacebook-github-bot
authored andcommitted
Find minimal col dim when its not provided (#1295)
Summary: Pull Request resolved: #1295 This diff has two parts. In the first part, we handle the case when the per table base_dim is not given. We use a for loop for find the base_dim that is at least 128, divides the table dim, and is divisible by 4. This helps with the case when the table dim is not divisible by 128, since in that case the current code would just use the table dim. In the second part, we replace `col_wise_shard_dim` with `_find_base_dim(col_wise_shard_dim, columns)`. In cases where `col_wise_shard_dim` is correctly provided, this does nothing. But this allows the user to put in a more flexible `col_wise_shard_dim`. For example, the users can just put in `col_wise_shard_dim = 40` for all tables, and the planner will find the per table base dim that (1) divides the table dim, (2) is divisible by 4 and (3) is larger than or equal to 40. Reviewed By: bigning Differential Revision: D47887146 fbshipit-source-id: deaafa143a03f14c9dcada17743033c623c8ae68
1 parent b458289 commit fa53de4

File tree

3 files changed

+146
-136
lines changed

3 files changed

+146
-136
lines changed

torchrec/distributed/planner/tests/test_enumerators.py

Lines changed: 111 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@
3434

3535

3636
EXPECTED_RW_SHARD_SIZES = [
37-
[[13, 10], [13, 10], [13, 10], [13, 10], [13, 10], [13, 10], [13, 10], [9, 10]],
38-
[[14, 20], [14, 20], [14, 20], [14, 20], [14, 20], [14, 20], [14, 20], [12, 20]],
39-
[[15, 30], [15, 30], [15, 30], [15, 30], [15, 30], [15, 30], [15, 30], [15, 30]],
40-
[[17, 40], [17, 40], [17, 40], [17, 40], [17, 40], [17, 40], [17, 40], [11, 40]],
37+
[[13, 20], [13, 20], [13, 20], [13, 20], [13, 20], [13, 20], [13, 20], [9, 20]],
38+
[[14, 40], [14, 40], [14, 40], [14, 40], [14, 40], [14, 40], [14, 40], [12, 40]],
39+
[[15, 60], [15, 60], [15, 60], [15, 60], [15, 60], [15, 60], [15, 60], [15, 60]],
40+
[[17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [11, 80]],
4141
]
4242

4343
EXPECTED_RW_SHARD_OFFSETS = [
@@ -49,97 +49,97 @@
4949

5050
EXPECTED_RW_SHARD_STORAGE = [
5151
[
52-
Storage(hbm=84488, ddr=0),
53-
Storage(hbm=84488, ddr=0),
54-
Storage(hbm=84488, ddr=0),
55-
Storage(hbm=84488, ddr=0),
56-
Storage(hbm=84488, ddr=0),
57-
Storage(hbm=84488, ddr=0),
58-
Storage(hbm=84488, ddr=0),
59-
Storage(hbm=84328, ddr=0),
52+
Storage(hbm=166928, ddr=0),
53+
Storage(hbm=166928, ddr=0),
54+
Storage(hbm=166928, ddr=0),
55+
Storage(hbm=166928, ddr=0),
56+
Storage(hbm=166928, ddr=0),
57+
Storage(hbm=166928, ddr=0),
58+
Storage(hbm=166928, ddr=0),
59+
Storage(hbm=166608, ddr=0),
6060
],
6161
[
62-
Storage(hbm=511072, ddr=0),
63-
Storage(hbm=511072, ddr=0),
64-
Storage(hbm=511072, ddr=0),
65-
Storage(hbm=511072, ddr=0),
66-
Storage(hbm=511072, ddr=0),
67-
Storage(hbm=511072, ddr=0),
68-
Storage(hbm=511072, ddr=0),
69-
Storage(hbm=510912, ddr=0),
62+
Storage(hbm=1003712, ddr=0),
63+
Storage(hbm=1003712, ddr=0),
64+
Storage(hbm=1003712, ddr=0),
65+
Storage(hbm=1003712, ddr=0),
66+
Storage(hbm=1003712, ddr=0),
67+
Storage(hbm=1003712, ddr=0),
68+
Storage(hbm=1003712, ddr=0),
69+
Storage(hbm=1003392, ddr=0),
7070
],
7171
[
72-
Storage(hbm=513800, ddr=0),
73-
Storage(hbm=513800, ddr=0),
74-
Storage(hbm=513800, ddr=0),
75-
Storage(hbm=513800, ddr=0),
76-
Storage(hbm=513800, ddr=0),
77-
Storage(hbm=513800, ddr=0),
78-
Storage(hbm=513800, ddr=0),
79-
Storage(hbm=513800, ddr=0),
72+
Storage(hbm=1007120, ddr=0),
73+
Storage(hbm=1007120, ddr=0),
74+
Storage(hbm=1007120, ddr=0),
75+
Storage(hbm=1007120, ddr=0),
76+
Storage(hbm=1007120, ddr=0),
77+
Storage(hbm=1007120, ddr=0),
78+
Storage(hbm=1007120, ddr=0),
79+
Storage(hbm=1007120, ddr=0),
8080
],
8181
[
82-
Storage(hbm=1340064, ddr=0),
83-
Storage(hbm=1340064, ddr=0),
84-
Storage(hbm=1340064, ddr=0),
85-
Storage(hbm=1340064, ddr=0),
86-
Storage(hbm=1340064, ddr=0),
87-
Storage(hbm=1340064, ddr=0),
88-
Storage(hbm=1340064, ddr=0),
89-
Storage(hbm=1339104, ddr=0),
82+
Storage(hbm=2653504, ddr=0),
83+
Storage(hbm=2653504, ddr=0),
84+
Storage(hbm=2653504, ddr=0),
85+
Storage(hbm=2653504, ddr=0),
86+
Storage(hbm=2653504, ddr=0),
87+
Storage(hbm=2653504, ddr=0),
88+
Storage(hbm=2653504, ddr=0),
89+
Storage(hbm=2651584, ddr=0),
9090
],
9191
]
9292

9393

9494
EXPECTED_UVM_CACHING_RW_SHARD_STORAGE = [
9595
[
96-
Storage(hbm=84072, ddr=520),
97-
Storage(hbm=84072, ddr=520),
98-
Storage(hbm=84072, ddr=520),
99-
Storage(hbm=84072, ddr=520),
100-
Storage(hbm=84072, ddr=520),
101-
Storage(hbm=84072, ddr=520),
102-
Storage(hbm=84072, ddr=520),
103-
Storage(hbm=84040, ddr=360),
96+
Storage(hbm=166096, ddr=1040),
97+
Storage(hbm=166096, ddr=1040),
98+
Storage(hbm=166096, ddr=1040),
99+
Storage(hbm=166096, ddr=1040),
100+
Storage(hbm=166096, ddr=1040),
101+
Storage(hbm=166096, ddr=1040),
102+
Storage(hbm=166096, ddr=1040),
103+
Storage(hbm=166032, ddr=720),
104104
],
105105
[
106-
Storage(hbm=510176, ddr=1120),
107-
Storage(hbm=510176, ddr=1120),
108-
Storage(hbm=510176, ddr=1120),
109-
Storage(hbm=510176, ddr=1120),
110-
Storage(hbm=510176, ddr=1120),
111-
Storage(hbm=510176, ddr=1120),
112-
Storage(hbm=510176, ddr=1120),
113-
Storage(hbm=510144, ddr=960),
106+
Storage(hbm=1001920, ddr=2240),
107+
Storage(hbm=1001920, ddr=2240),
108+
Storage(hbm=1001920, ddr=2240),
109+
Storage(hbm=1001920, ddr=2240),
110+
Storage(hbm=1001920, ddr=2240),
111+
Storage(hbm=1001920, ddr=2240),
112+
Storage(hbm=1001920, ddr=2240),
113+
Storage(hbm=1001856, ddr=1920),
114114
],
115115
[
116-
Storage(hbm=512360, ddr=1800),
117-
Storage(hbm=512360, ddr=1800),
118-
Storage(hbm=512360, ddr=1800),
119-
Storage(hbm=512360, ddr=1800),
120-
Storage(hbm=512360, ddr=1800),
121-
Storage(hbm=512360, ddr=1800),
122-
Storage(hbm=512360, ddr=1800),
123-
Storage(hbm=512360, ddr=1800),
116+
Storage(hbm=1004240, ddr=3600),
117+
Storage(hbm=1004240, ddr=3600),
118+
Storage(hbm=1004240, ddr=3600),
119+
Storage(hbm=1004240, ddr=3600),
120+
Storage(hbm=1004240, ddr=3600),
121+
Storage(hbm=1004240, ddr=3600),
122+
Storage(hbm=1004240, ddr=3600),
123+
Storage(hbm=1004240, ddr=3600),
124124
],
125125
[
126-
Storage(hbm=1337888, ddr=2720),
127-
Storage(hbm=1337888, ddr=2720),
128-
Storage(hbm=1337888, ddr=2720),
129-
Storage(hbm=1337888, ddr=2720),
130-
Storage(hbm=1337888, ddr=2720),
131-
Storage(hbm=1337888, ddr=2720),
132-
Storage(hbm=1337888, ddr=2720),
133-
Storage(hbm=1337696, ddr=1760),
126+
Storage(hbm=2649152, ddr=5440),
127+
Storage(hbm=2649152, ddr=5440),
128+
Storage(hbm=2649152, ddr=5440),
129+
Storage(hbm=2649152, ddr=5440),
130+
Storage(hbm=2649152, ddr=5440),
131+
Storage(hbm=2649152, ddr=5440),
132+
Storage(hbm=2649152, ddr=5440),
133+
Storage(hbm=2648768, ddr=3520),
134134
],
135135
]
136136

137137

138138
EXPECTED_TWRW_SHARD_SIZES = [
139-
[[25, 10], [25, 10], [25, 10], [25, 10]],
140-
[[28, 20], [28, 20], [28, 20], [26, 20]],
141-
[[30, 30], [30, 30], [30, 30], [30, 30]],
142-
[[33, 40], [33, 40], [33, 40], [31, 40]],
139+
[[25, 20], [25, 20], [25, 20], [25, 20]],
140+
[[28, 40], [28, 40], [28, 40], [26, 40]],
141+
[[30, 60], [30, 60], [30, 60], [30, 60]],
142+
[[33, 80], [33, 80], [33, 80], [31, 80]],
143143
]
144144

145145
EXPECTED_TWRW_SHARD_OFFSETS = [
@@ -151,75 +151,69 @@
151151

152152
EXPECTED_TWRW_SHARD_STORAGE = [
153153
[
154-
Storage(hbm=87016, ddr=0),
155-
Storage(hbm=87016, ddr=0),
156-
Storage(hbm=87016, ddr=0),
157-
Storage(hbm=87016, ddr=0),
154+
Storage(hbm=169936, ddr=0),
155+
Storage(hbm=169936, ddr=0),
156+
Storage(hbm=169936, ddr=0),
157+
Storage(hbm=169936, ddr=0),
158158
],
159159
[
160-
Storage(hbm=530624, ddr=0),
161-
Storage(hbm=530624, ddr=0),
162-
Storage(hbm=530624, ddr=0),
163-
Storage(hbm=530464, ddr=0),
160+
Storage(hbm=1024384, ddr=0),
161+
Storage(hbm=1024384, ddr=0),
162+
Storage(hbm=1024384, ddr=0),
163+
Storage(hbm=1024064, ddr=0),
164164
],
165165
[
166-
Storage(hbm=536080, ddr=0),
167-
Storage(hbm=536080, ddr=0),
168-
Storage(hbm=536080, ddr=0),
169-
Storage(hbm=536080, ddr=0),
166+
Storage(hbm=1031200, ddr=0),
167+
Storage(hbm=1031200, ddr=0),
168+
Storage(hbm=1031200, ddr=0),
169+
Storage(hbm=1031200, ddr=0),
170170
],
171171
[
172-
Storage(hbm=1369248, ddr=0),
173-
Storage(hbm=1369248, ddr=0),
174-
Storage(hbm=1369248, ddr=0),
175-
Storage(hbm=1368928, ddr=0),
172+
Storage(hbm=2685248, ddr=0),
173+
Storage(hbm=2685248, ddr=0),
174+
Storage(hbm=2685248, ddr=0),
175+
Storage(hbm=2684608, ddr=0),
176176
],
177177
]
178178

179179
EXPECTED_CW_SHARD_SIZES = [
180-
[[100, 10]],
181-
[[110, 10], [110, 10]],
182-
[[120, 10], [120, 10], [120, 10]],
183-
[[130, 20], [130, 20]],
180+
[[100, 20]],
181+
[[110, 20], [110, 20]],
182+
[[120, 20], [120, 20], [120, 20]],
183+
[[130, 40], [130, 40]],
184184
]
185185

186186
EXPECTED_CW_SHARD_OFFSETS = [
187187
[[0, 0]],
188-
[[0, 0], [0, 10]],
189-
[[0, 0], [0, 10], [0, 20]],
190188
[[0, 0], [0, 20]],
189+
[[0, 0], [0, 20], [0, 40]],
190+
[[0, 0], [0, 40]],
191191
]
192192

193193
EXPECTED_CW_SHARD_STORAGE = [
194-
[Storage(hbm=102304, ddr=0)],
195-
[Storage(hbm=397616, ddr=0), Storage(hbm=397616, ddr=0)],
194+
[Storage(hbm=188224, ddr=0)],
195+
[Storage(hbm=647776, ddr=0), Storage(hbm=647776, ddr=0)],
196196
[
197-
Storage(hbm=332480, ddr=0),
198-
Storage(hbm=332480, ddr=0),
199-
Storage(hbm=332480, ddr=0),
200-
],
201-
[
202-
Storage(hbm=878752, ddr=0),
203-
Storage(hbm=878752, ddr=0),
197+
Storage(hbm=501120, ddr=0),
198+
Storage(hbm=501120, ddr=0),
199+
Storage(hbm=501120, ddr=0),
204200
],
201+
[Storage(hbm=1544512, ddr=0), Storage(hbm=1544512, ddr=0)],
205202
]
206203

207204
EXPECTED_TWCW_SHARD_SIZES: List[List[List[int]]] = EXPECTED_CW_SHARD_SIZES
208205

209206
EXPECTED_TWCW_SHARD_OFFSETS: List[List[List[int]]] = EXPECTED_CW_SHARD_OFFSETS
210207

211208
EXPECTED_TWCW_SHARD_STORAGE = [
212-
[Storage(hbm=102304, ddr=0)],
213-
[Storage(hbm=397616, ddr=0), Storage(hbm=397616, ddr=0)],
209+
[Storage(hbm=188224, ddr=0)],
210+
[Storage(hbm=647776, ddr=0), Storage(hbm=647776, ddr=0)],
214211
[
215-
Storage(hbm=332480, ddr=0),
216-
Storage(hbm=332480, ddr=0),
217-
Storage(hbm=332480, ddr=0),
218-
],
219-
[
220-
Storage(hbm=878752, ddr=0),
221-
Storage(hbm=878752, ddr=0),
212+
Storage(hbm=501120, ddr=0),
213+
Storage(hbm=501120, ddr=0),
214+
Storage(hbm=501120, ddr=0),
222215
],
216+
[Storage(hbm=1544512, ddr=0), Storage(hbm=1544512, ddr=0)],
223217
]
224218

225219

@@ -343,20 +337,20 @@ def setUp(self) -> None:
343337
self.world_size = 8
344338
self.local_world_size = 4
345339
self.constraints = {
346-
"table_0": ParameterConstraints(min_partition=10),
340+
"table_0": ParameterConstraints(min_partition=20),
347341
"table_1": ParameterConstraints(
348-
min_partition=10, pooling_factors=[1, 3, 5]
342+
min_partition=20, pooling_factors=[1, 3, 5]
349343
),
350-
"table_2": ParameterConstraints(min_partition=10, pooling_factors=[8, 2]),
344+
"table_2": ParameterConstraints(min_partition=20, pooling_factors=[8, 2]),
351345
"table_3": ParameterConstraints(
352-
min_partition=20, pooling_factors=[2, 1, 3, 7]
346+
min_partition=40, pooling_factors=[2, 1, 3, 7]
353347
),
354348
}
355349
self.num_tables = 4
356350
tables = [
357351
EmbeddingBagConfig(
358352
num_embeddings=100 + i * 10,
359-
embedding_dim=10 + i * 10,
353+
embedding_dim=20 + i * 20,
360354
name="table_" + str(i),
361355
feature_names=["feature_" + str(i)],
362356
)
@@ -443,7 +437,6 @@ def test_dp_sharding(self) -> None:
443437
expected_storage = [
444438
Storage(hbm=storage_size, ddr=0) for storage_size in storage_sizes
445439
]
446-
447440
self.assertEqual(
448441
[shard.storage for shard in sharding_option.shards], expected_storage
449442
)

0 commit comments

Comments
 (0)