Skip to content

Commit f3ff2e5

Browse files
authored
Add dynamic shape support for lowbit kernels (#1942)
* Init Summary: Adds dynmaic shape support to linear op. * up * up
1 parent a8d2159 commit f3ff2e5

File tree

10 files changed

+135
-120
lines changed

10 files changed

+135
-120
lines changed

torchao/__init__.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,9 @@
3434
# They can also be built outside of the torchao install process by
3535
# running the script `torchao/experimental/build_torchao_ops.sh <aten|executorch>`
3636
# For more information, see https://github.com/pytorch/ao/blob/main/torchao/experimental/docs/readme.md
37-
experimental_lib = list(Path(__file__).parent.glob("libtorchao_ops_aten.*"))
38-
if len(experimental_lib) > 0:
39-
assert (
40-
len(experimental_lib) == 1
41-
), f"Expected at most one libtorchao_ops_aten.* file, found {len(experimental_lib)}"
42-
torch.ops.load_library(str(experimental_lib[0]))
43-
except:
44-
logging.debug("Skipping import of cpp extensions")
37+
from torchao.experimental.op_lib import * # noqa: F403
38+
except Exception as e:
39+
logging.debug(f"Skipping import of cpp extensions: {e}")
4540

4641
from torchao.quantization import (
4742
autoquant,

torchao/experimental/op_lib.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from pathlib import Path
8+
9+
import torch
10+
from torch import Tensor
11+
from torch.library import impl
12+
13+
# Load C++ ops
14+
lib_path = Path(__file__).parent.parent
15+
libs = list(lib_path.glob("libtorchao_ops_aten.*"))
16+
assert (
17+
len(libs) == 1
18+
), f"Expected to find one libtorchao_ops_aten.* library at {lib_path}, but found {len(libs)}"
19+
torch.ops.load_library(str(libs[0]))
20+
21+
22+
# Define meta ops. To support dynamic shapes, some meta ops need to
23+
# be defined in python instead of C++.
24+
torchao_lib = torch.library.Library("torchao", "IMPL")
25+
for weight_nbit in range(1, 9):
26+
27+
@impl(torchao_lib, f"_linear_8bit_act_{weight_nbit}bit_weight", "Meta")
28+
def _(
29+
activations: Tensor,
30+
packed_weights: Tensor,
31+
group_size: int,
32+
n: int,
33+
k: int,
34+
):
35+
assert activations.dim() == 2
36+
m, k_ = activations.shape
37+
assert k_ == k
38+
return torch.empty(m, n, dtype=activations.dtype, device="meta")
39+
40+
@impl(torchao_lib, f"_embedding_{weight_nbit}bit", "Meta")
41+
def _(
42+
packed_weight_qvals: Tensor,
43+
num_embeddings: int,
44+
embedding_dim: int,
45+
weight_scales: Tensor,
46+
weight_zeros: Tensor,
47+
indices: Tensor,
48+
):
49+
assert indices.dim() == 1
50+
num_out = indices.shape[0]
51+
return torch.empty(num_out, embedding_dim, dtype=torch.float32, device="meta")
52+
53+
@impl(torchao_lib, f"_shared_embedding_{weight_nbit}bit", "Meta")
54+
def _(packed_weights: Tensor, group_size: int, n: int, k: int, indices: Tensor):
55+
assert indices.dim() == 1
56+
num_out = indices.shape[0]
57+
return torch.empty(num_out, k, dtype=torch.float32, device="meta")

torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h

Lines changed: 10 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -92,18 +92,12 @@ void check_embedding_inputs(
9292
template <int weight_nbit>
9393
Tensor embedding_out_cpu(
9494
const Tensor& packed_weight_qvals,
95-
// TODO(T200095131): convert to
96-
// int64_t when supported by AOTI
97-
// Currently they are tensors with size
98-
// equal to (0, the int they wrap)
99-
const Tensor& num_embeddings_tensor,
100-
const Tensor& embedding_dim_tensor,
95+
const int64_t& num_embeddings,
96+
const int64_t& embedding_dim,
10197
const Tensor& weight_scales,
10298
const Tensor& weight_zeros,
10399
const Tensor& indices,
104100
Tensor& out) {
105-
int num_embeddings = num_embeddings_tensor.size(1);
106-
int embedding_dim = embedding_dim_tensor.size(1);
107101
int group_size;
108102
check_embedding_inputs<weight_nbit>(
109103
packed_weight_qvals,
@@ -117,16 +111,8 @@ Tensor embedding_out_cpu(
117111
int num_out = indices.size(0);
118112
const int8_t* weight_zeros_ptr = weight_zeros.const_data_ptr<int8_t>();
119113

120-
#ifdef USE_ATEN
121-
TORCHAO_CHECK(out.dtype() == torch::kFloat32, "out must be float32");
122-
out.resize_({num_out, embedding_dim});
123-
#endif // USE_ATEN
124-
125-
#ifdef USE_EXECUTORCH
126-
TORCHAO_CHECK(out.dim() == 2, "out must be 2D");
127-
TORCHAO_CHECK(out.size(0) == num_out, "out shape is incorrect");
128-
TORCHAO_CHECK(out.size(1) == embedding_dim, "out shape is incorrect");
129-
#endif // USE_EXECUTORCH
114+
// Explicit cast from int64_t to int is required for Executorch
115+
TORCHAO_RESIZE_TENSOR(out, {(int)num_out, (int)embedding_dim});
130116

131117
const int32_t* index32_ptr = nullptr;
132118
const int64_t* index64_ptr = nullptr;
@@ -169,20 +155,16 @@ Tensor embedding_out_cpu(
169155
template <int weight_nbit>
170156
Tensor embedding_cpu(
171157
const Tensor& packed_weight_qvals,
172-
// TODO(T200095131): convert to
173-
// int64_t when supported by AOTI
174-
// Currently they are tensors with size
175-
// equal to (0, the int they wrap)
176-
const Tensor& num_embeddings_tensor,
177-
const Tensor& embedding_dim_tensor,
158+
const int64_t& num_embeddings,
159+
const int64_t& embedding_dim,
178160
const Tensor& weight_scales,
179161
const Tensor& weight_zeros,
180162
const Tensor& indices) {
181163
Tensor output_tensor = torch::empty({}, torch::kFloat32);
182164
embedding_out_cpu<weight_nbit>(
183165
packed_weight_qvals,
184-
num_embeddings_tensor,
185-
embedding_dim_tensor,
166+
num_embeddings,
167+
embedding_dim,
186168
weight_scales,
187169
weight_zeros,
188170
indices,
@@ -191,25 +173,6 @@ Tensor embedding_cpu(
191173
}
192174
#endif // USE_ATEN
193175

194-
#ifdef USE_ATEN
195-
template <int weight_nbit>
196-
Tensor embedding_meta(
197-
const Tensor& packed_weight_qvals,
198-
// TODO(T200095131): convert to
199-
// int64_t when supported by AOTI
200-
// Currently they are tensors with size
201-
// equal to (0, the int they wrap)
202-
const Tensor& num_embeddings_tensor,
203-
const Tensor& embedding_dim_tensor,
204-
const Tensor& weight_scales,
205-
const Tensor& weight_zeros,
206-
const Tensor& indices) {
207-
int embedding_dim = embedding_dim_tensor.size(1);
208-
int num_out = indices.size(0);
209-
return torch::empty({num_out, embedding_dim}).to("meta");
210-
}
211-
#endif // USE_ATEN
212-
213176
#ifdef USE_ATEN
214177
template <int weight_nbit>
215178
Tensor pack_embedding_cpu(const Tensor& weight_qvals) {
@@ -261,10 +224,10 @@ Tensor pack_embedding_meta(const Tensor& weight_qvals) {
261224
TORCHAO_CHECK(
262225
embedding_dim % 8 == 0, "embedding_dim must be a multiple of 8 to pack");
263226
int packed_embedding_dim = embedding_dim * weight_nbit / 8;
227+
auto options = torch::TensorOptions().device(c10::DeviceType::Meta).dtype(torch::kInt8);
264228
return torch::empty(
265229
torchao::ops::PackedWeightsHeader::size() +
266-
(num_embeddings * packed_embedding_dim))
267-
.to("meta");
230+
(num_embeddings * packed_embedding_dim), options);
268231
}
269232
#endif // USE_ATEN
270233

@@ -371,17 +334,4 @@ Tensor shared_embedding_cpu(
371334
}
372335
#endif // USE_ATEN
373336

374-
#ifdef USE_ATEN
375-
template <int weight_nbit>
376-
Tensor shared_embedding_meta(
377-
const Tensor& packed_weights,
378-
const int64_t& group_size,
379-
const int64_t& n, // same as num_embeddings
380-
const int64_t& k, // same as embedding_dim
381-
const Tensor& indices) {
382-
int num_out = indices.size(0);
383-
return torch::empty({num_out, k}).to("meta");
384-
}
385-
#endif // USE_ATEN
386-
387337
#endif // defined(USE_ATEN) || defined(USE_EXECUTORCH)

torchao/experimental/ops/embedding_xbit/op_embedding_xbit_aten.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
m.def("_pack_embedding_" #weight_nbit "bit(Tensor weight_qvals) -> Tensor"); \
1111
m.def( \
1212
"_embedding_" #weight_nbit \
13-
"bit(Tensor packed_weight_qvals, Tensor num_embeddings_tensor, Tensor embedding_dim_tensor, Tensor weight_scales, Tensor weight_zeros, Tensor indices) -> Tensor"); \
13+
"bit(Tensor packed_weight_qvals, int num_embeddings, int embedding_dim, Tensor weight_scales, Tensor weight_zeros, Tensor indices) -> Tensor"); \
1414
m.def( \
1515
"_embedding_" #weight_nbit \
16-
"bit.out(Tensor packed_weight_qvals, Tensor num_embeddings_tensor, Tensor embedding_dim_tensor, Tensor weight_scales, Tensor weight_zeros, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)"); \
16+
"bit.out(Tensor packed_weight_qvals, int num_embeddings, int embedding_dim, Tensor weight_scales, Tensor weight_zeros, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)"); \
1717
m.def( \
1818
"_shared_embedding_" #weight_nbit \
1919
"bit.out(Tensor packed_weights, int group_size, int n, int k, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)"); \
@@ -38,11 +38,7 @@
3838
#define DEFINE_META_IMPL(weight_nbit) \
3939
m.impl( \
4040
"_pack_embedding_" #weight_nbit "bit", \
41-
&pack_embedding_meta<weight_nbit>); \
42-
m.impl("_embedding_" #weight_nbit "bit", &embedding_meta<weight_nbit>); \
43-
m.impl( \
44-
"_shared_embedding_" #weight_nbit "bit", \
45-
&shared_embedding_meta<weight_nbit>);
41+
&pack_embedding_meta<weight_nbit>);
4642

4743
TORCH_LIBRARY_FRAGMENT(torchao, m) {
4844
DEFINE_OP(1);

torchao/experimental/ops/embedding_xbit/op_embedding_xbit_executorch.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,17 @@
1010
Tensor _op_out_##weight_nbit( \
1111
RuntimeContext& ctx, \
1212
const Tensor& packed_weight_qvals, \
13-
const Tensor& num_embeddings_tensor, \
14-
const Tensor& embedding_dim_tensor, \
13+
const int64_t& num_embeddings, \
14+
const int64_t& embedding_dim, \
1515
const Tensor& weight_scales, \
1616
const Tensor& weight_zeros, \
1717
const Tensor& indices, \
1818
Tensor& out) { \
1919
(void)ctx; \
2020
embedding_out_cpu<weight_nbit>( \
2121
packed_weight_qvals, \
22-
num_embeddings_tensor, \
23-
embedding_dim_tensor, \
22+
num_embeddings, \
23+
embedding_dim, \
2424
weight_scales, \
2525
weight_zeros, \
2626
indices, \

torchao/experimental/ops/library.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@ using Tensor = at::Tensor;
1313
#define Tensor_dtype_kInt32 torch::kInt32
1414
#define Tensor_dtype_kInt64 torch::kInt64
1515
#define TORCHAO_CHECK(cond, msg) TORCH_CHECK(cond, msg)
16+
#define TORCHAO_RESIZE_TENSOR(tensor, ...) tensor.resize_({__VA_ARGS__})
1617

1718
#elif defined(USE_EXECUTORCH) && !defined(USE_ATEN)
1819
#pragma message("USE_EXECUTORCH")
1920
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
2021
#include <executorch/runtime/kernel/kernel_includes.h>
22+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
2123
using Tensor = torch::executor::Tensor;
2224
using RuntimeContext = torch::executor::KernelRuntimeContext;
2325
#define Tensor_dtype_kInt32 torch::executor::ScalarType::Int
2426
#define Tensor_dtype_kInt64 torch::executor::ScalarType::Long
2527
#define TORCHAO_CHECK(cond, msg) ET_CHECK_MSG(cond, msg)
28+
#define TORCHAO_RESIZE_TENSOR(tensor, ...) \
29+
ET_CHECK_MSG(torch::executor::resize_tensor(tensor, {__VA_ARGS__}) == torch::executor::Error::Ok, "resize failed")
2630

2731
#elif !defined(USE_EXECUTORCH) && !defined(USE_ATEN)
2832
#pragma message("Neither USE_ATEN or USE_EXECUTORCH defined")

torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ Tensor pack_weights_meta(
133133
torchao::ops::PackedWeightsHeader::size() +
134134
get_packed_weight_data_size(
135135
ukernel_config, n, k, group_size, has_weight_zeros, has_bias);
136-
return torch::empty({static_cast<int64_t>(packed_weight_data_size)})
137-
.to("meta");
136+
auto options = torch::TensorOptions().device(c10::DeviceType::Meta).dtype(torch::kInt8);
137+
return torch::empty({static_cast<int64_t>(packed_weight_data_size)}, options);
138138
}
139139
#endif // USE_ATEN
140140

@@ -166,15 +166,8 @@ Tensor linear_out_cpu(
166166
TORCHAO_CHECK(out.dtype() == torch::kFloat32, "out must be float32");
167167
#endif // USE_ATEN
168168

169-
#ifdef USE_ATEN
170-
out.resize_({m, n});
171-
#endif // USE_ATEN
172-
173-
#ifdef USE_EXECUTORCH
174-
TORCHAO_CHECK(out.dim() == 2, "out must be 2D");
175-
TORCHAO_CHECK(out.size(0) == m, "out shape is incorrect");
176-
TORCHAO_CHECK(out.size(1) == n, "out shape is incorrect");
177-
#endif // USE_EXECUTORCH
169+
// Explicit cast from int64_t to int is required for Executorch
170+
TORCHAO_RESIZE_TENSOR(out, {(int)m, (int)n});
178171

179172
using namespace torchao::ops::linear_8bit_act_xbit_weight;
180173

@@ -254,24 +247,4 @@ Tensor linear_cpu(
254247
}
255248
#endif // USE_ATEN
256249

257-
#ifdef USE_ATEN
258-
template <int weight_nbit>
259-
Tensor linear_meta(
260-
const Tensor& activations,
261-
const Tensor& packed_weights,
262-
const int64_t& group_size,
263-
const int64_t& n,
264-
const int64_t& k) {
265-
TORCHAO_CHECK(n >= 1, "n must be >= 1");
266-
TORCHAO_CHECK(k >= 1, "k must be >= 1");
267-
268-
TORCHAO_CHECK(activations.dim() == 2, "activations must be 2D");
269-
int m = activations.size(0);
270-
int k_ = activations.size(1);
271-
TORCHAO_CHECK(
272-
k == k_, "activation shape is incompatible with packed weights.");
273-
return torch::empty({m, n}).to("meta");
274-
}
275-
#endif // USE_ATEN
276-
277250
} // namespace

torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,7 @@
3434
&pack_weights_meta<weight_nbit>); \
3535
m.impl( \
3636
"_pack_8bit_act_" #weight_nbit "bit_weight", \
37-
&pack_weights_meta<weight_nbit>); \
38-
m.impl( \
39-
"_linear_8bit_act_" #weight_nbit "bit_weight", \
40-
&linear_meta<weight_nbit>);
37+
&pack_weights_meta<weight_nbit>)
4138

4239
TORCH_LIBRARY_FRAGMENT(torchao, m) {
4340
DEFINE_OP(1);

torchao/experimental/quant_api.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -384,12 +384,8 @@ def quantize_and_pack_weights(self, weights, group_size, has_weight_zeros):
384384
self.register_buffer(
385385
"packed_weight_qvals", self.pack_weights_op(weight_qvals.to(torch.int8))
386386
)
387-
self.register_buffer(
388-
"num_embeddings", torch.empty(0, num_embeddings, dtype=torch.int8)
389-
)
390-
self.register_buffer(
391-
"embedding_dim", torch.empty(0, embedding_dim, dtype=torch.int8)
392-
)
387+
self.num_embeddings = num_embeddings
388+
self.embedding_dim = embedding_dim
393389
self.register_buffer("weight_scales", weight_scales)
394390
self.register_buffer("weight_zeros", weight_zeros.to(torch.int8))
395391

0 commit comments

Comments
 (0)