Commit 696d332
fix stride_per_key_per_rank in stagger scenario in D74366343 (#3111)
Summary:
Pull Request resolved: #3111
# context
* original diff D74366343 broke cogwheel test and was reverted
* the error stack P1844048578 is shown below:
```
File "/dev/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/dev/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/dev/torchrec/distributed/train_pipeline/runtime_forwards.py", line 84, in __call__
data = request.wait()
File "/dev/torchrec/distributed/types.py", line 334, in wait
ret: W = self._wait_impl()
File "/dev/torchrec/distributed/embedding_sharding.py", line 655, in _wait_impl
kjts.append(w.wait())
File "/dev/torchrec/distributed/types.py", line 334, in wait
ret: W = self._wait_impl()
File "/dev/torchrec/distributed/dist_data.py", line 426, in _wait_impl
return type(self._input).dist_init(
File "/dev/torchrec/sparse/jagged_tensor.py", line 2993, in dist_init
return kjt.sync()
File "/dev/torchrec/sparse/jagged_tensor.py", line 2067, in sync
self.length_per_key()
File "/dev/torchrec/sparse/jagged_tensor.py", line 2281, in length_per_key
_length_per_key = _maybe_compute_length_per_key(
File "/dev/torchrec/sparse/jagged_tensor.py", line 1192, in _maybe_compute_length_per_key
_length_per_key_from_stride_per_key(lengths, stride_per_key)
File "/dev/torchrec/sparse/jagged_tensor.py", line 1144, in _length_per_key_from_stride_per_key
if _use_segment_sum_csr(stride_per_key):
File "/dev/torchrec/sparse/jagged_tensor.py", line 1131, in _use_segment_sum_csr
elements_per_segment = sum(stride_per_key) / len(stride_per_key)
ZeroDivisionError: division by zero
```
* the complaint is `stride_per_key` is an empty list, which comes from the following function call:
```
stride_per_key = _maybe_compute_stride_per_key(
self._stride_per_key,
self._stride_per_key_per_rank,
self.stride(),
self._keys,
)
```
* the only place this `stride_per_key` could be empty is when the `stride_per_key_per_rank.dim() != 2`
```
def _maybe_compute_stride_per_key(
stride_per_key: Optional[List[int]],
stride_per_key_per_rank: Optional[torch.IntTensor],
stride: Optional[int],
keys: List[str],
) -> Optional[List[int]]:
if stride_per_key is not None:
return stride_per_key
elif stride_per_key_per_rank is not None:
if stride_per_key_per_rank.dim() != 2:
# after permute the kjt could be empty
return []
rt: List[int] = stride_per_key_per_rank.sum(dim=1).tolist()
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
pt2_checks_all_is_size(rt)
return rt
elif stride is not None:
return [stride] * len(keys)
else:
return None
```
# the main change from D74366343 is that the `stride_per_key_per_rank` in `dist_init`:
* baseline
```
if stagger > 1:
stride_per_key_per_rank_stagger: List[List[int]] = []
local_world_size = num_workers // stagger
for i in range(len(keys)):
stride_per_rank_stagger: List[int] = []
for j in range(local_world_size):
stride_per_rank_stagger.extend(
stride_per_key_per_rank[i][j::local_world_size]
)
stride_per_key_per_rank_stagger.append(stride_per_rank_stagger)
stride_per_key_per_rank = stride_per_key_per_rank_stagger
```
* D76875546 (correct, this diff)
```
if stagger > 1:
indices = torch.arange(num_workers).view(stagger, -1).T.reshape(-1)
stride_per_key_per_rank = stride_per_key_per_rank[:, indices]
```
* D74366343 (incorrect, reverted)
```
if stagger > 1:
local_world_size = num_workers // stagger
indices = [
list(range(i, num_workers, local_world_size))
for i in range(local_world_size)
]
stride_per_key_per_rank = stride_per_key_per_rank[:, indices]
```
Differential Revision: D769036461 parent 08e4f7b commit 696d332
File tree
3 files changed
+83
-48
lines changed- torchrec
- pt2
- schema/api_tests
- sparse
3 files changed
+83
-48
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
54 | 54 | | |
55 | 55 | | |
56 | 56 | | |
57 | | - | |
| 57 | + | |
58 | 58 | | |
59 | 59 | | |
60 | 60 | | |
| |||
85 | 85 | | |
86 | 86 | | |
87 | 87 | | |
88 | | - | |
| 88 | + | |
89 | 89 | | |
90 | 90 | | |
91 | 91 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
9 | 9 | | |
10 | 10 | | |
11 | 11 | | |
12 | | - | |
| 12 | + | |
13 | 13 | | |
14 | 14 | | |
15 | 15 | | |
| |||
112 | 112 | | |
113 | 113 | | |
114 | 114 | | |
115 | | - | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
116 | 118 | | |
117 | 119 | | |
118 | 120 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1094 | 1094 | | |
1095 | 1095 | | |
1096 | 1096 | | |
1097 | | - | |
| 1097 | + | |
1098 | 1098 | | |
1099 | 1099 | | |
1100 | 1100 | | |
1101 | 1101 | | |
1102 | | - | |
1103 | | - | |
| 1102 | + | |
| 1103 | + | |
| 1104 | + | |
| 1105 | + | |
1104 | 1106 | | |
1105 | 1107 | | |
1106 | 1108 | | |
| |||
1452 | 1454 | | |
1453 | 1455 | | |
1454 | 1456 | | |
1455 | | - | |
1456 | | - | |
| 1457 | + | |
| 1458 | + | |
1457 | 1459 | | |
1458 | 1460 | | |
1459 | 1461 | | |
| |||
1639 | 1641 | | |
1640 | 1642 | | |
1641 | 1643 | | |
1642 | | - | |
| 1644 | + | |
1643 | 1645 | | |
1644 | 1646 | | |
1645 | 1647 | | |
1646 | 1648 | | |
1647 | 1649 | | |
1648 | 1650 | | |
1649 | | - | |
| 1651 | + | |
| 1652 | + | |
| 1653 | + | |
| 1654 | + | |
| 1655 | + | |
| 1656 | + | |
| 1657 | + | |
1650 | 1658 | | |
1651 | 1659 | | |
1652 | 1660 | | |
| |||
1725 | 1733 | | |
1726 | 1734 | | |
1727 | 1735 | | |
1728 | | - | |
| 1736 | + | |
| 1737 | + | |
| 1738 | + | |
1729 | 1739 | | |
1730 | 1740 | | |
1731 | 1741 | | |
| |||
1747 | 1757 | | |
1748 | 1758 | | |
1749 | 1759 | | |
1750 | | - | |
1751 | | - | |
| 1760 | + | |
| 1761 | + | |
| 1762 | + | |
| 1763 | + | |
| 1764 | + | |
| 1765 | + | |
| 1766 | + | |
| 1767 | + | |
1752 | 1768 | | |
1753 | 1769 | | |
1754 | 1770 | | |
| |||
1759 | 1775 | | |
1760 | 1776 | | |
1761 | 1777 | | |
| 1778 | + | |
| 1779 | + | |
1762 | 1780 | | |
1763 | 1781 | | |
1764 | 1782 | | |
| |||
1774 | 1792 | | |
1775 | 1793 | | |
1776 | 1794 | | |
1777 | | - | |
1778 | | - | |
1779 | | - | |
1780 | | - | |
1781 | 1795 | | |
1782 | 1796 | | |
1783 | 1797 | | |
| |||
1987 | 2001 | | |
1988 | 2002 | | |
1989 | 2003 | | |
1990 | | - | |
| 2004 | + | |
1991 | 2005 | | |
1992 | 2006 | | |
1993 | 2007 | | |
| |||
2152 | 2166 | | |
2153 | 2167 | | |
2154 | 2168 | | |
2155 | | - | |
2156 | | - | |
| 2169 | + | |
| 2170 | + | |
| 2171 | + | |
| 2172 | + | |
| 2173 | + | |
| 2174 | + | |
| 2175 | + | |
| 2176 | + | |
| 2177 | + | |
| 2178 | + | |
| 2179 | + | |
| 2180 | + | |
| 2181 | + | |
| 2182 | + | |
| 2183 | + | |
| 2184 | + | |
| 2185 | + | |
| 2186 | + | |
| 2187 | + | |
2157 | 2188 | | |
2158 | 2189 | | |
2159 | 2190 | | |
2160 | 2191 | | |
| 2192 | + | |
| 2193 | + | |
| 2194 | + | |
2161 | 2195 | | |
2162 | 2196 | | |
2163 | 2197 | | |
| |||
2302 | 2336 | | |
2303 | 2337 | | |
2304 | 2338 | | |
| 2339 | + | |
| 2340 | + | |
2305 | 2341 | | |
2306 | 2342 | | |
2307 | 2343 | | |
2308 | 2344 | | |
2309 | 2345 | | |
2310 | | - | |
| 2346 | + | |
2311 | 2347 | | |
| 2348 | + | |
2312 | 2349 | | |
2313 | 2350 | | |
2314 | 2351 | | |
| |||
2456 | 2493 | | |
2457 | 2494 | | |
2458 | 2495 | | |
2459 | | - | |
2460 | 2496 | | |
2461 | 2497 | | |
2462 | 2498 | | |
2463 | 2499 | | |
2464 | 2500 | | |
2465 | 2501 | | |
2466 | | - | |
2467 | | - | |
2468 | | - | |
2469 | | - | |
| 2502 | + | |
| 2503 | + | |
| 2504 | + | |
| 2505 | + | |
| 2506 | + | |
| 2507 | + | |
| 2508 | + | |
| 2509 | + | |
| 2510 | + | |
| 2511 | + | |
| 2512 | + | |
| 2513 | + | |
2470 | 2514 | | |
2471 | 2515 | | |
2472 | 2516 | | |
| |||
2518 | 2562 | | |
2519 | 2563 | | |
2520 | 2564 | | |
2521 | | - | |
2522 | | - | |
2523 | | - | |
| 2565 | + | |
2524 | 2566 | | |
2525 | 2567 | | |
2526 | 2568 | | |
2527 | 2569 | | |
2528 | 2570 | | |
2529 | 2571 | | |
2530 | 2572 | | |
2531 | | - | |
2532 | | - | |
| 2573 | + | |
| 2574 | + | |
2533 | 2575 | | |
2534 | 2576 | | |
2535 | 2577 | | |
| |||
2848 | 2890 | | |
2849 | 2891 | | |
2850 | 2892 | | |
2851 | | - | |
| 2893 | + | |
2852 | 2894 | | |
2853 | 2895 | | |
2854 | 2896 | | |
| |||
2885 | 2927 | | |
2886 | 2928 | | |
2887 | 2929 | | |
2888 | | - | |
2889 | | - | |
2890 | | - | |
| 2930 | + | |
| 2931 | + | |
| 2932 | + | |
| 2933 | + | |
2891 | 2934 | | |
2892 | | - | |
2893 | | - | |
2894 | 2935 | | |
2895 | | - | |
2896 | | - | |
2897 | | - | |
2898 | | - | |
2899 | | - | |
2900 | | - | |
2901 | | - | |
2902 | | - | |
2903 | | - | |
2904 | | - | |
| 2936 | + | |
| 2937 | + | |
2905 | 2938 | | |
2906 | 2939 | | |
2907 | 2940 | | |
| |||
0 commit comments