Skip to content

Commit d84191c

Browse files
authored
Add embedding ops aten
Differential Revision: D64477035 Pull Request resolved: #1129
1 parent a234bc1 commit d84191c

File tree

11 files changed

+664
-10
lines changed

11 files changed

+664
-10
lines changed

torchao/experimental/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@ if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
3333
# Defines target torchao_kernels_aarch64
3434
add_subdirectory(kernels/cpu/aarch64)
3535
add_subdirectory(ops/linear_8bit_act_xbit_weight)
36+
add_subdirectory(ops/embedding_xbit)
3637

3738
add_library(torchao_ops_aten SHARED)
3839
target_link_libraries(
3940
torchao_ops_aten PRIVATE
4041
torchao_ops_linear_8bit_act_xbit_weight_aten
42+
torchao_ops_embedding_xbit_aten
4143
)
4244
install(
4345
TARGETS torchao_ops_aten
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
cmake_minimum_required(VERSION 3.19)
8+
9+
include(${CMAKE_CURRENT_SOURCE_DIR}/../../Utils.cmake)
10+
11+
find_package(Torch REQUIRED)
12+
add_library(torchao_ops_embedding_xbit_aten OBJECT
13+
op_embedding_xbit_aten.cpp
14+
)
15+
target_link_torchao_parallel_backend(torchao_ops_embedding_xbit_aten "aten_openmp")
16+
target_link_libraries(torchao_ops_embedding_xbit_aten PRIVATE torchao_kernels_aarch64)
17+
target_include_directories(torchao_ops_embedding_xbit_aten PRIVATE "${TORCH_INCLUDE_DIRS}")
18+
target_link_libraries(torchao_ops_embedding_xbit_aten PRIVATE "${TORCH_LIBRARIES}")
19+
target_compile_definitions(torchao_ops_embedding_xbit_aten PRIVATE USE_ATEN=1)
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#if defined(__aarch64__) || defined(__ARM_NEON)
10+
#include <torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h>
11+
#endif // defined(__aarch64__) || defined(__ARM_NEON)
12+
13+
#include <torchao/experimental/ops/embedding_xbit/packed_weights_header.h>
14+
#include <torchao/experimental/ops/library.h>
15+
#include <torchao/experimental/ops/packed_weights_header.h>
16+
#include <torchao/experimental/ops/parallel.h>
17+
18+
template <int weight_nbit>
19+
void check_embedding_inputs(
20+
const Tensor& packed_weight_qvals,
21+
int num_embeddings,
22+
int embedding_dim,
23+
const Tensor& weight_scales,
24+
const Tensor& weight_zeros,
25+
const Tensor& indices,
26+
int& group_size) {
27+
TORCHAO_CHECK(
28+
packed_weight_qvals.dim() == 1, "packed_weight_qvals must be 1D");
29+
#ifdef USE_ATEN
30+
TORCHAO_CHECK(
31+
packed_weight_qvals.dtype() == torch::kInt8,
32+
"packed_weight_qvals must be byte");
33+
#endif // USE_ATEN
34+
TORCHAO_CHECK(
35+
(embedding_dim * weight_nbit) % 8 == 0,
36+
"embedding_dim * weight_nbit must be a multiple of 8");
37+
int packed_embedding_dim = (embedding_dim * weight_nbit) / 8;
38+
TORCHAO_CHECK(
39+
packed_weight_qvals.size(0) ==
40+
(torchao::ops::PackedWeightsHeader::size() +
41+
(num_embeddings * packed_embedding_dim)),
42+
"packed_weight_qvals is not the correct size");
43+
44+
// Check header
45+
auto header = torchao::ops::PackedWeightsHeader::read(
46+
packed_weight_qvals.const_data_ptr());
47+
TORCHAO_CHECK(
48+
header ==
49+
torchao::ops::embedding_xbit::get_packed_weights_header_universal(
50+
weight_nbit,
51+
/*min_value_chunk_size=*/32,
52+
/*max_value_chunk_size=*/128),
53+
"packed_weights are not compatible with the kernel");
54+
55+
#ifdef USE_ATEN
56+
TORCHAO_CHECK(
57+
weight_scales.dtype() == torch::kFloat32,
58+
"weight_scales must be float32");
59+
#endif // USE_ATEN
60+
TORCHAO_CHECK(weight_scales.dim() == 2, "weight_scales must be 2D");
61+
TORCHAO_CHECK(
62+
weight_scales.size(0) == num_embeddings,
63+
"weight_scales must be same shape as packed_weight_qvals in dim0 (num_embeddings)");
64+
int num_groups = weight_scales.size(1);
65+
TORCHAO_CHECK(
66+
num_groups >= 1, "weight_scales must be at least 1 in dim1 (num_groups)");
67+
TORCHAO_CHECK(
68+
embedding_dim % num_groups == 0,
69+
"embedding_dim must be a multiple of num_groups");
70+
group_size = embedding_dim / num_groups;
71+
TORCHAO_CHECK(group_size % 32 == 0, "group_size must be a multiple of 32");
72+
73+
#ifdef USE_ATEN
74+
TORCHAO_CHECK(
75+
weight_zeros.dtype() == torch::kInt8, "weight_zeros must be int8");
76+
#endif // USE_ATEN
77+
TORCHAO_CHECK(weight_zeros.dim() == 2, "weight_zeros must be 2D");
78+
TORCHAO_CHECK(
79+
weight_zeros.size(0) == weight_scales.size(0) &&
80+
weight_zeros.size(1) == weight_scales.size(1),
81+
"zeros must be same shape as scales");
82+
83+
TORCHAO_CHECK(indices.dim() == 1, "indices must be 1D");
84+
TORCHAO_CHECK(
85+
(indices.dtype() == Tensor_dtype_kInt32) ||
86+
(indices.dtype() == Tensor_dtype_kInt64),
87+
"indices must be int32 or int64");
88+
}
89+
90+
#if defined(USE_ATEN) || defined(USE_EXECUTORCH)
91+
template <int weight_nbit>
92+
Tensor embedding_out_cpu(
93+
const Tensor& packed_weight_qvals,
94+
// TODO(T200095131): convert to
95+
// int64_t when supported by AOTI
96+
// Currently they are tensors with size
97+
// equal to (0, the int they wrap)
98+
const Tensor& num_embeddings_tensor,
99+
const Tensor& embedding_dim_tensor,
100+
const Tensor& weight_scales,
101+
const Tensor& weight_zeros,
102+
const Tensor& indices,
103+
Tensor& out) {
104+
int num_embeddings = num_embeddings_tensor.size(1);
105+
int embedding_dim = embedding_dim_tensor.size(1);
106+
int group_size;
107+
check_embedding_inputs<weight_nbit>(
108+
packed_weight_qvals,
109+
num_embeddings,
110+
embedding_dim,
111+
weight_scales,
112+
weight_zeros,
113+
indices,
114+
group_size);
115+
116+
int num_out = indices.size(0);
117+
const int8_t* weight_zeros_ptr = weight_zeros.const_data_ptr<int8_t>();
118+
119+
#ifdef USE_ATEN
120+
TORCHAO_CHECK(out.dtype() == torch::kFloat32, "out must be float32");
121+
out.resize_({num_out, embedding_dim});
122+
#endif // USE_ATEN
123+
124+
#ifdef USE_EXECUTORCH
125+
TORCHAO_CHECK(out.dim() == 2, "out must be 2D");
126+
TORCHAO_CHECK(out.size(0) == num_out, "out shape is incorrect");
127+
TORCHAO_CHECK(out.size(1) == embedding_dim, "out shape is incorrect");
128+
#endif // USE_EXECUTORCH
129+
130+
const int32_t* index32_ptr = nullptr;
131+
const int64_t* index64_ptr = nullptr;
132+
if (indices.dtype() == Tensor_dtype_kInt32) {
133+
index32_ptr = indices.const_data_ptr<int32_t>();
134+
} else {
135+
TORCHAO_CHECK(
136+
indices.dtype() == Tensor_dtype_kInt64,
137+
"indices must be int32 or int64");
138+
index64_ptr = indices.const_data_ptr<int64_t>();
139+
}
140+
torchao::parallel_1d(0, num_out, [&](int64_t idx) {
141+
int index = -1;
142+
if (index32_ptr != nullptr) {
143+
index = index32_ptr[idx];
144+
} else {
145+
index = index64_ptr[idx];
146+
}
147+
TORCHAO_CHECK(index >= 0 && index < num_embeddings, "index out of bounds");
148+
#if defined(__aarch64__) || defined(__ARM_NEON)
149+
torchao::kernels::cpu::aarch64::embedding::embedding<weight_nbit>(
150+
out.mutable_data_ptr<float>() + idx * embedding_dim,
151+
embedding_dim,
152+
group_size,
153+
packed_weight_qvals.const_data_ptr<int8_t>() +
154+
torchao::ops::PackedWeightsHeader::size(),
155+
weight_scales.const_data_ptr<float>(),
156+
weight_zeros_ptr,
157+
index);
158+
#else
159+
TORCHAO_CHECK(false, "Unsupported platform");
160+
#endif // defined(__aarch64__) || defined(__ARM_NEON)
161+
});
162+
163+
return out;
164+
}
165+
#endif // defined(USE_ATEN) || defined(USE_EXECUTORCH)
166+
167+
#ifdef USE_ATEN
168+
template <int weight_nbit>
169+
Tensor embedding_cpu(
170+
const Tensor& packed_weight_qvals,
171+
// TODO(T200095131): convert to
172+
// int64_t when supported by AOTI
173+
// Currently they are tensors with size
174+
// equal to (0, the int they wrap)
175+
const Tensor& num_embeddings_tensor,
176+
const Tensor& embedding_dim_tensor,
177+
const Tensor& weight_scales,
178+
const Tensor& weight_zeros,
179+
const Tensor& indices) {
180+
Tensor output_tensor = torch::empty({}, torch::kFloat32);
181+
embedding_out_cpu<weight_nbit>(
182+
packed_weight_qvals,
183+
num_embeddings_tensor,
184+
embedding_dim_tensor,
185+
weight_scales,
186+
weight_zeros,
187+
indices,
188+
output_tensor);
189+
return output_tensor;
190+
}
191+
#endif // USE_ATEN
192+
193+
#ifdef USE_ATEN
194+
template <int weight_nbit>
195+
Tensor embedding_meta(
196+
const Tensor& packed_weight_qvals,
197+
// TODO(T200095131): convert to
198+
// int64_t when supported by AOTI
199+
// Currently they are tensors with size
200+
// equal to (0, the int they wrap)
201+
const Tensor& num_embeddings_tensor,
202+
const Tensor& embedding_dim_tensor,
203+
const Tensor& weight_scales,
204+
const Tensor& weight_zeros,
205+
const Tensor& indices) {
206+
int embedding_dim = embedding_dim_tensor.size(1);
207+
int num_out = indices.size(0);
208+
return torch::empty({num_out, embedding_dim}).to("meta");
209+
}
210+
#endif // USE_ATEN
211+
212+
#ifdef USE_ATEN
213+
template <int weight_nbit>
214+
Tensor pack_embedding_cpu(const Tensor& weight_qvals) {
215+
TORCHAO_CHECK(weight_qvals.dim() == 2, "weight_qvals must be 2D");
216+
int num_embeddings = weight_qvals.size(0);
217+
int embedding_dim = weight_qvals.size(1);
218+
TORCHAO_CHECK(
219+
embedding_dim % 8 == 0, "embedding_dim must be a multiple of 8 to pack");
220+
int packed_embedding_dim = embedding_dim * weight_nbit / 8;
221+
TORCHAO_CHECK(
222+
weight_qvals.dtype() == torch::kInt8, "weight_qvals must be int8");
223+
224+
auto out = torch::empty(
225+
torchao::ops::PackedWeightsHeader::size() +
226+
(num_embeddings * packed_embedding_dim))
227+
.to(torch::kInt8);
228+
229+
auto header =
230+
torchao::ops::embedding_xbit::get_packed_weights_header_universal(
231+
weight_nbit,
232+
/*min_value_chunk_size=*/32,
233+
/*max_value_chunk_size=*/128);
234+
header.write(out.mutable_data_ptr());
235+
236+
torchao::parallel_1d(0, num_embeddings, [&](int64_t idx) {
237+
#if defined(__aarch64__) || defined(__ARM_NEON)
238+
torchao::kernels::cpu::aarch64::embedding::pack_embedding_weight_qvals<
239+
weight_nbit>(
240+
out.mutable_data_ptr<int8_t>() +
241+
torchao::ops::PackedWeightsHeader::size(),
242+
embedding_dim,
243+
weight_qvals.const_data_ptr<int8_t>(),
244+
idx);
245+
#else
246+
TORCHAO_CHECK(false, "Unsupported platform");
247+
#endif // defined(__aarch64__) || defined(__ARM_NEON)
248+
});
249+
250+
return out;
251+
}
252+
#endif // USE_ATEN
253+
254+
#ifdef USE_ATEN
255+
template <int weight_nbit>
256+
Tensor pack_embedding_meta(const Tensor& weight_qvals) {
257+
TORCHAO_CHECK(weight_qvals.dim() == 2, "weight_qvals must be 2D");
258+
int num_embeddings = weight_qvals.size(0);
259+
int embedding_dim = weight_qvals.size(1);
260+
TORCHAO_CHECK(
261+
embedding_dim % 8 == 0, "embedding_dim must be a multiple of 8 to pack");
262+
int packed_embedding_dim = embedding_dim * weight_nbit / 8;
263+
return torch::empty(
264+
torchao::ops::PackedWeightsHeader::size() +
265+
(num_embeddings * packed_embedding_dim))
266+
.to("meta");
267+
}
268+
#endif // USE_ATEN
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include <torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h>
8+
9+
#define DEFINE_OP(weight_nbit) \
10+
m.def("_pack_embedding_" #weight_nbit "bit(Tensor weight_qvals) -> Tensor"); \
11+
m.def( \
12+
"_embedding_" #weight_nbit \
13+
"bit(Tensor packed_weight_qvals, Tensor num_embeddings_tensor, Tensor embedding_dim_tensor, Tensor weight_scales, Tensor weight_zeros, Tensor indices) -> Tensor"); \
14+
m.def( \
15+
"_embedding_" #weight_nbit \
16+
"bit.out(Tensor packed_weight_qvals, Tensor num_embeddings_tensor, Tensor embedding_dim_tensor, Tensor weight_scales, Tensor weight_zeros, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)");
17+
18+
#define DEFINE_CPU_IMPL(weight_nbit) \
19+
m.impl( \
20+
"_pack_embedding_" #weight_nbit "bit", \
21+
&pack_embedding_cpu<weight_nbit>); \
22+
m.impl("_embedding_" #weight_nbit "bit", &embedding_cpu<weight_nbit>); \
23+
m.impl("_embedding_" #weight_nbit "bit.out", &embedding_out_cpu<weight_nbit>);
24+
25+
#define DEFINE_META_IMPL(weight_nbit) \
26+
m.impl( \
27+
"_pack_embedding_" #weight_nbit "bit", \
28+
&pack_embedding_meta<weight_nbit>); \
29+
m.impl("_embedding_" #weight_nbit "bit", &embedding_meta<weight_nbit>);
30+
31+
TORCH_LIBRARY_FRAGMENT(torchao, m) {
32+
DEFINE_OP(1);
33+
DEFINE_OP(2);
34+
DEFINE_OP(3);
35+
DEFINE_OP(4);
36+
DEFINE_OP(5);
37+
DEFINE_OP(6);
38+
}
39+
40+
TORCH_LIBRARY_IMPL(torchao, CPU, m) {
41+
DEFINE_CPU_IMPL(1);
42+
DEFINE_CPU_IMPL(2);
43+
DEFINE_CPU_IMPL(3);
44+
DEFINE_CPU_IMPL(4);
45+
DEFINE_CPU_IMPL(5);
46+
DEFINE_CPU_IMPL(6);
47+
}
48+
49+
TORCH_LIBRARY_IMPL(torchao, Meta, m) {
50+
DEFINE_META_IMPL(1);
51+
DEFINE_META_IMPL(2);
52+
DEFINE_META_IMPL(3);
53+
DEFINE_META_IMPL(4);
54+
DEFINE_META_IMPL(5);
55+
DEFINE_META_IMPL(6);
56+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
#include <torchao/experimental/ops/library.h>
9+
#include <torchao/experimental/ops/packed_weights_header.h>
10+
11+
namespace torchao::ops::embedding_xbit {
12+
13+
inline torchao::ops::PackedWeightsHeader get_packed_weights_header_universal(
14+
int weight_nbit,
15+
int min_value_chunk_size,
16+
int max_value_chunk_size,
17+
int version = 1) {
18+
return torchao::ops::PackedWeightsHeader(
19+
torchao::ops::PackedWeightsFormat::embedding_xbit_universal,
20+
{version,
21+
weight_nbit,
22+
min_value_chunk_size,
23+
max_value_chunk_size,
24+
0,
25+
0,
26+
0,
27+
0,
28+
0,
29+
0,
30+
0,
31+
0,
32+
0,
33+
0});
34+
}
35+
36+
} // namespace torchao::ops::embedding_xbit

0 commit comments

Comments
 (0)