Skip to content

Commit ca0df12

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
add stride into KJT pytree (#2587)
Summary: Pull Request resolved: #2587 # context * Previously for a KJT, only the following fields and `_keys` are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly. ``` _fields = [ "_values", "_weights", "_lengths", "_offsets", ] ``` * Particularly, the `stride` (int) of a KJT, which represents the `batch_size`, is computed by `_maybe_compute_stride_kjt`: ``` def _maybe_compute_stride_kjt( keys: List[str], stride: Optional[int], lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor], stride_per_key_per_rank: Optional[List[List[int]]], ) -> int: if stride is None: if len(keys) == 0: stride = 0 elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0: stride = max([sum(s) for s in stride_per_key_per_rank]) elif offsets is not None and offsets.numel() > 0: stride = (offsets.numel() - 1) // len(keys) elif lengths is not None: stride = lengths.numel() // len(keys) else: stride = 0 return stride ``` * The previously stored pytree flatten specs are enough if the `batch_size` is static, however, this no longer holds true in a variable batch size scenario, where the `stride_per_key_per_rank` is not `None`. * An example is that with `dedup_ebc`, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the **true** `stride` (static). * During ir_export, the output shape will be calculated from `kjt.stride()` function, which would be incorrect if the pytree specs only contains the `keys`. * This diff adds the `stride` into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value. Reviewed By: PaulZhang12 Differential Revision: D66400821
1 parent a719848 commit ca0df12

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

.github/scripts/install_libs.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@ elif [ "$CHANNEL" = "test" ]; then
2828
fi
2929

3030

31-
${CONDA_RUN} pip install importlib-metadata
31+
${CONDA_RUN} pip install importlib-metadata click

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
black
2+
click
23
cmake
34
fbgemm-gpu
45
hypothesis==6.70.1

0 commit comments

Comments
 (0)