Commit 3c8f53a
add stride into KJT pytree (meta-pytorch#2587)
Summary:
Pull Request resolved: meta-pytorch#2587
# context
* Previously for a KJT, only the following fields and `_keys` are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly.
```
_fields = [
"_values",
"_weights",
"_lengths",
"_offsets",
]
```
* Particularly, the `stride` (int) of a KJT, which represents the `batch_size`, is computed by `_maybe_compute_stride_kjt`:
```
def _maybe_compute_stride_kjt(
keys: List[str],
stride: Optional[int],
lengths: Optional[torch.Tensor],
offsets: Optional[torch.Tensor],
stride_per_key_per_rank: Optional[List[List[int]]],
) -> int:
if stride is None:
if len(keys) == 0:
stride = 0
elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
stride = max([sum(s) for s in stride_per_key_per_rank])
elif offsets is not None and offsets.numel() > 0:
stride = (offsets.numel() - 1) // len(keys)
elif lengths is not None:
stride = lengths.numel() // len(keys)
else:
stride = 0
return stride
```
* The previously stored pytree flatten specs are enough if the `batch_size` is static, however, this no longer holds true in a variable batch size scenario, where the `stride_per_key_per_rank` is not `None`.
* An example is that with `dedup_ebc`, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the **true** `stride` (static).
* During ir_export, the output shape will be calculated from `kjt.stride()` function, which would be incorrect if the pytree specs only contains the `keys`.
* This diff adds the `stride` into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value.
Reviewed By: PaulZhang12
Differential Revision: D664008211 parent f126ded commit 3c8f53a
2 files changed
+4
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
28 | 28 | | |
29 | 29 | | |
30 | 30 | | |
31 | | - | |
| 31 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
| 2 | + | |
2 | 3 | | |
3 | 4 | | |
4 | 5 | | |
| 6 | + | |
5 | 7 | | |
6 | 8 | | |
7 | 9 | | |
| |||
13 | 15 | | |
14 | 16 | | |
15 | 17 | | |
| 18 | + | |
16 | 19 | | |
17 | 20 | | |
18 | 21 | | |
| |||
0 commit comments