You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
Pull Request resolved: pytorch#1641
This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:
- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
`SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
`SparseType.FP16`)
- Device, manged, managed caching embedding locations
(`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
`EmbeddingLocation.MANAGED`,
`EmbeddingLocation.MANAGED_CACHING`)
Cases that the new implementation does **NOT** support:
- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
`SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
`SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
`SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
`SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
`EmbeddingLocation.HOST`)
Note that this optimization is enabled for NVIDIA GPUs, but **not**
enabled for AMD GPUs.
**Usage**
The frontend changes are in D44479772
The `FBGEMM_EXPERIMENTAL_TBE` environment variable flag is added for
enabling/disabling the new implementation at runtime. If
`FBGEMM_EXPERIMENTAL_TBE` is not set, TBE will use the orignal
implementation. If `FBGEMM_EXPERIMENTAL_TBE=1`, TBE will use the new
implementation. If the TBE usecases are not supported in the new
implementation, TBE will fall back to the original implementation. By
default, `FBGEMM_EXPERIMENTAL_TBE` is not set.
This can also be enabled by passing `use_experimental_tbe=True` when
instantiating the TBE operator.
```
emb_op = SplitTableBatchedEmbeddingBagsCodegen(
embedding_specs=...,
...,
use_experimental_tbe=True,
)
```
**Optimization**
The new implementation contains the following optimizations:
- Use multiple warps per bag for D > 128 to maintain a constant
number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
memory
- Save state variables in shared memory instead of registers to free
registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
small Ls
Note: D = embedding dimension, L = pooling factor
Reviewed By: jianyuh
Differential Revision: D43634651
fbshipit-source-id: 96ad56f0e5567959fd28c72a649f862e1f5dd307
0 commit comments