Skip to content

Commit 7ce853c

Browse files
aliafzalmeta-codesync[bot]
authored andcommitted
Raw id tracker store (#3541)
Summary: Pull Request resolved: #3541 Introducing a seperate store for raw id tracker for specifically trackiing ids from RawIdTracker. Reviewed By: chouxi Differential Revision: D86524689 fbshipit-source-id: 143e22622a12cf29b1ff378ead1172c403428528
1 parent b9c1e52 commit 7ce853c

File tree

3 files changed

+78
-7
lines changed

3 files changed

+78
-7
lines changed

torchrec/distributed/model_tracker/delta_store.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515
from torchrec.distributed.model_tracker.types import (
1616
IndexedLookup,
17+
RawIndexedLookup,
1718
UniqueRows,
1819
UpdateMode,
1920
)
@@ -90,7 +91,7 @@ def append(
9091
batch_idx: int,
9192
fqn: str,
9293
ids: torch.Tensor,
93-
states: Optional[torch.Tensor],
94+
states: Optional[torch.Tensor] = None,
9495
raw_ids: Optional[torch.Tensor] = None,
9596
) -> None:
9697
"""
@@ -162,12 +163,12 @@ def append(
162163
batch_idx: int,
163164
fqn: str,
164165
ids: torch.Tensor,
165-
states: Optional[torch.Tensor],
166+
states: Optional[torch.Tensor] = None,
166167
raw_ids: Optional[torch.Tensor] = None,
167168
) -> None:
168169
table_fqn_lookup = self.per_fqn_lookups.get(fqn, [])
169170
table_fqn_lookup.append(
170-
IndexedLookup(batch_idx=batch_idx, ids=ids, states=states, raw_ids=raw_ids)
171+
IndexedLookup(batch_idx=batch_idx, ids=ids, states=states)
171172
)
172173
self.per_fqn_lookups[fqn] = table_fqn_lookup
173174

@@ -264,3 +265,64 @@ def get_unique(self, from_idx: int = 0) -> Dict[str, UniqueRows]:
264265
ids=compact_ids, states=compact_states, mode=self.updateMode
265266
)
266267
return delta_per_table_fqn
268+
269+
270+
class RawIdTrackerStore(DeltaStore):
271+
"""
272+
RawIdTrackerStore is a concrete implementation of DeltaStore that stores and manages raw ids tracked by RawIdTracker.
273+
"""
274+
275+
def __init__(self, updateMode: UpdateMode = UpdateMode.NONE) -> None:
276+
super().__init__(updateMode)
277+
self.updateMode = updateMode
278+
self.per_fqn_lookups: Dict[str, List[RawIndexedLookup]] = {}
279+
280+
def append(
281+
self,
282+
batch_idx: int,
283+
fqn: str,
284+
ids: torch.Tensor,
285+
states: Optional[torch.Tensor] = None,
286+
raw_ids: Optional[torch.Tensor] = None,
287+
) -> None:
288+
table_fqn_lookup = self.per_fqn_lookups.get(fqn, [])
289+
table_fqn_lookup.append(
290+
RawIndexedLookup(batch_idx=batch_idx, ids=ids, raw_ids=raw_ids)
291+
)
292+
self.per_fqn_lookups[fqn] = table_fqn_lookup
293+
294+
def delete(self, up_to_idx: Optional[int] = None) -> None:
295+
"""
296+
Delete all idx from the store up to `up_to_idx`
297+
"""
298+
if up_to_idx is None:
299+
# If up_to_idx is None, delete all lookups
300+
self.per_fqn_lookups = {}
301+
else:
302+
# lookups are sorted by idx.
303+
up_to_idx = none_throws(up_to_idx)
304+
for table_fqn, lookups in self.per_fqn_lookups.items():
305+
# remove all lookups up to up_to_idx
306+
self.per_fqn_lookups[table_fqn] = [
307+
lookup for lookup in lookups if lookup.batch_idx >= up_to_idx
308+
]
309+
310+
def compact(self, start_idx: int, end_idx: int) -> None:
311+
pass
312+
313+
def get_indexed_lookups(
314+
self, start_idx: int, end_idx: int
315+
) -> Dict[str, List[RawIndexedLookup]]:
316+
r"""
317+
Return all unique/delta ids per table from the Delta Store.
318+
"""
319+
per_fqn_lookups: Dict[str, List[RawIndexedLookup]] = {}
320+
for table_fqn, lookups in self.per_fqn_lookups.items():
321+
indexices = [h.batch_idx for h in lookups]
322+
index_l = bisect_left(indexices, start_idx)
323+
index_r = bisect_left(indexices, end_idx)
324+
per_fqn_lookups[table_fqn] = lookups[index_l:index_r]
325+
return per_fqn_lookups
326+
327+
def get_unique(self, from_idx: int = 0) -> Dict[str, UniqueRows]:
328+
return {}

torchrec/distributed/model_tracker/trackers/raw_id_tracker.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
ShardedManagedCollisionEmbeddingBagCollection,
2424
)
2525
from torchrec.distributed.mc_modules import ShardedManagedCollisionCollection
26-
from torchrec.distributed.model_tracker.delta_store import DeltaStoreTrec
26+
from torchrec.distributed.model_tracker.delta_store import RawIdTrackerStore
2727

2828
from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTracker
2929
from torchrec.distributed.model_tracker.types import IndexedLookup, UniqueRows
@@ -71,7 +71,7 @@ def __init__(
7171
c: -1 for c in (self._consumers or [self.DEFAULT_CONSUMER])
7272
}
7373

74-
self.store: DeltaStoreTrec = DeltaStoreTrec()
74+
self.store: RawIdTrackerStore = RawIdTrackerStore()
7575

7676
# Mapping feature name to corresponding FQNs. This is used for retrieving
7777
# the FQN associated with a given feature name in record_lookup().
@@ -212,7 +212,6 @@ def record_lookup(
212212
batch_idx=self.curr_batch_idx,
213213
fqn=table_fqn,
214214
ids=torch.cat(ids_list),
215-
states=None,
216215
raw_ids=torch.cat(per_table_raw_ids[table_fqn]),
217216
)
218217

torchrec/distributed/model_tracker/types.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,20 @@ class IndexedLookup:
2323
batch_idx: int
2424
ids: torch.Tensor
2525
states: Optional[torch.Tensor]
26-
raw_ids: Optional[torch.Tensor] = None
2726
compact: bool = False
2827

2928

29+
@dataclass
30+
class RawIndexedLookup:
31+
r"""
32+
Data class for storing per batch lookedup ids and embeddings or optimizer states.
33+
"""
34+
35+
batch_idx: int
36+
ids: torch.Tensor
37+
raw_ids: Optional[torch.Tensor] = None
38+
39+
3040
@dataclass
3141
class UniqueRows:
3242
r"""

0 commit comments

Comments
 (0)