Skip to content

Refactoring ITEP / PTP Pruning Scuba Logger [1/N] #2986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion torchrec/distributed/itep_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def __init__(
pruning_interval=module._itep_module.pruning_interval,
enable_pruning=module._itep_module.enable_pruning,
pg=env.process_group,
itep_logger=module._itep_module.itep_logger,
)

def prefetch(
Expand Down
17 changes: 5 additions & 12 deletions torchrec/modules/itep_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torchrec.distributed.embedding_types import ShardedEmbeddingTable, ShardingType
from torchrec.distributed.types import Shard, ShardedTensor, ShardedTensorMetadata
from torchrec.modules.embedding_modules import reorder_inverse_indices
from torchrec.modules.itep_logger import ITEPLogger, ITEPLoggerDefault
from torchrec.modules.pruning_logger import PruningLogger, PruningLoggerDefault

from torchrec.sparse.jagged_tensor import _pin_and_move, _to_offsets, KeyedJaggedTensor

Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(
pruning_interval: int = 1001, # Default pruning interval 1001 iterations
pg: Optional[dist.ProcessGroup] = None,
table_name_to_sharding_type: Optional[Dict[str, str]] = None,
itep_logger: Optional[ITEPLogger] = None,
scuba_logger: Optional[PruningLogger] = None,
) -> None:
super(GenericITEPModule, self).__init__()

Expand All @@ -90,10 +90,10 @@ def __init__(
)
self.table_name_to_sharding_type = table_name_to_sharding_type

self.itep_logger: ITEPLogger = (
itep_logger if itep_logger is not None else ITEPLoggerDefault()
self.scuba_logger: PruningLogger = (
scuba_logger if scuba_logger is not None else PruningLoggerDefault()
)
self.itep_logger.log_run_info()
self.scuba_logger.log_run_info()

# Map each feature to a physical address_lookup/row_util buffer
self.feature_table_map: Dict[str, int] = {}
Expand Down Expand Up @@ -168,13 +168,6 @@ def print_itep_eviction_stats(
# in dummy mode, we don't have the feature_table_map or reversed_feature_table_map
pass

self.itep_logger.log_table_eviction_info(
iteration=None,
rank=None,
table_to_sizes_mapping=table_to_sizes_mapping,
eviction_tables=logged_eviction_mapping,
)

# Print the sorted mapping
logger.info(f"ITEP: table name to eviction ratio {sorted_mapping}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
logger: logging.Logger = logging.getLogger(__name__)


class ITEPLogger(ABC):
class PruningLogger(ABC):
@abstractmethod
def log_table_eviction_info(
self,
Expand All @@ -30,7 +30,7 @@ def log_run_info(
pass


class ITEPLoggerDefault(ITEPLogger):
class PruningLoggerDefault(PruningLogger):
"""
noop logger as a default
"""
Expand All @@ -39,7 +39,7 @@ def __init__(
self,
) -> None:
"""
Initialize ITEPLoggerScuba.
Initialize PruningScubaLogger.
"""
pass

Expand Down
Loading