Skip to content

Conversation

@TroyGarden
Copy link
Contributor

Summary:

context

  • Previously for a KJT, only the following fields and _keys are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly.
    _fields = [
        "_values",
        "_weights",
        "_lengths",
        "_offsets",
    ]
  • Particularly, the stride (int) of a KJT, which represents the batch_size, is computed by _maybe_compute_stride_kjt:
def _maybe_compute_stride_kjt(
    keys: List[str],
    stride: Optional[int],
    lengths: Optional[torch.Tensor],
    offsets: Optional[torch.Tensor],
    stride_per_key_per_rank: Optional[List[List[int]]],
) -> int:
    if stride is None:
        if len(keys) == 0:
            stride = 0
        elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
            stride = max([sum(s) for s in stride_per_key_per_rank])
        elif offsets is not None and offsets.numel() > 0:
            stride = (offsets.numel() - 1) // len(keys)
        elif lengths is not None:
            stride = lengths.numel() // len(keys)
        else:
            stride = 0
    return stride
  • The previously stored pytree flatten specs are enough if the batch_size is static, however, this no longer holds true in a variable batch size scenario, where the stride_per_key_per_rank is not None.
  • An example is that with dedup_ebc, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the true stride (static).
  • During ir_export, the output shape will be calculated from kjt.stride() function, which would be incorrect if the pytree specs only contains the keys.
  • This diff adds the stride into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value.

Differential Revision: D66400821

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 23, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66400821

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Dec 12, 2024
Summary:

# context
* Previously for a KJT, only the following fields and `_keys` are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly.
```
    _fields = [
        "_values",
        "_weights",
        "_lengths",
        "_offsets",
    ]
```
* Particularly, the `stride` (int) of a KJT, which represents the `batch_size`, is computed by `_maybe_compute_stride_kjt`:
```
def _maybe_compute_stride_kjt(
    keys: List[str],
    stride: Optional[int],
    lengths: Optional[torch.Tensor],
    offsets: Optional[torch.Tensor],
    stride_per_key_per_rank: Optional[List[List[int]]],
) -> int:
    if stride is None:
        if len(keys) == 0:
            stride = 0
        elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
            stride = max([sum(s) for s in stride_per_key_per_rank])
        elif offsets is not None and offsets.numel() > 0:
            stride = (offsets.numel() - 1) // len(keys)
        elif lengths is not None:
            stride = lengths.numel() // len(keys)
        else:
            stride = 0
    return stride
```
* The previously stored pytree flatten specs are enough if the `batch_size` is static, however, this no longer holds true in a variable batch size scenario, where the `stride_per_key_per_rank` is not `None`. 
* An example is that with `dedup_ebc`, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the **true** `stride` (static). 
* During ir_export, the output shape will be calculated from `kjt.stride()` function, which would be incorrect if the pytree specs only contains the `keys`. 
* This diff adds the `stride` into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value.

Differential Revision: D66400821
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66400821

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66400821

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Jun 18, 2025
Summary:
Pull Request resolved: meta-pytorch#2587

# context
* Previously for a KJT, only the following fields and `_keys` are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly.
```
    _fields = [
        "_values",
        "_weights",
        "_lengths",
        "_offsets",
    ]
```
* Particularly, the `stride` (int) of a KJT, which represents the `batch_size`, is computed by `_maybe_compute_stride_kjt`:
```
def _maybe_compute_stride_kjt(
    keys: List[str],
    stride: Optional[int],
    lengths: Optional[torch.Tensor],
    offsets: Optional[torch.Tensor],
    stride_per_key_per_rank: Optional[List[List[int]]],
) -> int:
    if stride is None:
        if len(keys) == 0:
            stride = 0
        elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
            stride = max([sum(s) for s in stride_per_key_per_rank])
        elif offsets is not None and offsets.numel() > 0:
            stride = (offsets.numel() - 1) // len(keys)
        elif lengths is not None:
            stride = lengths.numel() // len(keys)
        else:
            stride = 0
    return stride
```
* The previously stored pytree flatten specs are enough if the `batch_size` is static, however, this no longer holds true in a variable batch size scenario, where the `stride_per_key_per_rank` is not `None`.
* An example is that with `dedup_ebc`, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the **true** `stride` (static).
* During ir_export, the output shape will be calculated from `kjt.stride()` function, which would be incorrect if the pytree specs only contains the `keys`.
* This diff adds the `stride` into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value.

Reviewed By: PaulZhang12

Differential Revision: D66400821
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66400821

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Jun 18, 2025
Summary:
Pull Request resolved: meta-pytorch#2587

# context
* Previously for a KJT, only the following fields and `_keys` are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly.
```
    _fields = [
        "_values",
        "_weights",
        "_lengths",
        "_offsets",
    ]
```
* Particularly, the `stride` (int) of a KJT, which represents the `batch_size`, is computed by `_maybe_compute_stride_kjt`:
```
def _maybe_compute_stride_kjt(
    keys: List[str],
    stride: Optional[int],
    lengths: Optional[torch.Tensor],
    offsets: Optional[torch.Tensor],
    stride_per_key_per_rank: Optional[List[List[int]]],
) -> int:
    if stride is None:
        if len(keys) == 0:
            stride = 0
        elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
            stride = max([sum(s) for s in stride_per_key_per_rank])
        elif offsets is not None and offsets.numel() > 0:
            stride = (offsets.numel() - 1) // len(keys)
        elif lengths is not None:
            stride = lengths.numel() // len(keys)
        else:
            stride = 0
    return stride
```
* The previously stored pytree flatten specs are enough if the `batch_size` is static, however, this no longer holds true in a variable batch size scenario, where the `stride_per_key_per_rank` is not `None`.
* An example is that with `dedup_ebc`, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the **true** `stride` (static).
* During ir_export, the output shape will be calculated from `kjt.stride()` function, which would be incorrect if the pytree specs only contains the `keys`.
* This diff adds the `stride` into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value.

Reviewed By: PaulZhang12

Differential Revision: D66400821
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66400821

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Jun 18, 2025
Summary:
Pull Request resolved: meta-pytorch#2587

# context
* Previously for a KJT, only the following fields and `_keys` are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly.
```
    _fields = [
        "_values",
        "_weights",
        "_lengths",
        "_offsets",
    ]
```
* Particularly, the `stride` (int) of a KJT, which represents the `batch_size`, is computed by `_maybe_compute_stride_kjt`:
```
def _maybe_compute_stride_kjt(
    keys: List[str],
    stride: Optional[int],
    lengths: Optional[torch.Tensor],
    offsets: Optional[torch.Tensor],
    stride_per_key_per_rank: Optional[List[List[int]]],
) -> int:
    if stride is None:
        if len(keys) == 0:
            stride = 0
        elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
            stride = max([sum(s) for s in stride_per_key_per_rank])
        elif offsets is not None and offsets.numel() > 0:
            stride = (offsets.numel() - 1) // len(keys)
        elif lengths is not None:
            stride = lengths.numel() // len(keys)
        else:
            stride = 0
    return stride
```
* The previously stored pytree flatten specs are enough if the `batch_size` is static, however, this no longer holds true in a variable batch size scenario, where the `stride_per_key_per_rank` is not `None`.
* An example is that with `dedup_ebc`, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the **true** `stride` (static).
* During ir_export, the output shape will be calculated from `kjt.stride()` function, which would be incorrect if the pytree specs only contains the `keys`.
* This diff adds the `stride` into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value.

Reviewed By: PaulZhang12

Differential Revision: D66400821
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66400821

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Jun 18, 2025
Summary:
Pull Request resolved: meta-pytorch#2587

# context
* Previously for a KJT, only the following fields and `_keys` are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly.
```
    _fields = [
        "_values",
        "_weights",
        "_lengths",
        "_offsets",
    ]
```
* Particularly, the `stride` (int) of a KJT, which represents the `batch_size`, is computed by `_maybe_compute_stride_kjt`:
```
def _maybe_compute_stride_kjt(
    keys: List[str],
    stride: Optional[int],
    lengths: Optional[torch.Tensor],
    offsets: Optional[torch.Tensor],
    stride_per_key_per_rank: Optional[List[List[int]]],
) -> int:
    if stride is None:
        if len(keys) == 0:
            stride = 0
        elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
            stride = max([sum(s) for s in stride_per_key_per_rank])
        elif offsets is not None and offsets.numel() > 0:
            stride = (offsets.numel() - 1) // len(keys)
        elif lengths is not None:
            stride = lengths.numel() // len(keys)
        else:
            stride = 0
    return stride
```
* The previously stored pytree flatten specs are enough if the `batch_size` is static, however, this no longer holds true in a variable batch size scenario, where the `stride_per_key_per_rank` is not `None`.
* An example is that with `dedup_ebc`, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the **true** `stride` (static).
* During ir_export, the output shape will be calculated from `kjt.stride()` function, which would be incorrect if the pytree specs only contains the `keys`.
* This diff adds the `stride` into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value.

Reviewed By: PaulZhang12

Differential Revision: D66400821
Summary:
Pull Request resolved: meta-pytorch#2587

# context
* Previously for a KJT, only the following fields and `_keys` are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly.
```
    _fields = [
        "_values",
        "_weights",
        "_lengths",
        "_offsets",
    ]
```
* Particularly, the `stride` (int) of a KJT, which represents the `batch_size`, is computed by `_maybe_compute_stride_kjt`:
```
def _maybe_compute_stride_kjt(
    keys: List[str],
    stride: Optional[int],
    lengths: Optional[torch.Tensor],
    offsets: Optional[torch.Tensor],
    stride_per_key_per_rank: Optional[List[List[int]]],
) -> int:
    if stride is None:
        if len(keys) == 0:
            stride = 0
        elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
            stride = max([sum(s) for s in stride_per_key_per_rank])
        elif offsets is not None and offsets.numel() > 0:
            stride = (offsets.numel() - 1) // len(keys)
        elif lengths is not None:
            stride = lengths.numel() // len(keys)
        else:
            stride = 0
    return stride
```
* The previously stored pytree flatten specs are enough if the `batch_size` is static, however, this no longer holds true in a variable batch size scenario, where the `stride_per_key_per_rank` is not `None`.
* An example is that with `dedup_ebc`, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the **true** `stride` (static).
* During ir_export, the output shape will be calculated from `kjt.stride()` function, which would be incorrect if the pytree specs only contains the `keys`.
* This diff adds the `stride` into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value.

Reviewed By: PaulZhang12

Differential Revision: D66400821
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66400821

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Jun 18, 2025
Summary:
Pull Request resolved: meta-pytorch#2587

# context
* Previously for a KJT, only the following fields and `_keys` are stored in the pytree flatten specs. All other arguments/parameters would be derived accordingly.
```
    _fields = [
        "_values",
        "_weights",
        "_lengths",
        "_offsets",
    ]
```
* Particularly, the `stride` (int) of a KJT, which represents the `batch_size`, is computed by `_maybe_compute_stride_kjt`:
```
def _maybe_compute_stride_kjt(
    keys: List[str],
    stride: Optional[int],
    lengths: Optional[torch.Tensor],
    offsets: Optional[torch.Tensor],
    stride_per_key_per_rank: Optional[List[List[int]]],
) -> int:
    if stride is None:
        if len(keys) == 0:
            stride = 0
        elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
            stride = max([sum(s) for s in stride_per_key_per_rank])
        elif offsets is not None and offsets.numel() > 0:
            stride = (offsets.numel() - 1) // len(keys)
        elif lengths is not None:
            stride = lengths.numel() // len(keys)
        else:
            stride = 0
    return stride
```
* The previously stored pytree flatten specs are enough if the `batch_size` is static, however, this no longer holds true in a variable batch size scenario, where the `stride_per_key_per_rank` is not `None`.
* An example is that with `dedup_ebc`, where the actual batch_size is variable (depending on the dedup data), but the output of the ebc should always be the **true** `stride` (static).
* During ir_export, the output shape will be calculated from `kjt.stride()` function, which would be incorrect if the pytree specs only contains the `keys`.
* This diff adds the `stride` into the KJT pytree flatten/unflatten functions so that a fakified KJT would have the correct stride value.

Differential Revision: D66400821

Reviewed By: PaulZhang12
@TroyGarden TroyGarden deleted the export-D66400821 branch June 19, 2025 07:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants