|
21 | 21 | from torchrec.distributed.embedding_types import ShardedEmbeddingTable, ShardingType
|
22 | 22 | from torchrec.distributed.types import Shard, ShardedTensor, ShardedTensorMetadata
|
23 | 23 | from torchrec.modules.embedding_modules import reorder_inverse_indices
|
24 |
| -from torchrec.modules.pruning_logger import PruningLogger, PruningLoggerDefault |
| 24 | +from torchrec.modules.pruning_logger import PruningLogger |
25 | 25 |
|
26 | 26 | from torchrec.sparse.jagged_tensor import _pin_and_move, _to_offsets, KeyedJaggedTensor
|
27 | 27 |
|
@@ -73,113 +73,117 @@ def __init__(
|
73 | 73 | pruning_interval: int = 1001, # Default pruning interval 1001 iterations
|
74 | 74 | pg: Optional[dist.ProcessGroup] = None,
|
75 | 75 | table_name_to_sharding_type: Optional[Dict[str, str]] = None,
|
76 |
| - scuba_logger: Optional[PruningLogger] = None, |
77 | 76 | ) -> None:
|
78 |
| - super(GenericITEPModule, self).__init__() |
79 |
| - |
80 |
| - if not table_name_to_sharding_type: |
81 |
| - table_name_to_sharding_type = {} |
82 |
| - |
83 |
| - # Construct in-training embedding pruning args |
84 |
| - self.enable_pruning: bool = enable_pruning |
85 |
| - self.rank_to_virtual_index_mapping: Dict[str, Dict[int, int]] = {} |
86 |
| - self.pruning_interval: int = pruning_interval |
87 |
| - self.lookups: Optional[List[nn.Module]] = None if not lookups else lookups |
88 |
| - self.table_name_to_unpruned_hash_sizes: Dict[str, int] = ( |
89 |
| - table_name_to_unpruned_hash_sizes |
90 |
| - ) |
91 |
| - self.table_name_to_sharding_type = table_name_to_sharding_type |
92 |
| - |
93 |
| - self.scuba_logger: PruningLogger = ( |
94 |
| - scuba_logger if scuba_logger is not None else PruningLoggerDefault() |
95 |
| - ) |
96 |
| - self.scuba_logger.log_run_info() |
97 |
| - |
98 |
| - # Map each feature to a physical address_lookup/row_util buffer |
99 |
| - self.feature_table_map: Dict[str, int] = {} |
100 |
| - self.table_name_to_idx: Dict[str, int] = {} |
101 |
| - self.buffer_offsets_list: List[int] = [] |
102 |
| - self.idx_to_table_name: Dict[int, str] = {} |
103 |
| - # Prevent multi-pruning, after moving iteration counter to outside. |
104 |
| - self.last_pruned_iter = -1 |
105 |
| - self.pg = pg |
106 |
| - |
107 |
| - if self.lookups is not None: |
108 |
| - self.init_itep_state() |
109 |
| - else: |
110 |
| - logger.info( |
111 |
| - "ITEP init: no lookups provided. Skipping init for dummy module." |
| 77 | + with PruningLogger.pruning_logger(event="ITEP_MODULE_INIT") as log_details: |
| 78 | + log_details.__setattr__("enable_pruning", enable_pruning) |
| 79 | + log_details.__setattr__("pruning_interval", pruning_interval) |
| 80 | + |
| 81 | + super(GenericITEPModule, self).__init__() |
| 82 | + |
| 83 | + if not table_name_to_sharding_type: |
| 84 | + table_name_to_sharding_type = {} |
| 85 | + |
| 86 | + # Construct in-training embedding pruning args |
| 87 | + self.enable_pruning: bool = enable_pruning |
| 88 | + self.rank_to_virtual_index_mapping: Dict[str, Dict[int, int]] = {} |
| 89 | + self.pruning_interval: int = pruning_interval |
| 90 | + self.lookups: List[nn.Module] = [] if not lookups else lookups |
| 91 | + self.table_name_to_unpruned_hash_sizes: Dict[str, int] = ( |
| 92 | + table_name_to_unpruned_hash_sizes |
| 93 | + ) |
| 94 | + self.table_name_to_sharding_type: Dict[str, str] = ( |
| 95 | + table_name_to_sharding_type |
112 | 96 | )
|
113 | 97 |
|
| 98 | + # Map each feature to a physical address_lookup/row_util buffer |
| 99 | + self.feature_table_map: Dict[str, int] = {} |
| 100 | + self.table_name_to_idx: Dict[str, int] = {} |
| 101 | + self.buffer_offsets_list: List[int] = [] |
| 102 | + self.idx_to_table_name: Dict[int, str] = {} |
| 103 | + # Prevent multi-pruning, after moving iteration counter to outside. |
| 104 | + self.last_pruned_iter: int = -1 |
| 105 | + self.pg: Optional[dist.ProcessGroup] = pg |
| 106 | + |
| 107 | + if self.lookups is not None: |
| 108 | + self.init_itep_state() |
| 109 | + else: |
| 110 | + logger.info( |
| 111 | + "ITEP init: no lookups provided. Skipping init for dummy module." |
| 112 | + ) |
| 113 | + |
114 | 114 | def print_itep_eviction_stats(
|
115 | 115 | self,
|
116 | 116 | pruned_indices_offsets: torch.Tensor,
|
117 | 117 | pruned_indices_total_length: torch.Tensor,
|
118 | 118 | cur_iter: int,
|
119 | 119 | ) -> None:
|
120 |
| - table_name_to_eviction_ratio = {} |
121 |
| - buffer_idx_to_eviction_ratio = {} |
122 |
| - buffer_idx_to_sizes = {} |
123 |
| - |
124 |
| - num_buffers = len(self.buffer_offsets_list) - 1 |
125 |
| - for buffer_idx in range(num_buffers): |
126 |
| - pruned_start = pruned_indices_offsets[buffer_idx] |
127 |
| - pruned_end = pruned_indices_offsets[buffer_idx + 1] |
128 |
| - pruned_length = pruned_end - pruned_start |
129 |
| - |
130 |
| - if pruned_length > 0: |
131 |
| - start = self.buffer_offsets_list[buffer_idx] |
132 |
| - end = self.buffer_offsets_list[buffer_idx + 1] |
133 |
| - buffer_length = end - start |
134 |
| - assert buffer_length > 0 |
135 |
| - eviction_ratio = pruned_length.item() / buffer_length |
136 |
| - table_name_to_eviction_ratio[self.idx_to_table_name[buffer_idx]] = ( |
137 |
| - eviction_ratio |
| 120 | + with PruningLogger.pruning_logger(event="ITEP_EVICTION"): |
| 121 | + table_name_to_eviction_ratio = {} |
| 122 | + buffer_idx_to_eviction_ratio = {} |
| 123 | + buffer_idx_to_sizes = {} |
| 124 | + |
| 125 | + num_buffers = len(self.buffer_offsets_list) - 1 |
| 126 | + for buffer_idx in range(num_buffers): |
| 127 | + pruned_start = pruned_indices_offsets[buffer_idx] |
| 128 | + pruned_end = pruned_indices_offsets[buffer_idx + 1] |
| 129 | + pruned_length = pruned_end - pruned_start |
| 130 | + |
| 131 | + if pruned_length > 0: |
| 132 | + start = self.buffer_offsets_list[buffer_idx] |
| 133 | + end = self.buffer_offsets_list[buffer_idx + 1] |
| 134 | + buffer_length = end - start |
| 135 | + assert buffer_length > 0 |
| 136 | + eviction_ratio = pruned_length.item() / buffer_length |
| 137 | + table_name_to_eviction_ratio[self.idx_to_table_name[buffer_idx]] = ( |
| 138 | + eviction_ratio |
| 139 | + ) |
| 140 | + buffer_idx_to_eviction_ratio[buffer_idx] = eviction_ratio |
| 141 | + buffer_idx_to_sizes[buffer_idx] = ( |
| 142 | + pruned_length.item(), |
| 143 | + buffer_length, |
| 144 | + ) |
| 145 | + |
| 146 | + # Sort the mapping by eviction ratio in descending order |
| 147 | + sorted_mapping = dict( |
| 148 | + sorted( |
| 149 | + table_name_to_eviction_ratio.items(), |
| 150 | + key=lambda item: item[1], |
| 151 | + reverse=True, |
138 | 152 | )
|
139 |
| - buffer_idx_to_eviction_ratio[buffer_idx] = eviction_ratio |
140 |
| - buffer_idx_to_sizes[buffer_idx] = (pruned_length.item(), buffer_length) |
141 |
| - |
142 |
| - # Sort the mapping by eviction ratio in descending order |
143 |
| - sorted_mapping = dict( |
144 |
| - sorted( |
145 |
| - table_name_to_eviction_ratio.items(), |
146 |
| - key=lambda item: item[1], |
147 |
| - reverse=True, |
148 | 153 | )
|
149 |
| - ) |
150 | 154 |
|
151 |
| - logged_eviction_mapping = {} |
152 |
| - for idx in sorted_mapping.keys(): |
153 |
| - try: |
154 |
| - logged_eviction_mapping[self.reversed_feature_table_map[idx]] = ( |
155 |
| - sorted_mapping[idx] |
156 |
| - ) |
157 |
| - except KeyError: |
158 |
| - # in dummy mode, we don't have the feature_table_map or reversed_feature_table_map |
159 |
| - pass |
160 |
| - |
161 |
| - table_to_sizes_mapping = {} |
162 |
| - for idx in buffer_idx_to_sizes.keys(): |
163 |
| - try: |
164 |
| - table_to_sizes_mapping[self.reversed_feature_table_map[idx]] = ( |
165 |
| - buffer_idx_to_sizes[idx] |
166 |
| - ) |
167 |
| - except KeyError: |
168 |
| - # in dummy mode, we don't have the feature_table_map or reversed_feature_table_map |
169 |
| - pass |
170 |
| - |
171 |
| - # Print the sorted mapping |
172 |
| - logger.info(f"ITEP: table name to eviction ratio {sorted_mapping}") |
173 |
| - |
174 |
| - # Calculate percentage of indiced updated/evicted during ITEP iter |
175 |
| - pruned_indices_ratio = ( |
176 |
| - pruned_indices_total_length / self.buffer_offsets_list[-1] |
177 |
| - if self.buffer_offsets_list[-1] > 0 |
178 |
| - else 0 |
179 |
| - ) |
180 |
| - logger.info( |
181 |
| - f"Performed ITEP in iter {cur_iter}, evicted {pruned_indices_total_length} ({pruned_indices_ratio:%}) indices." |
182 |
| - ) |
| 155 | + logged_eviction_mapping = {} |
| 156 | + for idx in sorted_mapping.keys(): |
| 157 | + try: |
| 158 | + logged_eviction_mapping[self.reversed_feature_table_map[idx]] = ( |
| 159 | + sorted_mapping[idx] |
| 160 | + ) |
| 161 | + except KeyError: |
| 162 | + # in dummy mode, we don't have the feature_table_map or reversed_feature_table_map |
| 163 | + pass |
| 164 | + |
| 165 | + table_to_sizes_mapping = {} |
| 166 | + for idx in buffer_idx_to_sizes.keys(): |
| 167 | + try: |
| 168 | + table_to_sizes_mapping[self.reversed_feature_table_map[idx]] = ( |
| 169 | + buffer_idx_to_sizes[idx] |
| 170 | + ) |
| 171 | + except KeyError: |
| 172 | + # in dummy mode, we don't have the feature_table_map or reversed_feature_table_map |
| 173 | + pass |
| 174 | + |
| 175 | + # Print the sorted mapping |
| 176 | + logger.info(f"ITEP: table name to eviction ratio {sorted_mapping}") |
| 177 | + |
| 178 | + # Calculate percentage of indiced updated/evicted during ITEP iter |
| 179 | + pruned_indices_ratio = ( |
| 180 | + pruned_indices_total_length / self.buffer_offsets_list[-1] |
| 181 | + if self.buffer_offsets_list[-1] > 0 |
| 182 | + else 0 |
| 183 | + ) |
| 184 | + logger.info( |
| 185 | + f"Performed ITEP in iter {cur_iter}, evicted {pruned_indices_total_length} ({pruned_indices_ratio:%}) indices." |
| 186 | + ) |
183 | 187 |
|
184 | 188 | def get_table_hash_sizes(self, table: ShardedEmbeddingTable) -> Tuple[int, int]:
|
185 | 189 | unpruned_hash_size = table.num_embeddings
|
@@ -251,7 +255,6 @@ def init_itep_state(self) -> None:
|
251 | 255 | self.current_device = None
|
252 | 256 |
|
253 | 257 | # Iterate over all tables
|
254 |
| - # pyre-ignore |
255 | 258 | for lookup in self.lookups:
|
256 | 259 | while isinstance(lookup, DistributedDataParallel):
|
257 | 260 | lookup = lookup.module
|
@@ -337,55 +340,49 @@ def reset_weight_momentum(
|
337 | 340 | pruned_indices: torch.Tensor,
|
338 | 341 | pruned_indices_offsets: torch.Tensor,
|
339 | 342 | ) -> None:
|
340 |
| - if self.lookups is not None: |
341 |
| - # pyre-ignore |
342 |
| - for lookup in self.lookups: |
343 |
| - while isinstance(lookup, DistributedDataParallel): |
344 |
| - lookup = lookup.module |
345 |
| - for emb in lookup._emb_modules: |
346 |
| - emb_tables: List[ShardedEmbeddingTable] = ( |
347 |
| - emb._config.embedding_tables |
348 |
| - ) |
| 343 | + for lookup in self.lookups: |
| 344 | + while isinstance(lookup, DistributedDataParallel): |
| 345 | + lookup = lookup.module |
| 346 | + for emb in lookup._emb_modules: |
| 347 | + emb_tables: List[ShardedEmbeddingTable] = emb._config.embedding_tables |
349 | 348 |
|
350 |
| - logical_idx = 0 |
351 |
| - logical_table_ids = [] |
352 |
| - buffer_ids = [] |
353 |
| - for table in emb_tables: |
354 |
| - name = table.name |
355 |
| - if name in self.table_name_to_idx: |
356 |
| - buffer_idx = self.table_name_to_idx[name] |
357 |
| - start = pruned_indices_offsets[buffer_idx] |
358 |
| - end = pruned_indices_offsets[buffer_idx + 1] |
359 |
| - length = end - start |
360 |
| - if length > 0: |
361 |
| - logical_table_ids.append(logical_idx) |
362 |
| - buffer_ids.append(buffer_idx) |
363 |
| - logical_idx += table.num_features() |
364 |
| - |
365 |
| - if len(logical_table_ids) > 0: |
366 |
| - emb.emb_module.reset_embedding_weight_momentum( |
367 |
| - pruned_indices, |
368 |
| - pruned_indices_offsets, |
369 |
| - torch.tensor( |
370 |
| - logical_table_ids, |
371 |
| - dtype=torch.int32, |
372 |
| - requires_grad=False, |
373 |
| - ), |
374 |
| - torch.tensor( |
375 |
| - buffer_ids, dtype=torch.int32, requires_grad=False |
376 |
| - ), |
377 |
| - ) |
| 349 | + logical_idx = 0 |
| 350 | + logical_table_ids = [] |
| 351 | + buffer_ids = [] |
| 352 | + for table in emb_tables: |
| 353 | + name = table.name |
| 354 | + if name in self.table_name_to_idx: |
| 355 | + buffer_idx = self.table_name_to_idx[name] |
| 356 | + start = pruned_indices_offsets[buffer_idx] |
| 357 | + end = pruned_indices_offsets[buffer_idx + 1] |
| 358 | + length = end - start |
| 359 | + if length > 0: |
| 360 | + logical_table_ids.append(logical_idx) |
| 361 | + buffer_ids.append(buffer_idx) |
| 362 | + logical_idx += table.num_features() |
| 363 | + |
| 364 | + if len(logical_table_ids) > 0: |
| 365 | + emb.emb_module.reset_embedding_weight_momentum( |
| 366 | + pruned_indices, |
| 367 | + pruned_indices_offsets, |
| 368 | + torch.tensor( |
| 369 | + logical_table_ids, |
| 370 | + dtype=torch.int32, |
| 371 | + requires_grad=False, |
| 372 | + ), |
| 373 | + torch.tensor( |
| 374 | + buffer_ids, dtype=torch.int32, requires_grad=False |
| 375 | + ), |
| 376 | + ) |
378 | 377 |
|
379 | 378 | # Flush UVM cache after ITEP eviction to remove stale states
|
380 | 379 | def flush_uvm_cache(self) -> None:
|
381 |
| - if self.lookups is not None: |
382 |
| - # pyre-ignore |
383 |
| - for lookup in self.lookups: |
384 |
| - while isinstance(lookup, DistributedDataParallel): |
385 |
| - lookup = lookup.module |
386 |
| - for emb in lookup._emb_modules: |
387 |
| - emb.emb_module.flush() |
388 |
| - emb.emb_module.reset_cache_states() |
| 380 | + for lookup in self.lookups: |
| 381 | + while isinstance(lookup, DistributedDataParallel): |
| 382 | + lookup = lookup.module |
| 383 | + for emb in lookup._emb_modules: |
| 384 | + emb.emb_module.flush() |
| 385 | + emb.emb_module.reset_cache_states() |
389 | 386 |
|
390 | 387 | def get_remap_info(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:
|
391 | 388 | keys = features.keys()
|
@@ -460,7 +457,7 @@ def forward(
|
460 | 457 | We use the same forward method for sharded and non-sharded case.
|
461 | 458 | """
|
462 | 459 |
|
463 |
| - if not self.enable_pruning or self.lookups is None: |
| 460 | + if not self.enable_pruning or not self.lookups: |
464 | 461 | return sparse_features
|
465 | 462 |
|
466 | 463 | num_buffers = self.buffer_offsets.size(dim=0) - 1
|
@@ -695,7 +692,7 @@ def get_key_from_table_name_and_suffix(
|
695 | 692 | key = get_key_from_table_name_and_suffix(table.name, prefix, suffix)
|
696 | 693 | param_idx = self.table_name_to_idx[table.name]
|
697 | 694 | buffer_param: torch.Tensor = get_param(params, param_idx)
|
698 |
| - sharding_type = self.table_name_to_sharding_type[table.name] # pyre-ignore |
| 695 | + sharding_type = self.table_name_to_sharding_type[table.name] |
699 | 696 |
|
700 | 697 | # For inference there is no pg, all tensors are local
|
701 | 698 | if table.global_metadata is not None and pg is not None:
|
@@ -806,29 +803,28 @@ def state_dict(
|
806 | 803 | destination = OrderedDict()
|
807 | 804 | # pyre-ignore [16]
|
808 | 805 | destination._metadata = OrderedDict()
|
809 |
| - if self.lookups is not None: |
810 |
| - # pyre-ignore [16] |
811 |
| - for lookup in self.lookups: |
812 |
| - list_of_tables: List[ShardedEmbeddingTable] = [] |
813 |
| - for emb_config in lookup.grouped_configs: |
814 |
| - list_of_tables.extend(emb_config.embedding_tables) |
815 |
| - |
816 |
| - destination = self.get_itp_state_dict( |
817 |
| - list_of_tables, |
818 |
| - self.address_lookup, # pyre-ignore |
819 |
| - self.pg, |
820 |
| - destination, |
821 |
| - prefix, |
822 |
| - suffix="_itp_address_lookup", |
823 |
| - dtype=torch.int64, |
824 |
| - ) |
825 |
| - destination = self.get_itp_state_dict( |
826 |
| - list_of_tables, |
827 |
| - self.row_util, # pyre-ignore |
828 |
| - self.pg, |
829 |
| - destination, |
830 |
| - prefix, |
831 |
| - suffix="_itp_row_util", |
832 |
| - dtype=torch.float32, |
833 |
| - ) |
| 806 | + for lookup in self.lookups: |
| 807 | + list_of_tables: List[ShardedEmbeddingTable] = [] |
| 808 | + # pyre-ignore [29] |
| 809 | + for emb_config in lookup.grouped_configs: |
| 810 | + list_of_tables.extend(emb_config.embedding_tables) |
| 811 | + |
| 812 | + destination = self.get_itp_state_dict( |
| 813 | + list_of_tables, |
| 814 | + self.address_lookup, # pyre-ignore |
| 815 | + self.pg, |
| 816 | + destination, |
| 817 | + prefix, |
| 818 | + suffix="_itp_address_lookup", |
| 819 | + dtype=torch.int64, |
| 820 | + ) |
| 821 | + destination = self.get_itp_state_dict( |
| 822 | + list_of_tables, |
| 823 | + self.row_util, # pyre-ignore |
| 824 | + self.pg, |
| 825 | + destination, |
| 826 | + prefix, |
| 827 | + suffix="_itp_row_util", |
| 828 | + dtype=torch.float32, |
| 829 | + ) |
834 | 830 | return destination
|
0 commit comments