Skip to content

Commit 24529f2

Browse files
Jasper Shanfacebook-github-bot
Jasper Shan
authored andcommitted
Refactoring ITEP / PTP Pruning Scuba Logger [3/N] (#3002)
Summary: refactor 3/n Reviewed By: AKhazane Differential Revision: D75108474
1 parent 67ebc8c commit 24529f2

File tree

2 files changed

+190
-205
lines changed

2 files changed

+190
-205
lines changed

torchrec/modules/itep_modules.py

Lines changed: 165 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torchrec.distributed.embedding_types import ShardedEmbeddingTable, ShardingType
2222
from torchrec.distributed.types import Shard, ShardedTensor, ShardedTensorMetadata
2323
from torchrec.modules.embedding_modules import reorder_inverse_indices
24-
from torchrec.modules.pruning_logger import PruningLogger, PruningLoggerDefault
24+
from torchrec.modules.pruning_logger import PruningLogger
2525

2626
from torchrec.sparse.jagged_tensor import _pin_and_move, _to_offsets, KeyedJaggedTensor
2727

@@ -73,113 +73,117 @@ def __init__(
7373
pruning_interval: int = 1001, # Default pruning interval 1001 iterations
7474
pg: Optional[dist.ProcessGroup] = None,
7575
table_name_to_sharding_type: Optional[Dict[str, str]] = None,
76-
scuba_logger: Optional[PruningLogger] = None,
7776
) -> None:
78-
super(GenericITEPModule, self).__init__()
79-
80-
if not table_name_to_sharding_type:
81-
table_name_to_sharding_type = {}
82-
83-
# Construct in-training embedding pruning args
84-
self.enable_pruning: bool = enable_pruning
85-
self.rank_to_virtual_index_mapping: Dict[str, Dict[int, int]] = {}
86-
self.pruning_interval: int = pruning_interval
87-
self.lookups: Optional[List[nn.Module]] = None if not lookups else lookups
88-
self.table_name_to_unpruned_hash_sizes: Dict[str, int] = (
89-
table_name_to_unpruned_hash_sizes
90-
)
91-
self.table_name_to_sharding_type = table_name_to_sharding_type
92-
93-
self.scuba_logger: PruningLogger = (
94-
scuba_logger if scuba_logger is not None else PruningLoggerDefault()
95-
)
96-
self.scuba_logger.log_run_info()
97-
98-
# Map each feature to a physical address_lookup/row_util buffer
99-
self.feature_table_map: Dict[str, int] = {}
100-
self.table_name_to_idx: Dict[str, int] = {}
101-
self.buffer_offsets_list: List[int] = []
102-
self.idx_to_table_name: Dict[int, str] = {}
103-
# Prevent multi-pruning, after moving iteration counter to outside.
104-
self.last_pruned_iter = -1
105-
self.pg = pg
106-
107-
if self.lookups is not None:
108-
self.init_itep_state()
109-
else:
110-
logger.info(
111-
"ITEP init: no lookups provided. Skipping init for dummy module."
77+
with PruningLogger.pruning_logger(event="ITEP_MODULE_INIT") as log_details:
78+
log_details.__setattr__("enable_pruning", enable_pruning)
79+
log_details.__setattr__("pruning_interval", pruning_interval)
80+
81+
super(GenericITEPModule, self).__init__()
82+
83+
if not table_name_to_sharding_type:
84+
table_name_to_sharding_type = {}
85+
86+
# Construct in-training embedding pruning args
87+
self.enable_pruning: bool = enable_pruning
88+
self.rank_to_virtual_index_mapping: Dict[str, Dict[int, int]] = {}
89+
self.pruning_interval: int = pruning_interval
90+
self.lookups: List[nn.Module] = [] if not lookups else lookups
91+
self.table_name_to_unpruned_hash_sizes: Dict[str, int] = (
92+
table_name_to_unpruned_hash_sizes
93+
)
94+
self.table_name_to_sharding_type: Dict[str, str] = (
95+
table_name_to_sharding_type
11296
)
11397

98+
# Map each feature to a physical address_lookup/row_util buffer
99+
self.feature_table_map: Dict[str, int] = {}
100+
self.table_name_to_idx: Dict[str, int] = {}
101+
self.buffer_offsets_list: List[int] = []
102+
self.idx_to_table_name: Dict[int, str] = {}
103+
# Prevent multi-pruning, after moving iteration counter to outside.
104+
self.last_pruned_iter: int = -1
105+
self.pg: Optional[dist.ProcessGroup] = pg
106+
107+
if self.lookups is not None:
108+
self.init_itep_state()
109+
else:
110+
logger.info(
111+
"ITEP init: no lookups provided. Skipping init for dummy module."
112+
)
113+
114114
def print_itep_eviction_stats(
115115
self,
116116
pruned_indices_offsets: torch.Tensor,
117117
pruned_indices_total_length: torch.Tensor,
118118
cur_iter: int,
119119
) -> None:
120-
table_name_to_eviction_ratio = {}
121-
buffer_idx_to_eviction_ratio = {}
122-
buffer_idx_to_sizes = {}
123-
124-
num_buffers = len(self.buffer_offsets_list) - 1
125-
for buffer_idx in range(num_buffers):
126-
pruned_start = pruned_indices_offsets[buffer_idx]
127-
pruned_end = pruned_indices_offsets[buffer_idx + 1]
128-
pruned_length = pruned_end - pruned_start
129-
130-
if pruned_length > 0:
131-
start = self.buffer_offsets_list[buffer_idx]
132-
end = self.buffer_offsets_list[buffer_idx + 1]
133-
buffer_length = end - start
134-
assert buffer_length > 0
135-
eviction_ratio = pruned_length.item() / buffer_length
136-
table_name_to_eviction_ratio[self.idx_to_table_name[buffer_idx]] = (
137-
eviction_ratio
120+
with PruningLogger.pruning_logger(event="ITEP_EVICTION"):
121+
table_name_to_eviction_ratio = {}
122+
buffer_idx_to_eviction_ratio = {}
123+
buffer_idx_to_sizes = {}
124+
125+
num_buffers = len(self.buffer_offsets_list) - 1
126+
for buffer_idx in range(num_buffers):
127+
pruned_start = pruned_indices_offsets[buffer_idx]
128+
pruned_end = pruned_indices_offsets[buffer_idx + 1]
129+
pruned_length = pruned_end - pruned_start
130+
131+
if pruned_length > 0:
132+
start = self.buffer_offsets_list[buffer_idx]
133+
end = self.buffer_offsets_list[buffer_idx + 1]
134+
buffer_length = end - start
135+
assert buffer_length > 0
136+
eviction_ratio = pruned_length.item() / buffer_length
137+
table_name_to_eviction_ratio[self.idx_to_table_name[buffer_idx]] = (
138+
eviction_ratio
139+
)
140+
buffer_idx_to_eviction_ratio[buffer_idx] = eviction_ratio
141+
buffer_idx_to_sizes[buffer_idx] = (
142+
pruned_length.item(),
143+
buffer_length,
144+
)
145+
146+
# Sort the mapping by eviction ratio in descending order
147+
sorted_mapping = dict(
148+
sorted(
149+
table_name_to_eviction_ratio.items(),
150+
key=lambda item: item[1],
151+
reverse=True,
138152
)
139-
buffer_idx_to_eviction_ratio[buffer_idx] = eviction_ratio
140-
buffer_idx_to_sizes[buffer_idx] = (pruned_length.item(), buffer_length)
141-
142-
# Sort the mapping by eviction ratio in descending order
143-
sorted_mapping = dict(
144-
sorted(
145-
table_name_to_eviction_ratio.items(),
146-
key=lambda item: item[1],
147-
reverse=True,
148153
)
149-
)
150154

151-
logged_eviction_mapping = {}
152-
for idx in sorted_mapping.keys():
153-
try:
154-
logged_eviction_mapping[self.reversed_feature_table_map[idx]] = (
155-
sorted_mapping[idx]
156-
)
157-
except KeyError:
158-
# in dummy mode, we don't have the feature_table_map or reversed_feature_table_map
159-
pass
160-
161-
table_to_sizes_mapping = {}
162-
for idx in buffer_idx_to_sizes.keys():
163-
try:
164-
table_to_sizes_mapping[self.reversed_feature_table_map[idx]] = (
165-
buffer_idx_to_sizes[idx]
166-
)
167-
except KeyError:
168-
# in dummy mode, we don't have the feature_table_map or reversed_feature_table_map
169-
pass
170-
171-
# Print the sorted mapping
172-
logger.info(f"ITEP: table name to eviction ratio {sorted_mapping}")
173-
174-
# Calculate percentage of indiced updated/evicted during ITEP iter
175-
pruned_indices_ratio = (
176-
pruned_indices_total_length / self.buffer_offsets_list[-1]
177-
if self.buffer_offsets_list[-1] > 0
178-
else 0
179-
)
180-
logger.info(
181-
f"Performed ITEP in iter {cur_iter}, evicted {pruned_indices_total_length} ({pruned_indices_ratio:%}) indices."
182-
)
155+
logged_eviction_mapping = {}
156+
for idx in sorted_mapping.keys():
157+
try:
158+
logged_eviction_mapping[self.reversed_feature_table_map[idx]] = (
159+
sorted_mapping[idx]
160+
)
161+
except KeyError:
162+
# in dummy mode, we don't have the feature_table_map or reversed_feature_table_map
163+
pass
164+
165+
table_to_sizes_mapping = {}
166+
for idx in buffer_idx_to_sizes.keys():
167+
try:
168+
table_to_sizes_mapping[self.reversed_feature_table_map[idx]] = (
169+
buffer_idx_to_sizes[idx]
170+
)
171+
except KeyError:
172+
# in dummy mode, we don't have the feature_table_map or reversed_feature_table_map
173+
pass
174+
175+
# Print the sorted mapping
176+
logger.info(f"ITEP: table name to eviction ratio {sorted_mapping}")
177+
178+
# Calculate percentage of indiced updated/evicted during ITEP iter
179+
pruned_indices_ratio = (
180+
pruned_indices_total_length / self.buffer_offsets_list[-1]
181+
if self.buffer_offsets_list[-1] > 0
182+
else 0
183+
)
184+
logger.info(
185+
f"Performed ITEP in iter {cur_iter}, evicted {pruned_indices_total_length} ({pruned_indices_ratio:%}) indices."
186+
)
183187

184188
def get_table_hash_sizes(self, table: ShardedEmbeddingTable) -> Tuple[int, int]:
185189
unpruned_hash_size = table.num_embeddings
@@ -251,7 +255,6 @@ def init_itep_state(self) -> None:
251255
self.current_device = None
252256

253257
# Iterate over all tables
254-
# pyre-ignore
255258
for lookup in self.lookups:
256259
while isinstance(lookup, DistributedDataParallel):
257260
lookup = lookup.module
@@ -337,55 +340,49 @@ def reset_weight_momentum(
337340
pruned_indices: torch.Tensor,
338341
pruned_indices_offsets: torch.Tensor,
339342
) -> None:
340-
if self.lookups is not None:
341-
# pyre-ignore
342-
for lookup in self.lookups:
343-
while isinstance(lookup, DistributedDataParallel):
344-
lookup = lookup.module
345-
for emb in lookup._emb_modules:
346-
emb_tables: List[ShardedEmbeddingTable] = (
347-
emb._config.embedding_tables
348-
)
343+
for lookup in self.lookups:
344+
while isinstance(lookup, DistributedDataParallel):
345+
lookup = lookup.module
346+
for emb in lookup._emb_modules:
347+
emb_tables: List[ShardedEmbeddingTable] = emb._config.embedding_tables
349348

350-
logical_idx = 0
351-
logical_table_ids = []
352-
buffer_ids = []
353-
for table in emb_tables:
354-
name = table.name
355-
if name in self.table_name_to_idx:
356-
buffer_idx = self.table_name_to_idx[name]
357-
start = pruned_indices_offsets[buffer_idx]
358-
end = pruned_indices_offsets[buffer_idx + 1]
359-
length = end - start
360-
if length > 0:
361-
logical_table_ids.append(logical_idx)
362-
buffer_ids.append(buffer_idx)
363-
logical_idx += table.num_features()
364-
365-
if len(logical_table_ids) > 0:
366-
emb.emb_module.reset_embedding_weight_momentum(
367-
pruned_indices,
368-
pruned_indices_offsets,
369-
torch.tensor(
370-
logical_table_ids,
371-
dtype=torch.int32,
372-
requires_grad=False,
373-
),
374-
torch.tensor(
375-
buffer_ids, dtype=torch.int32, requires_grad=False
376-
),
377-
)
349+
logical_idx = 0
350+
logical_table_ids = []
351+
buffer_ids = []
352+
for table in emb_tables:
353+
name = table.name
354+
if name in self.table_name_to_idx:
355+
buffer_idx = self.table_name_to_idx[name]
356+
start = pruned_indices_offsets[buffer_idx]
357+
end = pruned_indices_offsets[buffer_idx + 1]
358+
length = end - start
359+
if length > 0:
360+
logical_table_ids.append(logical_idx)
361+
buffer_ids.append(buffer_idx)
362+
logical_idx += table.num_features()
363+
364+
if len(logical_table_ids) > 0:
365+
emb.emb_module.reset_embedding_weight_momentum(
366+
pruned_indices,
367+
pruned_indices_offsets,
368+
torch.tensor(
369+
logical_table_ids,
370+
dtype=torch.int32,
371+
requires_grad=False,
372+
),
373+
torch.tensor(
374+
buffer_ids, dtype=torch.int32, requires_grad=False
375+
),
376+
)
378377

379378
# Flush UVM cache after ITEP eviction to remove stale states
380379
def flush_uvm_cache(self) -> None:
381-
if self.lookups is not None:
382-
# pyre-ignore
383-
for lookup in self.lookups:
384-
while isinstance(lookup, DistributedDataParallel):
385-
lookup = lookup.module
386-
for emb in lookup._emb_modules:
387-
emb.emb_module.flush()
388-
emb.emb_module.reset_cache_states()
380+
for lookup in self.lookups:
381+
while isinstance(lookup, DistributedDataParallel):
382+
lookup = lookup.module
383+
for emb in lookup._emb_modules:
384+
emb.emb_module.flush()
385+
emb.emb_module.reset_cache_states()
389386

390387
def get_remap_info(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:
391388
keys = features.keys()
@@ -460,7 +457,7 @@ def forward(
460457
We use the same forward method for sharded and non-sharded case.
461458
"""
462459

463-
if not self.enable_pruning or self.lookups is None:
460+
if not self.enable_pruning or not self.lookups:
464461
return sparse_features
465462

466463
num_buffers = self.buffer_offsets.size(dim=0) - 1
@@ -695,7 +692,7 @@ def get_key_from_table_name_and_suffix(
695692
key = get_key_from_table_name_and_suffix(table.name, prefix, suffix)
696693
param_idx = self.table_name_to_idx[table.name]
697694
buffer_param: torch.Tensor = get_param(params, param_idx)
698-
sharding_type = self.table_name_to_sharding_type[table.name] # pyre-ignore
695+
sharding_type = self.table_name_to_sharding_type[table.name]
699696

700697
# For inference there is no pg, all tensors are local
701698
if table.global_metadata is not None and pg is not None:
@@ -806,29 +803,28 @@ def state_dict(
806803
destination = OrderedDict()
807804
# pyre-ignore [16]
808805
destination._metadata = OrderedDict()
809-
if self.lookups is not None:
810-
# pyre-ignore [16]
811-
for lookup in self.lookups:
812-
list_of_tables: List[ShardedEmbeddingTable] = []
813-
for emb_config in lookup.grouped_configs:
814-
list_of_tables.extend(emb_config.embedding_tables)
815-
816-
destination = self.get_itp_state_dict(
817-
list_of_tables,
818-
self.address_lookup, # pyre-ignore
819-
self.pg,
820-
destination,
821-
prefix,
822-
suffix="_itp_address_lookup",
823-
dtype=torch.int64,
824-
)
825-
destination = self.get_itp_state_dict(
826-
list_of_tables,
827-
self.row_util, # pyre-ignore
828-
self.pg,
829-
destination,
830-
prefix,
831-
suffix="_itp_row_util",
832-
dtype=torch.float32,
833-
)
806+
for lookup in self.lookups:
807+
list_of_tables: List[ShardedEmbeddingTable] = []
808+
# pyre-ignore [29]
809+
for emb_config in lookup.grouped_configs:
810+
list_of_tables.extend(emb_config.embedding_tables)
811+
812+
destination = self.get_itp_state_dict(
813+
list_of_tables,
814+
self.address_lookup, # pyre-ignore
815+
self.pg,
816+
destination,
817+
prefix,
818+
suffix="_itp_address_lookup",
819+
dtype=torch.int64,
820+
)
821+
destination = self.get_itp_state_dict(
822+
list_of_tables,
823+
self.row_util, # pyre-ignore
824+
self.pg,
825+
destination,
826+
prefix,
827+
suffix="_itp_row_util",
828+
dtype=torch.float32,
829+
)
834830
return destination

0 commit comments

Comments
 (0)