Skip to content

Commit 32e5431

Browse files
Mengjiao Zhoumeta-codesync[bot]
authored andcommitted
sync opt_in metric fix
Summary: sync opt_in metric fix D84439437 to github - Note: due to OSS version is pretty off from internal version, I have to migrate more changes. Reviewed By: kausv Differential Revision: D86741254 fbshipit-source-id: e3c931a00881c0ea1fc81eee67f023f5d2305d57
1 parent 7de7d04 commit 32e5431

File tree

2 files changed

+51
-6
lines changed

2 files changed

+51
-6
lines changed

torchrec/modules/hash_mc_metrics.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ def __init__(
4242
zch_size: int,
4343
frequency: int,
4444
start_bucket: int,
45+
num_buckets_per_rank: int,
46+
num_reserved_slots_per_bucket: int,
47+
device: torch.device,
4548
disable_fallback: bool,
4649
log_file_path: str = "",
4750
) -> None:
@@ -57,8 +60,32 @@ def __init__(
5760
self._zch_size: int = zch_size
5861
self._frequency: int = frequency
5962
self._start_bucket: int = start_bucket
63+
self._num_buckets_per_rank: int = num_buckets_per_rank
64+
self._num_reserved_slots_per_bucket: int = num_reserved_slots_per_bucket
65+
self._device: torch.device = device
6066
self._disable_fallback: bool = disable_fallback
6167

68+
assert (
69+
self._zch_size % self._num_buckets_per_rank == 0
70+
), f"{self._zch_size} must be divisible by {self._num_buckets_per_rank}"
71+
indice_per_bucket = torch.tensor(
72+
[
73+
(self._zch_size // self._num_buckets_per_rank) * bucket
74+
for bucket in range(1, self._num_buckets_per_rank + 1)
75+
],
76+
dtype=torch.int64,
77+
device=self._device,
78+
)
79+
80+
self._opt_in_ranges: torch.Tensor = torch.sub(
81+
indice_per_bucket,
82+
(
83+
self._num_reserved_slots_per_bucket
84+
if self._num_reserved_slots_per_bucket > 0
85+
else 0
86+
),
87+
)
88+
6289
self._dtype_checked: bool = False
6390
self._total_cnt: int = 0
6491
self._hit_cnt: int = 0
@@ -77,6 +104,13 @@ def __init__(
77104
) # initialize file handler
78105
self.logger.addHandler(file_handler) # add file handler to logger
79106

107+
self.logger.info(
108+
f"ScalarLogger: {self._name=}, {self._device=}, "
109+
f"{self._zch_size=}, {self._frequency=}, {self._start_bucket=}, "
110+
f"{self._num_buckets_per_rank=}, {self._num_reserved_slots_per_bucket=}, "
111+
f"{self._opt_in_ranges=}, {self._disable_fallback=}"
112+
)
113+
80114
def should_report(self) -> bool:
81115
# We only need to report metrics from rank0 (start_bucket = 0)
82116

@@ -95,9 +129,9 @@ def update(
95129
identities_1: torch.Tensor,
96130
values: torch.Tensor,
97131
remapped_ids: torch.Tensor,
132+
hit_indices: torch.Tensor,
98133
evicted_emb_indices: Optional[torch.Tensor],
99134
metadata: Optional[torch.Tensor],
100-
num_reserved_slots: int,
101135
eviction_config: Optional[HashZchEvictionConfig] = None,
102136
) -> None:
103137
if not self._dtype_checked:
@@ -125,9 +159,17 @@ def update(
125159
self._hit_cnt += hit_cnt
126160
self._collision_cnt += values.numel() - hit_cnt - insert_cnt
127161

128-
opt_in_range = self._zch_size - num_reserved_slots
129-
opt_in_ids = torch.lt(remapped_ids, opt_in_range)
130-
self._opt_in_cnt += int(torch.sum(opt_in_ids).item())
162+
if self._disable_fallback:
163+
hit_values = values[hit_indices]
164+
train_buckets = hit_values % self._num_buckets_per_rank
165+
else:
166+
train_buckets = values % self._num_buckets_per_rank
167+
168+
opt_in_ranges = self._opt_in_ranges.index_select(dim=0, index=train_buckets)
169+
opt_in_ids = torch.lt(remapped_ids, opt_in_ranges)
170+
opt_in_ids_cnt = int(torch.sum(opt_in_ids).item())
171+
# opt_in_cnt: # of ids assigned to opt-in block
172+
self._opt_in_cnt += opt_in_ids_cnt
131173

132174
if evicted_emb_indices is not None and evicted_emb_indices.numel() > 0:
133175
deduped_evicted_indices = torch.unique(evicted_emb_indices)

torchrec/modules/hash_mc_modules.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,9 @@ def __init__(
286286
zch_size=self._zch_size,
287287
frequency=self._tb_logging_frequency,
288288
start_bucket=self._start_bucket,
289+
num_buckets_per_rank=self._end_bucket - self._start_bucket,
290+
num_reserved_slots_per_bucket=self.get_reserved_slots_per_bucket(),
291+
device=self._device,
289292
disable_fallback=self._disable_fallback,
290293
)
291294
else:
@@ -542,9 +545,9 @@ def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
542545
# record the on-device remapped ids
543546
self.table_name_on_device_remapped_ids_dict[name] = remapped_ids.clone()
544547
lengths: torch.Tensor = feature.lengths()
548+
hit_indices = remapped_ids != -1
545549
if self._disable_fallback:
546550
# Only works on GPU when read only is true.
547-
hit_indices = remapped_ids != -1
548551
remapped_ids = remapped_ids[hit_indices]
549552
lengths = torch.masked_fill(lengths, ~hit_indices, 0)
550553
if self._scalar_logger is not None:
@@ -554,9 +557,9 @@ def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
554557
identities_1=self._hash_zch_identities,
555558
values=values,
556559
remapped_ids=remapped_ids,
560+
hit_indices=hit_indices,
557561
evicted_emb_indices=evictions,
558562
metadata=metadata,
559-
num_reserved_slots=num_reserved_slots,
560563
eviction_config=self._eviction_config,
561564
)
562565

0 commit comments

Comments
 (0)