|
14 | 14 | import torch |
15 | 15 | from torchrec.distributed.model_tracker.types import ( |
16 | 16 | IndexedLookup, |
| 17 | + RawIndexedLookup, |
17 | 18 | UniqueRows, |
18 | 19 | UpdateMode, |
19 | 20 | ) |
@@ -90,7 +91,7 @@ def append( |
90 | 91 | batch_idx: int, |
91 | 92 | fqn: str, |
92 | 93 | ids: torch.Tensor, |
93 | | - states: Optional[torch.Tensor], |
| 94 | + states: Optional[torch.Tensor] = None, |
94 | 95 | raw_ids: Optional[torch.Tensor] = None, |
95 | 96 | ) -> None: |
96 | 97 | """ |
@@ -162,12 +163,12 @@ def append( |
162 | 163 | batch_idx: int, |
163 | 164 | fqn: str, |
164 | 165 | ids: torch.Tensor, |
165 | | - states: Optional[torch.Tensor], |
| 166 | + states: Optional[torch.Tensor] = None, |
166 | 167 | raw_ids: Optional[torch.Tensor] = None, |
167 | 168 | ) -> None: |
168 | 169 | table_fqn_lookup = self.per_fqn_lookups.get(fqn, []) |
169 | 170 | table_fqn_lookup.append( |
170 | | - IndexedLookup(batch_idx=batch_idx, ids=ids, states=states, raw_ids=raw_ids) |
| 171 | + IndexedLookup(batch_idx=batch_idx, ids=ids, states=states) |
171 | 172 | ) |
172 | 173 | self.per_fqn_lookups[fqn] = table_fqn_lookup |
173 | 174 |
|
@@ -264,3 +265,64 @@ def get_unique(self, from_idx: int = 0) -> Dict[str, UniqueRows]: |
264 | 265 | ids=compact_ids, states=compact_states, mode=self.updateMode |
265 | 266 | ) |
266 | 267 | return delta_per_table_fqn |
| 268 | + |
| 269 | + |
| 270 | +class RawIdTrackerStore(DeltaStore): |
| 271 | + """ |
| 272 | + RawIdTrackerStore is a concrete implementation of DeltaStore that stores and manages raw ids tracked by RawIdTracker. |
| 273 | + """ |
| 274 | + |
| 275 | + def __init__(self, updateMode: UpdateMode = UpdateMode.NONE) -> None: |
| 276 | + super().__init__(updateMode) |
| 277 | + self.updateMode = updateMode |
| 278 | + self.per_fqn_lookups: Dict[str, List[RawIndexedLookup]] = {} |
| 279 | + |
| 280 | + def append( |
| 281 | + self, |
| 282 | + batch_idx: int, |
| 283 | + fqn: str, |
| 284 | + ids: torch.Tensor, |
| 285 | + states: Optional[torch.Tensor] = None, |
| 286 | + raw_ids: Optional[torch.Tensor] = None, |
| 287 | + ) -> None: |
| 288 | + table_fqn_lookup = self.per_fqn_lookups.get(fqn, []) |
| 289 | + table_fqn_lookup.append( |
| 290 | + RawIndexedLookup(batch_idx=batch_idx, ids=ids, raw_ids=raw_ids) |
| 291 | + ) |
| 292 | + self.per_fqn_lookups[fqn] = table_fqn_lookup |
| 293 | + |
| 294 | + def delete(self, up_to_idx: Optional[int] = None) -> None: |
| 295 | + """ |
| 296 | + Delete all idx from the store up to `up_to_idx` |
| 297 | + """ |
| 298 | + if up_to_idx is None: |
| 299 | + # If up_to_idx is None, delete all lookups |
| 300 | + self.per_fqn_lookups = {} |
| 301 | + else: |
| 302 | + # lookups are sorted by idx. |
| 303 | + up_to_idx = none_throws(up_to_idx) |
| 304 | + for table_fqn, lookups in self.per_fqn_lookups.items(): |
| 305 | + # remove all lookups up to up_to_idx |
| 306 | + self.per_fqn_lookups[table_fqn] = [ |
| 307 | + lookup for lookup in lookups if lookup.batch_idx >= up_to_idx |
| 308 | + ] |
| 309 | + |
| 310 | + def compact(self, start_idx: int, end_idx: int) -> None: |
| 311 | + pass |
| 312 | + |
| 313 | + def get_indexed_lookups( |
| 314 | + self, start_idx: int, end_idx: int |
| 315 | + ) -> Dict[str, List[RawIndexedLookup]]: |
| 316 | + r""" |
| 317 | + Return all unique/delta ids per table from the Delta Store. |
| 318 | + """ |
| 319 | + per_fqn_lookups: Dict[str, List[RawIndexedLookup]] = {} |
| 320 | + for table_fqn, lookups in self.per_fqn_lookups.items(): |
| 321 | + indexices = [h.batch_idx for h in lookups] |
| 322 | + index_l = bisect_left(indexices, start_idx) |
| 323 | + index_r = bisect_left(indexices, end_idx) |
| 324 | + per_fqn_lookups[table_fqn] = lookups[index_l:index_r] |
| 325 | + return per_fqn_lookups |
| 326 | + |
| 327 | + def get_unique(self, from_idx: int = 0) -> Dict[str, UniqueRows]: |
| 328 | + return {} |
0 commit comments