Commit 8b5124f
add stride into KJT pytree
Summary:
# context
* Previously for a KJT, only the following fields and `_keys` are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly.
```
_fields = [
"_values",
"_weights",
"_lengths",
"_offsets",
]
```
* Particularly, the `stride` (int) of a KJT, which represents the `batch_size`, is computed by `_maybe_compute_stride_kjt`:
```
def _maybe_compute_stride_kjt(
keys: List[str],
stride: Optional[int],
lengths: Optional[torch.Tensor],
offsets: Optional[torch.Tensor],
stride_per_key_per_rank: Optional[List[List[int]]],
) -> int:
if stride is None:
if len(keys) == 0:
stride = 0
elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
stride = max([sum(s) for s in stride_per_key_per_rank])
elif offsets is not None and offsets.numel() > 0:
stride = (offsets.numel() - 1) // len(keys)
elif lengths is not None:
stride = lengths.numel() // len(keys)
else:
stride = 0
return stride
```
* The previously stored pytree flatten specs are enough if the `batch_size` is static, however, this no longer holds true in a variable batch size scenario, where the `stride_per_key_per_rank` is not `None`.
* An example is that with `dedup_ebc`, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the **true** `stride` (static).
* During ir_export, the output shape will be calculated from `kjt.stride()` function, which would be incorrect if the pytree specs only contains the `keys`.
* This diff adds the `stride` into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value.
Differential Revision: D664008211 parent 7f3b7dc commit 8b5124f
1 file changed
+11
-5
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3026 | 3026 | | |
3027 | 3027 | | |
3028 | 3028 | | |
3029 | | - | |
3030 | | - | |
| 3029 | + | |
| 3030 | + | |
| 3031 | + | |
| 3032 | + | |
| 3033 | + | |
| 3034 | + | |
3031 | 3035 | | |
3032 | 3036 | | |
3033 | 3037 | | |
3034 | 3038 | | |
3035 | | - | |
| 3039 | + | |
3036 | 3040 | | |
3037 | 3041 | | |
3038 | 3042 | | |
| |||
3041 | 3045 | | |
3042 | 3046 | | |
3043 | 3047 | | |
3044 | | - | |
| 3048 | + | |
| 3049 | + | |
3045 | 3050 | | |
3046 | | - | |
| 3051 | + | |
| 3052 | + | |
3047 | 3053 | | |
3048 | 3054 | | |
3049 | 3055 | | |
| |||
0 commit comments