Skip to content

Shared embedding ops #1935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <torchao/experimental/ops/embedding_xbit/packed_weights_header.h>
#include <torchao/experimental/ops/library.h>
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h>
#include <torchao/experimental/ops/packed_weights_header.h>
#include <torchao/experimental/ops/parallel.h>

Expand Down Expand Up @@ -266,3 +267,121 @@ Tensor pack_embedding_meta(const Tensor& weight_qvals) {
.to("meta");
}
#endif // USE_ATEN

#if defined(USE_ATEN) || defined(USE_EXECUTORCH)
template <int weight_nbit>
Tensor shared_embedding_out_cpu(
const Tensor& packed_weights,
const int64_t& group_size,
const int64_t& n, // same as num_embeddings
const int64_t& k, // same as embedding_dim
const Tensor& indices,
Tensor& out) {
// Check packed_weights are from linear op
TORCHAO_CHECK(packed_weights.dim() == 1, "packed_weights must be 1D");
#ifdef USE_ATEN
TORCHAO_CHECK(
packed_weights.dtype() == torch::kInt8, "packed_weights must be int8");
#endif // USE_ATEN
TORCHAO_CHECK(
packed_weights.size(0) >= torchao::ops::PackedWeightsHeader::size(),
"packed_weights is not big enough to read the header.");
auto header =
torchao::ops::PackedWeightsHeader::read(packed_weights.const_data_ptr());
auto format = torchao::ops::linear_8bit_act_xbit_weight::PackedWeightsFormat::
from_packed_weights_header(header);
torchao::ops::linear_8bit_act_xbit_weight::check_format<weight_nbit>(
format,
torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal);
constexpr int nr = 8;
constexpr int kr = 16;
constexpr int sr = 2;
TORCHAO_CHECK(format.nr == nr, "shared_embedding only supports nr=8");
TORCHAO_CHECK(format.kr == kr, "shared_embedding only supports kr=16");
TORCHAO_CHECK(format.sr == sr, "shared_embedding only supports sr=2");

int num_out = indices.size(0);

#ifdef USE_ATEN
TORCHAO_CHECK(out.dtype() == torch::kFloat32, "out must be float32");
out.resize_({num_out, k});
#endif // USE_ATEN

#ifdef USE_EXECUTORCH
TORCHAO_CHECK(out.dim() == 2, "out must be 2D");
TORCHAO_CHECK(out.size(0) == num_out, "out shape is incorrect");
TORCHAO_CHECK(out.size(1) == k, "out shape is incorrect");
#endif // USE_EXECUTORCH

const int32_t* index32_ptr = nullptr;
const int64_t* index64_ptr = nullptr;
if (indices.dtype() == Tensor_dtype_kInt32) {
index32_ptr = indices.const_data_ptr<int32_t>();
} else {
TORCHAO_CHECK(
indices.dtype() == Tensor_dtype_kInt64,
"indices must be int32 or int64");
index64_ptr = indices.const_data_ptr<int64_t>();
}
torchao::parallel_1d(0, num_out, [&](int64_t idx) {
int index = -1;
if (index32_ptr != nullptr) {
index = index32_ptr[idx];
} else {
index = index64_ptr[idx];
}
TORCHAO_CHECK(index >= 0 && index < k, "index out of bounds");
#if defined(TORCHAO_BUILD_CPU_AARCH64)
torchao::kernels::cpu::aarch64::embedding::
shared_embedding<weight_nbit, nr, kr, sr>(
out.mutable_data_ptr<float>() + idx * k,
packed_weights.const_data_ptr<int8_t>() +
torchao::ops::PackedWeightsHeader::size(),
n,
k,
group_size,
format.has_weight_zeros,
format.has_bias,
index);
#else
TORCHAO_CHECK(false, "Unsupported platform");
#endif // TORCHAO_BUILD_CPU_AARCH64
});

return out;
}

#ifdef USE_ATEN
template <int weight_nbit>
Tensor shared_embedding_cpu(
const Tensor& packed_weights,
const int64_t& group_size,
const int64_t& n, // same as num_embeddings
const int64_t& k, // same as embedding_dim
const Tensor& indices) {
Tensor output_tensor = torch::empty({}, torch::kFloat32);
shared_embedding_out_cpu<weight_nbit>(
packed_weights,
group_size,
n,
k,
indices,
output_tensor);
return output_tensor;
}
#endif // USE_ATEN

#ifdef USE_ATEN
template <int weight_nbit>
Tensor shared_embedding_meta(
const Tensor& packed_weights,
const int64_t& group_size,
const int64_t& n, // same as num_embeddings
const int64_t& k, // same as embedding_dim
const Tensor& indices) {
int num_out = indices.size(0);
return torch::empty({num_out, k}).to("meta");
}
#endif // USE_ATEN

#endif // defined(USE_ATEN) || defined(USE_EXECUTORCH)
54 changes: 35 additions & 19 deletions torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,43 @@

#include <torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h>

#define DEFINE_OP(weight_nbit) \
m.def("_pack_embedding_" #weight_nbit "bit(Tensor weight_qvals) -> Tensor"); \
m.def( \
"_embedding_" #weight_nbit \
"bit(Tensor packed_weight_qvals, Tensor num_embeddings_tensor, Tensor embedding_dim_tensor, Tensor weight_scales, Tensor weight_zeros, Tensor indices) -> Tensor"); \
m.def( \
"_embedding_" #weight_nbit \
"bit.out(Tensor packed_weight_qvals, Tensor num_embeddings_tensor, Tensor embedding_dim_tensor, Tensor weight_scales, Tensor weight_zeros, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)");
#define DEFINE_OP(weight_nbit) \
m.def("_pack_embedding_" #weight_nbit "bit(Tensor weight_qvals) -> Tensor"); \
m.def( \
"_embedding_" #weight_nbit \
"bit(Tensor packed_weight_qvals, Tensor num_embeddings_tensor, Tensor embedding_dim_tensor, Tensor weight_scales, Tensor weight_zeros, Tensor indices) -> Tensor"); \
m.def( \
"_embedding_" #weight_nbit \
"bit.out(Tensor packed_weight_qvals, Tensor num_embeddings_tensor, Tensor embedding_dim_tensor, Tensor weight_scales, Tensor weight_zeros, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)"); \
m.def( \
"_shared_embedding_" #weight_nbit \
"bit.out(Tensor packed_weights, int group_size, int n, int k, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)"); \
m.def( \
"_shared_embedding_" #weight_nbit \
"bit(Tensor packed_weights, int group_size, int n, int k, Tensor indices) -> Tensor");

#define DEFINE_CPU_IMPL(weight_nbit) \
m.impl( \
"_pack_embedding_" #weight_nbit "bit", \
&pack_embedding_cpu<weight_nbit>); \
m.impl("_embedding_" #weight_nbit "bit", &embedding_cpu<weight_nbit>); \
m.impl("_embedding_" #weight_nbit "bit.out", &embedding_out_cpu<weight_nbit>);
#define DEFINE_CPU_IMPL(weight_nbit) \
m.impl( \
"_pack_embedding_" #weight_nbit "bit", \
&pack_embedding_cpu<weight_nbit>); \
m.impl("_embedding_" #weight_nbit "bit", &embedding_cpu<weight_nbit>); \
m.impl( \
"_embedding_" #weight_nbit "bit.out", &embedding_out_cpu<weight_nbit>); \
m.impl( \
"_shared_embedding_" #weight_nbit "bit", \
&shared_embedding_cpu<weight_nbit>); \
m.impl( \
"_shared_embedding_" #weight_nbit "bit.out", \
&shared_embedding_out_cpu<weight_nbit>);

#define DEFINE_META_IMPL(weight_nbit) \
m.impl( \
"_pack_embedding_" #weight_nbit "bit", \
&pack_embedding_meta<weight_nbit>); \
m.impl("_embedding_" #weight_nbit "bit", &embedding_meta<weight_nbit>);
#define DEFINE_META_IMPL(weight_nbit) \
m.impl( \
"_pack_embedding_" #weight_nbit "bit", \
&pack_embedding_meta<weight_nbit>); \
m.impl("_embedding_" #weight_nbit "bit", &embedding_meta<weight_nbit>); \
m.impl( \
"_shared_embedding_" #weight_nbit "bit", \
&shared_embedding_meta<weight_nbit>);

TORCH_LIBRARY_FRAGMENT(torchao, m) {
DEFINE_OP(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,31 @@ DEFINE_OP(5);
DEFINE_OP(6);
DEFINE_OP(7);
DEFINE_OP(8);

#define DEFINE_SHARED_OP(weight_nbit) \
Tensor _shared_op_out_##weight_nbit( \
RuntimeContext& ctx, \
const Tensor& packed_weights, \
const int64_t& group_size, \
const int64_t& n, \
const int64_t& k, \
const Tensor& indices, \
Tensor& out) { \
(void)ctx; \
shared_embedding_out_cpu<weight_nbit>( \
packed_weights, group_size, n, k, indices, out); \
return out; \
} \
EXECUTORCH_LIBRARY( \
torchao, \
"_shared_embedding_" #weight_nbit "bit.out", \
_op_out_##weight_nbit)

DEFINE_SHARED_OP(1);
DEFINE_SHARED_OP(2);
DEFINE_SHARED_OP(3);
DEFINE_SHARED_OP(4);
DEFINE_SHARED_OP(5);
DEFINE_SHARED_OP(6);
DEFINE_SHARED_OP(7);
DEFINE_SHARED_OP(8);
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#pragma once
#include <cpuinfo.h>
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h>
#include <torchao/experimental/ops/packed_weights_header.h>
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_format.h>

#if defined(TORCHAO_BUILD_CPU_AARCH64)
#include <torchao/experimental/kernels/cpu/aarch64/linear/linear.h>
Expand All @@ -23,49 +23,6 @@

namespace torchao::ops::linear_8bit_act_xbit_weight {

struct PackedWeightsFormat {
torchao::ops::PackedWeightsType type;
int weight_nbit;
bool has_weight_zeros;
bool has_bias;
int nr;
int kr;
int sr;

PackedWeightsFormat(
torchao::ops::PackedWeightsType type,
int weight_nbit,
bool has_weight_zeros,
bool has_bias,
int nr,
int kr,
int sr)
: type{type},
weight_nbit{weight_nbit},
has_weight_zeros{has_weight_zeros},
has_bias{has_bias},
nr{nr},
kr{kr},
sr{sr} {}

static PackedWeightsFormat from_packed_weights_header(
torchao::ops::PackedWeightsHeader header) {
return PackedWeightsFormat(
header.type,
header.params[0],
static_cast<bool>(header.params[1]),
static_cast<bool>(header.params[2]),
header.params[3],
header.params[4],
header.params[5]);
}

inline torchao::ops::PackedWeightsHeader to_packed_weights_header() const {
return torchao::ops::PackedWeightsHeader(
type, {weight_nbit, has_weight_zeros, has_bias, nr, kr, sr});
}
};

struct UKernelConfigRegistrationTable {
private:
using Key = std::pair<torchao::ops::PackedWeightsHeader, cpuinfo_uarch>;
Expand Down Expand Up @@ -107,25 +64,6 @@ struct UKernelConfigRegistrationTable {
}
};

template <int weight_nbit>
void check_format(
PackedWeightsFormat format,
torchao::ops::PackedWeightsType type) {
if (format.type != type) {
throw std::runtime_error(
"Kernel expects packed_weights type=" +
std::to_string(static_cast<int>(type)) +
", but got packed_weights with type=" +
std::to_string(static_cast<int>(format.type)));
}
if (format.weight_nbit != weight_nbit) {
throw std::runtime_error(
"Kernel expects weight_nbit=" + std::to_string(weight_nbit) +
", but got packed_weights with weight_nbit=" +
std::to_string(format.weight_nbit));
}
}

void log_registration(PackedWeightsFormat format, std::string description) {
// Logging is only supported in ATen mode
#ifdef USE_ATEN
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <torchao/experimental/ops/packed_weights_header.h>

namespace torchao::ops::linear_8bit_act_xbit_weight {

struct PackedWeightsFormat {
torchao::ops::PackedWeightsType type;
int weight_nbit;
bool has_weight_zeros;
bool has_bias;
int nr;
int kr;
int sr;

PackedWeightsFormat(
torchao::ops::PackedWeightsType type,
int weight_nbit,
bool has_weight_zeros,
bool has_bias,
int nr,
int kr,
int sr)
: type{type},
weight_nbit{weight_nbit},
has_weight_zeros{has_weight_zeros},
has_bias{has_bias},
nr{nr},
kr{kr},
sr{sr} {}

static PackedWeightsFormat from_packed_weights_header(
torchao::ops::PackedWeightsHeader header) {
return PackedWeightsFormat(
header.type,
header.params[0],
static_cast<bool>(header.params[1]),
static_cast<bool>(header.params[2]),
header.params[3],
header.params[4],
header.params[5]);
}

inline torchao::ops::PackedWeightsHeader to_packed_weights_header() const {
return torchao::ops::PackedWeightsHeader(
type, {weight_nbit, has_weight_zeros, has_bias, nr, kr, sr});
}
};

template <int weight_nbit>
void check_format(
PackedWeightsFormat format,
torchao::ops::PackedWeightsType type) {
if (format.type != type) {
throw std::runtime_error(
"Kernel expects packed_weights type=" +
std::to_string(static_cast<int>(type)) +
", but got packed_weights with type=" +
std::to_string(static_cast<int>(format.type)));
}
if (format.weight_nbit != weight_nbit) {
throw std::runtime_error(
"Kernel expects weight_nbit=" + std::to_string(weight_nbit) +
", but got packed_weights with weight_nbit=" +
std::to_string(format.weight_nbit));
}
}

} // namespace torchao::ops::linear_8bit_act_xbit_weight
Loading