|
9 | 9 |
|
10 | 10 | import copy |
11 | 11 | from dataclasses import dataclass |
| 12 | +from functools import partial |
12 | 13 | from typing import Any, Dict, List, Mapping, Optional, Tuple, TypeVar, Union |
13 | 14 |
|
14 | 15 | import torch |
@@ -46,6 +47,47 @@ def _append_table_shard( |
46 | 47 | d[table_name].append(shard) |
47 | 48 |
|
48 | 49 |
|
| 50 | +def post_state_dict_hook( |
| 51 | + # Union["ShardedQuantEmbeddingBagCollection", "ShardedQuantEmbeddingCollection"] |
| 52 | + # pyre-ignore [24] |
| 53 | + module: ShardedEmbeddingModule, |
| 54 | + destination: Dict[str, torch.Tensor], |
| 55 | + prefix: str, |
| 56 | + _local_metadata: Dict[str, Any], |
| 57 | + tables_weights_prefix: str, # "embedding_bags" or "embeddings" |
| 58 | +) -> None: |
| 59 | + for ( |
| 60 | + table_name, |
| 61 | + sharded_t, |
| 62 | + ) in module._table_name_to_sharded_tensor.items(): |
| 63 | + destination[f"{prefix}{tables_weights_prefix}.{table_name}.weight"] = sharded_t |
| 64 | + |
| 65 | + for sfx, dict_sharded_t, dict_t_list in [ |
| 66 | + ( |
| 67 | + "weight_qscale", |
| 68 | + module._table_name_to_sharded_tensor_qscale, |
| 69 | + module._table_name_to_tensors_list_qscale, |
| 70 | + ), |
| 71 | + ( |
| 72 | + "weight_qbias", |
| 73 | + module._table_name_to_sharded_tensor_qbias, |
| 74 | + module._table_name_to_tensors_list_qbias, |
| 75 | + ), |
| 76 | + ]: |
| 77 | + for ( |
| 78 | + table_name, |
| 79 | + sharded_t, |
| 80 | + ) in dict_sharded_t.items(): |
| 81 | + destination[f"{prefix}{tables_weights_prefix}.{table_name}.{sfx}"] = ( |
| 82 | + sharded_t |
| 83 | + ) |
| 84 | + for ( |
| 85 | + table_name, |
| 86 | + t_list, |
| 87 | + ) in dict_t_list.items(): |
| 88 | + destination[f"{prefix}{tables_weights_prefix}.{table_name}.{sfx}"] = t_list |
| 89 | + |
| 90 | + |
49 | 91 | class ShardedQuantEmbeddingModuleState( |
50 | 92 | ShardedEmbeddingModule[CompIn, DistOut, Out, ShrdCtx] |
51 | 93 | ): |
@@ -82,17 +124,6 @@ def _initialize_torch_state( # noqa: C901 |
82 | 124 | ] = {} |
83 | 125 | self._table_name_to_tensors_list_qbias: Dict[str, List[torch.Tensor]] = {} |
84 | 126 |
|
85 | | - # pruning_index_remappings |
86 | | - self._table_name_to_local_shards_pruning_index_remappings: Dict[ |
87 | | - str, List[Shard] |
88 | | - ] = {} |
89 | | - self._table_name_to_sharded_tensor_pruning_index_remappings: Dict[ |
90 | | - str, Union[torch.Tensor, ShardedTensorBase] |
91 | | - ] = {} |
92 | | - self._table_name_to_tensors_list_pruning_index_remappings: Dict[ |
93 | | - str, List[torch.Tensor] |
94 | | - ] = {} |
95 | | - |
96 | 127 | for tbe, config in tbes.items(): |
97 | 128 | for (tbe_split_w, tbe_split_qscale, tbe_split_qbias), table in zip( |
98 | 129 | tbe.split_embedding_weights_with_scale_bias(split_scale_bias_mode=2), |
@@ -184,43 +215,6 @@ def _initialize_torch_state( # noqa: C901 |
184 | 215 | Shard(tensor=tbe_split_qparam, metadata=qmetadata), |
185 | 216 | ) |
186 | 217 | # end of weight_qscale & weight_qbias section |
187 | | - if table.pruning_indices_remapping is not None: |
188 | | - for ( |
189 | | - qparam, |
190 | | - table_name_to_local_shards, |
191 | | - _, |
192 | | - ) in [ |
193 | | - ( |
194 | | - table.pruning_indices_remapping, |
195 | | - self._table_name_to_local_shards_pruning_index_remappings, |
196 | | - self._table_name_to_tensors_list_pruning_index_remappings, |
197 | | - ) |
198 | | - ]: |
199 | | - parameter_sharding: ParameterSharding = ( |
200 | | - table_name_to_parameter_sharding[table.name] |
201 | | - ) |
202 | | - sharding_type: str = parameter_sharding.sharding_type |
203 | | - |
204 | | - assert sharding_type in [ |
205 | | - ShardingType.TABLE_WISE.value, |
206 | | - ShardingType.COLUMN_WISE.value, |
207 | | - ] |
208 | | - |
209 | | - qmetadata = ShardMetadata( |
210 | | - shard_offsets=[0], |
211 | | - shard_sizes=[ |
212 | | - qparam.shape[0], |
213 | | - ], |
214 | | - placement=table.local_metadata.placement, |
215 | | - ) |
216 | | - # TODO(ivankobzarev): "meta" sharding support: cleanup when copy to "meta" moves all tensors to "meta" |
217 | | - if qmetadata.placement.device != qparam.device: |
218 | | - qmetadata.placement = _remote_device(qparam.device) |
219 | | - _append_table_shard( |
220 | | - table_name_to_local_shards, |
221 | | - table.name, |
222 | | - Shard(tensor=qparam, metadata=qmetadata), |
223 | | - ) |
224 | 218 |
|
225 | 219 | for table_name_to_local_shards, table_name_to_sharded_tensor in [ |
226 | 220 | (self._table_name_to_local_shards, self._table_name_to_sharded_tensor), |
@@ -263,65 +257,9 @@ def _initialize_torch_state( # noqa: C901 |
263 | 257 | ) |
264 | 258 | ) |
265 | 259 |
|
266 | | - for table_name_to_local_shards, table_name_to_sharded_tensor in [ |
267 | | - ( |
268 | | - self._table_name_to_local_shards_pruning_index_remappings, |
269 | | - self._table_name_to_sharded_tensor_pruning_index_remappings, |
270 | | - ), |
271 | | - ]: |
272 | | - for table_name, local_shards in table_name_to_local_shards.items(): |
273 | | - # Single Tensor per table (TW sharding) |
274 | | - table_name_to_sharded_tensor[table_name] = local_shards[0].tensor |
275 | | - continue |
276 | | - |
277 | | - def post_state_dict_hook( |
278 | | - # Union["ShardedQuantEmbeddingBagCollection", "ShardedQuantEmbeddingCollection"] |
279 | | - module: ShardedQuantEmbeddingModuleState[CompIn, DistOut, Out, ShrdCtx], |
280 | | - destination: Dict[str, torch.Tensor], |
281 | | - prefix: str, |
282 | | - _local_metadata: Dict[str, Any], |
283 | | - ) -> None: |
284 | | - for ( |
285 | | - table_name, |
286 | | - sharded_t, |
287 | | - ) in module._table_name_to_sharded_tensor.items(): |
288 | | - destination[f"{prefix}{tables_weights_prefix}.{table_name}.weight"] = ( |
289 | | - sharded_t |
290 | | - ) |
291 | | - |
292 | | - for sfx, dict_sharded_t, dict_t_list in [ |
293 | | - ( |
294 | | - "weight_qscale", |
295 | | - module._table_name_to_sharded_tensor_qscale, |
296 | | - module._table_name_to_tensors_list_qscale, |
297 | | - ), |
298 | | - ( |
299 | | - "weight_qbias", |
300 | | - module._table_name_to_sharded_tensor_qbias, |
301 | | - module._table_name_to_tensors_list_qbias, |
302 | | - ), |
303 | | - ( |
304 | | - "index_remappings_array", |
305 | | - module._table_name_to_sharded_tensor_pruning_index_remappings, |
306 | | - module._table_name_to_tensors_list_pruning_index_remappings, |
307 | | - ), |
308 | | - ]: |
309 | | - for ( |
310 | | - table_name, |
311 | | - sharded_t, |
312 | | - ) in dict_sharded_t.items(): |
313 | | - destination[ |
314 | | - f"{prefix}{tables_weights_prefix}.{table_name}.{sfx}" |
315 | | - ] = sharded_t |
316 | | - for ( |
317 | | - table_name, |
318 | | - t_list, |
319 | | - ) in dict_t_list.items(): |
320 | | - destination[ |
321 | | - f"{prefix}{tables_weights_prefix}.{table_name}.{sfx}" |
322 | | - ] = t_list |
323 | | - |
324 | | - self._register_state_dict_hook(post_state_dict_hook) |
| 260 | + self._register_state_dict_hook( |
| 261 | + partial(post_state_dict_hook, tables_weights_prefix=tables_weights_prefix) |
| 262 | + ) |
325 | 263 |
|
326 | 264 | def _load_from_state_dict( |
327 | 265 | # Union["ShardedQuantEmbeddingBagCollection", "ShardedQuantEmbeddingCollection"] |
|
0 commit comments