Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,5 +764,69 @@ def test_swizzle_mm():
)


EMBEDINGBAG_MULTIHOT_SIZES = [1, 2, 3, 10]
EMBEDINGBAG_BAG_SIZES = [1, 2, 128, 1024]
EMBEDINGBAG_VECTOR_SIZES = [1, 128, 512]
EMBEDINGBAG_INDEX_DTYPES = [torch.int64, torch.int32]

EMBEDINGBAG_TEST_PARAMS = list(
itertools.product(
EMBEDINGBAG_MULTIHOT_SIZES,
EMBEDINGBAG_BAG_SIZES,
EMBEDINGBAG_VECTOR_SIZES,
EMBEDINGBAG_INDEX_DTYPES,
)
)


@pytest.mark.skipif(
"CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"),
reason="cpp kernels not built",
)
@pytest.mark.parametrize(
"multi_hot, batch_size, vector_size, index_type",
EMBEDINGBAG_TEST_PARAMS,
ids=str,
)
def test_scaled_embedding_bag_cpu(multi_hot, batch_size, vector_size, index_type):
qtype = torch.float8_e4m3fn
dtype = torch.float32
weight_scale = torch.tensor([2.0])
include_last_offset = True
mode = "sum"

if mode == "sum":
mode_enum = 0
elif mode == "mean":
mode_enum = 1
elif mode == "max":
mode_enum = 2
indices = torch.randint(1000, (batch_size * multi_hot,)).to(index_type)
offsets = torch.arange(0, (batch_size + 1) * multi_hot, multi_hot).to(index_type)

m = torch.nn.EmbeddingBag(
1000,
vector_size,
mode=mode,
dtype=dtype,
include_last_offset=include_last_offset,
)
fp8_weight = m.weight.data.to(qtype)
m.weight.data = fp8_weight.to(m.weight.dtype)

with torch.no_grad():
refe_out = m.forward(indices, offsets) * weight_scale
test_out = torch.ops.torchao._scaled_embedding_bag(
fp8_weight,
indices,
offsets,
weight_scale,
1.0,
mode_enum,
include_last_offset,
).to(dtype)
torch.testing.assert_close(refe_out, test_out, atol=1e-5, rtol=1e-5)


if __name__ == "__main__":
pytest.main(sys.argv)
183 changes: 183 additions & 0 deletions torchao/csrc/cpu/scaled_embedding_bag.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
#include <ATen/cpu/vec/vec.h>
#include <ATen/cpu/vec/vec512/vec512_float8.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/native/EmbeddingBag.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Unroll.h>
#include <torch/all.h>

namespace torchao {

namespace {

#if defined(CPU_CAPABILITY_AVX512)
static inline __m512 _mm512_load_e4m3_cvt_ps(const at::Float8_e4m3fn *x) {
__m512 o;
__m128i v = _mm_loadu_si128(reinterpret_cast<const __m128i *>(x));
at::vec::CPU_CAPABILITY::cvtfp8e4m3_fp32(v, o);
return o;
}
#endif

template <typename index_t>
inline void _scaled_embedding_bag_krnl(
const int64_t bs_begin, const int64_t bs_end, const int64_t num_emb,
const int64_t emb_dim, const index_t last_offset, const index_t *indices,
const index_t *offsets, const at::Float8_e4m3fn *weight, const double scale,
float *result, const int64_t num_batch) {
#if defined(CPU_CAPABILITY_AVX512)
if (emb_dim % 128 == 0) {
constexpr int64_t block_dim = 128;
const int64_t num_blocks = emb_dim / block_dim;
__m512 scale_v = _mm512_set1_ps(scale);
for (int64_t b = bs_begin; b < bs_end; ++b) {
__m512 x0, x1, x2, x3, x4, x5, x6, x7;
int64_t start_idx = offsets[b];
int64_t end_idx = ((b + 1) == num_batch && last_offset != -1)
? last_offset
: offsets[b + 1];
for (int64_t block_id = 0; block_id < num_blocks; block_id++) {
// load first indices
int64_t idx = indices[start_idx] * emb_dim + block_dim * block_id;
float *block_result = result + block_dim * block_id;
x0 = _mm512_load_e4m3_cvt_ps(&weight[idx]);
x1 = _mm512_load_e4m3_cvt_ps(&weight[idx + 16]);
x2 = _mm512_load_e4m3_cvt_ps(&weight[idx + 32]);
x3 = _mm512_load_e4m3_cvt_ps(&weight[idx + 48]);
x4 = _mm512_load_e4m3_cvt_ps(&weight[idx + 64]);
x5 = _mm512_load_e4m3_cvt_ps(&weight[idx + 80]);
x6 = _mm512_load_e4m3_cvt_ps(&weight[idx + 96]);
x7 = _mm512_load_e4m3_cvt_ps(&weight[idx + 112]);
for (int64_t j = start_idx + 1; j < end_idx; ++j) {
// add following idx
idx = indices[j] * emb_dim + block_dim * block_id;
x0 = _mm512_add_ps(x0, _mm512_load_e4m3_cvt_ps(&weight[idx]));
x1 = _mm512_add_ps(x1, _mm512_load_e4m3_cvt_ps(&weight[idx + 16]));
x2 = _mm512_add_ps(x2, _mm512_load_e4m3_cvt_ps(&weight[idx + 32]));
x3 = _mm512_add_ps(x3, _mm512_load_e4m3_cvt_ps(&weight[idx + 48]));
x4 = _mm512_add_ps(x4, _mm512_load_e4m3_cvt_ps(&weight[idx + 64]));
x5 = _mm512_add_ps(x5, _mm512_load_e4m3_cvt_ps(&weight[idx + 80]));
x6 = _mm512_add_ps(x6, _mm512_load_e4m3_cvt_ps(&weight[idx + 96]));
x7 = _mm512_add_ps(x7, _mm512_load_e4m3_cvt_ps(&weight[idx + 112]));
}
x0 = _mm512_mul_ps(x0, scale_v);
x1 = _mm512_mul_ps(x1, scale_v);
x2 = _mm512_mul_ps(x2, scale_v);
x3 = _mm512_mul_ps(x3, scale_v);
x4 = _mm512_mul_ps(x4, scale_v);
x5 = _mm512_mul_ps(x5, scale_v);
x6 = _mm512_mul_ps(x6, scale_v);
x7 = _mm512_mul_ps(x7, scale_v);
// store
_mm512_store_ps(block_result, x0);
_mm512_store_ps(block_result + 16, x1);
_mm512_store_ps(block_result + 32, x2);
_mm512_store_ps(block_result + 48, x3);
_mm512_store_ps(block_result + 64, x4);
_mm512_store_ps(block_result + 80, x5);
_mm512_store_ps(block_result + 96, x6);
_mm512_store_ps(block_result + 112, x7);
}
result += num_emb * emb_dim;
}
return;
}
#endif
for (int64_t b = bs_begin; b < bs_end; ++b) {
int64_t start_idx = offsets[b];
int64_t end_idx = ((b + 1) == num_batch && last_offset != -1)
? last_offset
: offsets[b + 1];
for (int64_t d = 0; d < emb_dim; d++) {
int64_t idx = indices[start_idx] * emb_dim;
float value = float(weight[idx + d]);
for (int64_t j = start_idx + 1; j < end_idx; ++j) {
idx = indices[j] * emb_dim;
value += float(weight[idx + d]);
}
value = value * scale;
result[d] = value;
}
result += num_emb * emb_dim;
}
}

template <typename index_t, typename data_t>
void _scaled_embedding_bag(float *o_ptr, data_t *w_ptr, index_t *indices_ptr,
index_t *offsets_ptr, int64_t num_batch,
int64_t emb_dim, index_t last_offset, double w_scale,
double o_scale) {
constexpr int64_t b_block = 512;
const int64_t n_b_blocks = (num_batch - 1) / b_block + 1;
w_scale /= o_scale;
const int64_t num_emb = 1;
#pragma omp parallel for collapse(2)
for (int64_t b = 0; b < n_b_blocks; ++b) {
for (int64_t n = 0; n < num_emb; ++n) {
const int64_t bs_begin = b * b_block;
const int64_t bs_end = std::min(num_batch, (b + 1) * b_block);
float *r = &o_ptr[b * b_block * num_emb * emb_dim + n * emb_dim];
// avoid offsets not include last batch
_scaled_embedding_bag_krnl(bs_begin, bs_end, num_emb, emb_dim,
last_offset, indices_ptr, offsets_ptr, w_ptr,
w_scale, r, num_batch);
}
}
}

at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight,
const at::Tensor &indices,
const at::Tensor &offsets,
const at::Tensor &w_scales,
double o_scale, const int64_t mode,
bool include_last_offset) {
// Only support include_last_offset == True and mode ==
// at::native::EmbeddingBagMode::SUM
// TODO: Support more case
TORCH_CHECK(include_last_offset,
"_scaled_embedding_bag: only suppport include_last_offset");
TORCH_CHECK(mode == at::native::EmbeddingBagMode::SUM,
"_scaled_embedding_bag: only suppport sum mode");
int64_t batch_size =
include_last_offset ? offsets.size(0) - 1 : offsets.size(0);
int64_t emb_dim = qweight.size(1);

auto index_type = indices.scalar_type();
auto qtype = qweight.scalar_type();
float w_scale = w_scales.data_ptr<float>()[0];

TORCH_CHECK(indices.is_contiguous() && offsets.is_contiguous(),
"_scaled_embedding_bag: only accept contiguous input");
TORCH_CHECK(
offsets.scalar_type() == index_type,
"_scaled_embedding_bag: index and offset must be of the same type");
TORCH_CHECK(qweight.is_contiguous(),
"_scaled_embedding_bag: only accept contiguous weight");
TORCH_CHECK(qweight.dim() == 2,
"_scaled_embedding_bag: only accept weight with dim == 2");
TORCH_CHECK(qweight.scalar_type() == c10::ScalarType::Float8_e4m3fn,
"_scaled_embedding_bag: only support e4m3fn weight")
// handle last offsets
int64_t last_offset = indices.numel();

at::Tensor output =
at::empty({batch_size, emb_dim}, qweight.options().dtype(at::kFloat));
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embeddingbag_cat", [&] {
at::Float8_e4m3fn *qweight_ptr = qweight.data_ptr<at::Float8_e4m3fn>();
index_t *indices_ptr = indices.data_ptr<index_t>();
index_t *offsets_ptr = offsets.data_ptr<index_t>();
float *output_ptr = output.data_ptr<float>();
_scaled_embedding_bag<index_t, at::Float8_e4m3fn>(
output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size, emb_dim,
last_offset, w_scale, o_scale);
});
return output;
}

} // anonymous namespace

TORCH_LIBRARY_IMPL(torchao, CPU, m) {
m.impl("torchao::_scaled_embedding_bag", &_scaled_embedding_bag_impl);
}

} // namespace torchao
19 changes: 19 additions & 0 deletions torchao/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
lib.define(
"da8w4_linear_cpu(Tensor input, Tensor input_scales, Tensor input_qzeros, Tensor weight, Tensor weight_scales, Tensor weight_qzeros, Tensor compensation, Tensor? bias, ScalarType output_dtype) -> Tensor"
)
lib.define(
"_scaled_embedding_bag(Tensor qweight, Tensor indices, Tensor offsets, Tensor weight_scale, float o_scale, int mode, bool include_last_offset) -> Tensor"
)


def register_custom_op(name):
Expand Down Expand Up @@ -1098,3 +1101,19 @@ def _(
assert weight.dim() == 4
N = weight.size(0) * weight.size(3) * 2
return input.new_empty(*input.shape[:-1], N, dtype=out_dtype)


@register_custom_op("torchao::_scaled_embedding_bag")
def _(
qweight: Tensor,
indices: Tensor,
offsets: Tensor,
w_scales: Tensor,
o_scale: float,
mode: int,
include_last_offset: bool,
) -> Tensor:
# Only support include_last_offset == True
assert include_last_offset == True
batch_size = offsets.shape[0] - 1
return qweight.new_empty(batch_size, qweight.shape[1], dtype=qweight.dtype)
Loading