|
34 | 34 |
|
35 | 35 |
|
36 | 36 | EXPECTED_RW_SHARD_SIZES = [
|
37 |
| - [[13, 10], [13, 10], [13, 10], [13, 10], [13, 10], [13, 10], [13, 10], [9, 10]], |
38 |
| - [[14, 20], [14, 20], [14, 20], [14, 20], [14, 20], [14, 20], [14, 20], [12, 20]], |
39 |
| - [[15, 30], [15, 30], [15, 30], [15, 30], [15, 30], [15, 30], [15, 30], [15, 30]], |
40 |
| - [[17, 40], [17, 40], [17, 40], [17, 40], [17, 40], [17, 40], [17, 40], [11, 40]], |
| 37 | + [[13, 20], [13, 20], [13, 20], [13, 20], [13, 20], [13, 20], [13, 20], [9, 20]], |
| 38 | + [[14, 40], [14, 40], [14, 40], [14, 40], [14, 40], [14, 40], [14, 40], [12, 40]], |
| 39 | + [[15, 60], [15, 60], [15, 60], [15, 60], [15, 60], [15, 60], [15, 60], [15, 60]], |
| 40 | + [[17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [11, 80]], |
41 | 41 | ]
|
42 | 42 |
|
43 | 43 | EXPECTED_RW_SHARD_OFFSETS = [
|
|
49 | 49 |
|
50 | 50 | EXPECTED_RW_SHARD_STORAGE = [
|
51 | 51 | [
|
52 |
| - Storage(hbm=84488, ddr=0), |
53 |
| - Storage(hbm=84488, ddr=0), |
54 |
| - Storage(hbm=84488, ddr=0), |
55 |
| - Storage(hbm=84488, ddr=0), |
56 |
| - Storage(hbm=84488, ddr=0), |
57 |
| - Storage(hbm=84488, ddr=0), |
58 |
| - Storage(hbm=84488, ddr=0), |
59 |
| - Storage(hbm=84328, ddr=0), |
| 52 | + Storage(hbm=166928, ddr=0), |
| 53 | + Storage(hbm=166928, ddr=0), |
| 54 | + Storage(hbm=166928, ddr=0), |
| 55 | + Storage(hbm=166928, ddr=0), |
| 56 | + Storage(hbm=166928, ddr=0), |
| 57 | + Storage(hbm=166928, ddr=0), |
| 58 | + Storage(hbm=166928, ddr=0), |
| 59 | + Storage(hbm=166608, ddr=0), |
60 | 60 | ],
|
61 | 61 | [
|
62 |
| - Storage(hbm=511072, ddr=0), |
63 |
| - Storage(hbm=511072, ddr=0), |
64 |
| - Storage(hbm=511072, ddr=0), |
65 |
| - Storage(hbm=511072, ddr=0), |
66 |
| - Storage(hbm=511072, ddr=0), |
67 |
| - Storage(hbm=511072, ddr=0), |
68 |
| - Storage(hbm=511072, ddr=0), |
69 |
| - Storage(hbm=510912, ddr=0), |
| 62 | + Storage(hbm=1003712, ddr=0), |
| 63 | + Storage(hbm=1003712, ddr=0), |
| 64 | + Storage(hbm=1003712, ddr=0), |
| 65 | + Storage(hbm=1003712, ddr=0), |
| 66 | + Storage(hbm=1003712, ddr=0), |
| 67 | + Storage(hbm=1003712, ddr=0), |
| 68 | + Storage(hbm=1003712, ddr=0), |
| 69 | + Storage(hbm=1003392, ddr=0), |
70 | 70 | ],
|
71 | 71 | [
|
72 |
| - Storage(hbm=513800, ddr=0), |
73 |
| - Storage(hbm=513800, ddr=0), |
74 |
| - Storage(hbm=513800, ddr=0), |
75 |
| - Storage(hbm=513800, ddr=0), |
76 |
| - Storage(hbm=513800, ddr=0), |
77 |
| - Storage(hbm=513800, ddr=0), |
78 |
| - Storage(hbm=513800, ddr=0), |
79 |
| - Storage(hbm=513800, ddr=0), |
| 72 | + Storage(hbm=1007120, ddr=0), |
| 73 | + Storage(hbm=1007120, ddr=0), |
| 74 | + Storage(hbm=1007120, ddr=0), |
| 75 | + Storage(hbm=1007120, ddr=0), |
| 76 | + Storage(hbm=1007120, ddr=0), |
| 77 | + Storage(hbm=1007120, ddr=0), |
| 78 | + Storage(hbm=1007120, ddr=0), |
| 79 | + Storage(hbm=1007120, ddr=0), |
80 | 80 | ],
|
81 | 81 | [
|
82 |
| - Storage(hbm=1340064, ddr=0), |
83 |
| - Storage(hbm=1340064, ddr=0), |
84 |
| - Storage(hbm=1340064, ddr=0), |
85 |
| - Storage(hbm=1340064, ddr=0), |
86 |
| - Storage(hbm=1340064, ddr=0), |
87 |
| - Storage(hbm=1340064, ddr=0), |
88 |
| - Storage(hbm=1340064, ddr=0), |
89 |
| - Storage(hbm=1339104, ddr=0), |
| 82 | + Storage(hbm=2653504, ddr=0), |
| 83 | + Storage(hbm=2653504, ddr=0), |
| 84 | + Storage(hbm=2653504, ddr=0), |
| 85 | + Storage(hbm=2653504, ddr=0), |
| 86 | + Storage(hbm=2653504, ddr=0), |
| 87 | + Storage(hbm=2653504, ddr=0), |
| 88 | + Storage(hbm=2653504, ddr=0), |
| 89 | + Storage(hbm=2651584, ddr=0), |
90 | 90 | ],
|
91 | 91 | ]
|
92 | 92 |
|
93 | 93 |
|
94 | 94 | EXPECTED_UVM_CACHING_RW_SHARD_STORAGE = [
|
95 | 95 | [
|
96 |
| - Storage(hbm=84072, ddr=520), |
97 |
| - Storage(hbm=84072, ddr=520), |
98 |
| - Storage(hbm=84072, ddr=520), |
99 |
| - Storage(hbm=84072, ddr=520), |
100 |
| - Storage(hbm=84072, ddr=520), |
101 |
| - Storage(hbm=84072, ddr=520), |
102 |
| - Storage(hbm=84072, ddr=520), |
103 |
| - Storage(hbm=84040, ddr=360), |
| 96 | + Storage(hbm=166096, ddr=1040), |
| 97 | + Storage(hbm=166096, ddr=1040), |
| 98 | + Storage(hbm=166096, ddr=1040), |
| 99 | + Storage(hbm=166096, ddr=1040), |
| 100 | + Storage(hbm=166096, ddr=1040), |
| 101 | + Storage(hbm=166096, ddr=1040), |
| 102 | + Storage(hbm=166096, ddr=1040), |
| 103 | + Storage(hbm=166032, ddr=720), |
104 | 104 | ],
|
105 | 105 | [
|
106 |
| - Storage(hbm=510176, ddr=1120), |
107 |
| - Storage(hbm=510176, ddr=1120), |
108 |
| - Storage(hbm=510176, ddr=1120), |
109 |
| - Storage(hbm=510176, ddr=1120), |
110 |
| - Storage(hbm=510176, ddr=1120), |
111 |
| - Storage(hbm=510176, ddr=1120), |
112 |
| - Storage(hbm=510176, ddr=1120), |
113 |
| - Storage(hbm=510144, ddr=960), |
| 106 | + Storage(hbm=1001920, ddr=2240), |
| 107 | + Storage(hbm=1001920, ddr=2240), |
| 108 | + Storage(hbm=1001920, ddr=2240), |
| 109 | + Storage(hbm=1001920, ddr=2240), |
| 110 | + Storage(hbm=1001920, ddr=2240), |
| 111 | + Storage(hbm=1001920, ddr=2240), |
| 112 | + Storage(hbm=1001920, ddr=2240), |
| 113 | + Storage(hbm=1001856, ddr=1920), |
114 | 114 | ],
|
115 | 115 | [
|
116 |
| - Storage(hbm=512360, ddr=1800), |
117 |
| - Storage(hbm=512360, ddr=1800), |
118 |
| - Storage(hbm=512360, ddr=1800), |
119 |
| - Storage(hbm=512360, ddr=1800), |
120 |
| - Storage(hbm=512360, ddr=1800), |
121 |
| - Storage(hbm=512360, ddr=1800), |
122 |
| - Storage(hbm=512360, ddr=1800), |
123 |
| - Storage(hbm=512360, ddr=1800), |
| 116 | + Storage(hbm=1004240, ddr=3600), |
| 117 | + Storage(hbm=1004240, ddr=3600), |
| 118 | + Storage(hbm=1004240, ddr=3600), |
| 119 | + Storage(hbm=1004240, ddr=3600), |
| 120 | + Storage(hbm=1004240, ddr=3600), |
| 121 | + Storage(hbm=1004240, ddr=3600), |
| 122 | + Storage(hbm=1004240, ddr=3600), |
| 123 | + Storage(hbm=1004240, ddr=3600), |
124 | 124 | ],
|
125 | 125 | [
|
126 |
| - Storage(hbm=1337888, ddr=2720), |
127 |
| - Storage(hbm=1337888, ddr=2720), |
128 |
| - Storage(hbm=1337888, ddr=2720), |
129 |
| - Storage(hbm=1337888, ddr=2720), |
130 |
| - Storage(hbm=1337888, ddr=2720), |
131 |
| - Storage(hbm=1337888, ddr=2720), |
132 |
| - Storage(hbm=1337888, ddr=2720), |
133 |
| - Storage(hbm=1337696, ddr=1760), |
| 126 | + Storage(hbm=2649152, ddr=5440), |
| 127 | + Storage(hbm=2649152, ddr=5440), |
| 128 | + Storage(hbm=2649152, ddr=5440), |
| 129 | + Storage(hbm=2649152, ddr=5440), |
| 130 | + Storage(hbm=2649152, ddr=5440), |
| 131 | + Storage(hbm=2649152, ddr=5440), |
| 132 | + Storage(hbm=2649152, ddr=5440), |
| 133 | + Storage(hbm=2648768, ddr=3520), |
134 | 134 | ],
|
135 | 135 | ]
|
136 | 136 |
|
137 | 137 |
|
138 | 138 | EXPECTED_TWRW_SHARD_SIZES = [
|
139 |
| - [[25, 10], [25, 10], [25, 10], [25, 10]], |
140 |
| - [[28, 20], [28, 20], [28, 20], [26, 20]], |
141 |
| - [[30, 30], [30, 30], [30, 30], [30, 30]], |
142 |
| - [[33, 40], [33, 40], [33, 40], [31, 40]], |
| 139 | + [[25, 20], [25, 20], [25, 20], [25, 20]], |
| 140 | + [[28, 40], [28, 40], [28, 40], [26, 40]], |
| 141 | + [[30, 60], [30, 60], [30, 60], [30, 60]], |
| 142 | + [[33, 80], [33, 80], [33, 80], [31, 80]], |
143 | 143 | ]
|
144 | 144 |
|
145 | 145 | EXPECTED_TWRW_SHARD_OFFSETS = [
|
|
151 | 151 |
|
152 | 152 | EXPECTED_TWRW_SHARD_STORAGE = [
|
153 | 153 | [
|
154 |
| - Storage(hbm=87016, ddr=0), |
155 |
| - Storage(hbm=87016, ddr=0), |
156 |
| - Storage(hbm=87016, ddr=0), |
157 |
| - Storage(hbm=87016, ddr=0), |
| 154 | + Storage(hbm=169936, ddr=0), |
| 155 | + Storage(hbm=169936, ddr=0), |
| 156 | + Storage(hbm=169936, ddr=0), |
| 157 | + Storage(hbm=169936, ddr=0), |
158 | 158 | ],
|
159 | 159 | [
|
160 |
| - Storage(hbm=530624, ddr=0), |
161 |
| - Storage(hbm=530624, ddr=0), |
162 |
| - Storage(hbm=530624, ddr=0), |
163 |
| - Storage(hbm=530464, ddr=0), |
| 160 | + Storage(hbm=1024384, ddr=0), |
| 161 | + Storage(hbm=1024384, ddr=0), |
| 162 | + Storage(hbm=1024384, ddr=0), |
| 163 | + Storage(hbm=1024064, ddr=0), |
164 | 164 | ],
|
165 | 165 | [
|
166 |
| - Storage(hbm=536080, ddr=0), |
167 |
| - Storage(hbm=536080, ddr=0), |
168 |
| - Storage(hbm=536080, ddr=0), |
169 |
| - Storage(hbm=536080, ddr=0), |
| 166 | + Storage(hbm=1031200, ddr=0), |
| 167 | + Storage(hbm=1031200, ddr=0), |
| 168 | + Storage(hbm=1031200, ddr=0), |
| 169 | + Storage(hbm=1031200, ddr=0), |
170 | 170 | ],
|
171 | 171 | [
|
172 |
| - Storage(hbm=1369248, ddr=0), |
173 |
| - Storage(hbm=1369248, ddr=0), |
174 |
| - Storage(hbm=1369248, ddr=0), |
175 |
| - Storage(hbm=1368928, ddr=0), |
| 172 | + Storage(hbm=2685248, ddr=0), |
| 173 | + Storage(hbm=2685248, ddr=0), |
| 174 | + Storage(hbm=2685248, ddr=0), |
| 175 | + Storage(hbm=2684608, ddr=0), |
176 | 176 | ],
|
177 | 177 | ]
|
178 | 178 |
|
179 | 179 | EXPECTED_CW_SHARD_SIZES = [
|
180 |
| - [[100, 10]], |
181 |
| - [[110, 10], [110, 10]], |
182 |
| - [[120, 10], [120, 10], [120, 10]], |
183 |
| - [[130, 20], [130, 20]], |
| 180 | + [[100, 20]], |
| 181 | + [[110, 20], [110, 20]], |
| 182 | + [[120, 20], [120, 20], [120, 20]], |
| 183 | + [[130, 40], [130, 40]], |
184 | 184 | ]
|
185 | 185 |
|
186 | 186 | EXPECTED_CW_SHARD_OFFSETS = [
|
187 | 187 | [[0, 0]],
|
188 |
| - [[0, 0], [0, 10]], |
189 |
| - [[0, 0], [0, 10], [0, 20]], |
190 | 188 | [[0, 0], [0, 20]],
|
| 189 | + [[0, 0], [0, 20], [0, 40]], |
| 190 | + [[0, 0], [0, 40]], |
191 | 191 | ]
|
192 | 192 |
|
193 | 193 | EXPECTED_CW_SHARD_STORAGE = [
|
194 |
| - [Storage(hbm=102304, ddr=0)], |
195 |
| - [Storage(hbm=397616, ddr=0), Storage(hbm=397616, ddr=0)], |
| 194 | + [Storage(hbm=188224, ddr=0)], |
| 195 | + [Storage(hbm=647776, ddr=0), Storage(hbm=647776, ddr=0)], |
196 | 196 | [
|
197 |
| - Storage(hbm=332480, ddr=0), |
198 |
| - Storage(hbm=332480, ddr=0), |
199 |
| - Storage(hbm=332480, ddr=0), |
200 |
| - ], |
201 |
| - [ |
202 |
| - Storage(hbm=878752, ddr=0), |
203 |
| - Storage(hbm=878752, ddr=0), |
| 197 | + Storage(hbm=501120, ddr=0), |
| 198 | + Storage(hbm=501120, ddr=0), |
| 199 | + Storage(hbm=501120, ddr=0), |
204 | 200 | ],
|
| 201 | + [Storage(hbm=1544512, ddr=0), Storage(hbm=1544512, ddr=0)], |
205 | 202 | ]
|
206 | 203 |
|
207 | 204 | EXPECTED_TWCW_SHARD_SIZES: List[List[List[int]]] = EXPECTED_CW_SHARD_SIZES
|
208 | 205 |
|
209 | 206 | EXPECTED_TWCW_SHARD_OFFSETS: List[List[List[int]]] = EXPECTED_CW_SHARD_OFFSETS
|
210 | 207 |
|
211 | 208 | EXPECTED_TWCW_SHARD_STORAGE = [
|
212 |
| - [Storage(hbm=102304, ddr=0)], |
213 |
| - [Storage(hbm=397616, ddr=0), Storage(hbm=397616, ddr=0)], |
| 209 | + [Storage(hbm=188224, ddr=0)], |
| 210 | + [Storage(hbm=647776, ddr=0), Storage(hbm=647776, ddr=0)], |
214 | 211 | [
|
215 |
| - Storage(hbm=332480, ddr=0), |
216 |
| - Storage(hbm=332480, ddr=0), |
217 |
| - Storage(hbm=332480, ddr=0), |
218 |
| - ], |
219 |
| - [ |
220 |
| - Storage(hbm=878752, ddr=0), |
221 |
| - Storage(hbm=878752, ddr=0), |
| 212 | + Storage(hbm=501120, ddr=0), |
| 213 | + Storage(hbm=501120, ddr=0), |
| 214 | + Storage(hbm=501120, ddr=0), |
222 | 215 | ],
|
| 216 | + [Storage(hbm=1544512, ddr=0), Storage(hbm=1544512, ddr=0)], |
223 | 217 | ]
|
224 | 218 |
|
225 | 219 |
|
@@ -343,20 +337,20 @@ def setUp(self) -> None:
|
343 | 337 | self.world_size = 8
|
344 | 338 | self.local_world_size = 4
|
345 | 339 | self.constraints = {
|
346 |
| - "table_0": ParameterConstraints(min_partition=10), |
| 340 | + "table_0": ParameterConstraints(min_partition=20), |
347 | 341 | "table_1": ParameterConstraints(
|
348 |
| - min_partition=10, pooling_factors=[1, 3, 5] |
| 342 | + min_partition=20, pooling_factors=[1, 3, 5] |
349 | 343 | ),
|
350 |
| - "table_2": ParameterConstraints(min_partition=10, pooling_factors=[8, 2]), |
| 344 | + "table_2": ParameterConstraints(min_partition=20, pooling_factors=[8, 2]), |
351 | 345 | "table_3": ParameterConstraints(
|
352 |
| - min_partition=20, pooling_factors=[2, 1, 3, 7] |
| 346 | + min_partition=40, pooling_factors=[2, 1, 3, 7] |
353 | 347 | ),
|
354 | 348 | }
|
355 | 349 | self.num_tables = 4
|
356 | 350 | tables = [
|
357 | 351 | EmbeddingBagConfig(
|
358 | 352 | num_embeddings=100 + i * 10,
|
359 |
| - embedding_dim=10 + i * 10, |
| 353 | + embedding_dim=20 + i * 20, |
360 | 354 | name="table_" + str(i),
|
361 | 355 | feature_names=["feature_" + str(i)],
|
362 | 356 | )
|
@@ -443,7 +437,6 @@ def test_dp_sharding(self) -> None:
|
443 | 437 | expected_storage = [
|
444 | 438 | Storage(hbm=storage_size, ddr=0) for storage_size in storage_sizes
|
445 | 439 | ]
|
446 |
| - |
447 | 440 | self.assertEqual(
|
448 | 441 | [shard.storage for shard in sharding_option.shards], expected_storage
|
449 | 442 | )
|
|
0 commit comments