You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
Pull Request resolved: #1641
This diff adds an optimized implementation of TBE training forward,
namely
`split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`.
The implementation currently supports only a subset of usecases of TBE
including:
- Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`)
- Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`)
- Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`)
- FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`,
`SparseType.FP16`)
- FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`,
`SparseType.FP16`)
- Device, manged, managed caching embedding locations
(`EmbeddingLocation`: `EmbeddingLocation.DEVICE`,
`EmbeddingLocation.MANAGED`,
`EmbeddingLocation.MANAGED_CACHING`)
Cases that the new implementation does **NOT** support:
- Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`)
- Sequence TBE (`pooling_mode`: `PoolingMode.NONE`)
- FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`:
`SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
`SparseType.INT2`, `SparseType.BF16`)
- FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`:
`SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`,
`SparseType.INT2`, `SparseType.BF16`)
- Host embedding locations (`EmbeddingLocation`:
`EmbeddingLocation.HOST`)
The `FBGEMM_EXPERIMENTAL_TBE` environment variable flag is added for
enabling/disabling the new implementation at runtime. If
`FBGEMM_EXPERIMENTAL_TBE` is not set, TBE will use the orignal implementation.
If `FBGEMM_EXPERIMENTAL_TBE=1`, TBE will use the new implementation. If the
TBE usecases are not supported in the new implementation, TBE will
fall back to the original implementation. By default,
`FBGEMM_EXPERIMENTAL_TBE` is not set.
The new implementation contains the following optimizations:
- Use multiple warps per bag for D > 128 to maintain a constant
number of registers per thread
- Use subwarps to process subsets of input rows in a bag if D < 128
- Cooperatively compute weight pointers and store them in shared
memory
- Save state variables in shared memory instead of registers to free
registers for compiler optimizations
- Use the upper bound number of warps for all tables to avoid complex
warp offset computation
- Process multiple samples (up to kWarpSize samples) in a warp for
small Ls
Note: D = embedding dimension, L = pooling factor
Differential Revision: D43634651
fbshipit-source-id: 42b8c5b853dd30df9bb3b2f808668d1ebf0db9a7
0 commit comments