Skip to content

Commit e341c2e

Browse files
committed
Update
[ghstack-poisoned]
2 parents 32005c9 + 3fb1665 commit e341c2e

21 files changed

+1594
-602
lines changed

.github/workflows/torchao_experimental_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
# Install executorch first because it installs its own version
3737
# of torch and torchao, which we do not want to use
3838
pip install executorch
39-
pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu" --force-reinstall
39+
pip install torch==2.7.0.dev20250311 --index-url "https://download.pytorch.org/whl/nightly/cpu" --force-reinstall
4040
pip install numpy
4141
pip install pytest
4242
pip install parameterized

benchmarks/mx_formats/cast_bench.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch._inductor.utils import do_bench_using_profiling
77

88
from torchao.prototype.mx_formats.custom_cast import (
9-
to_mxfp8_dim1,
9+
triton_to_mxfp8_dim1,
1010
)
1111
from torchao.prototype.mx_formats.mx_tensor import to_mx
1212

@@ -172,12 +172,12 @@ def run(
172172
bps = (bytes_r + bytes_w) / (time_us / 1e6)
173173

174174
elif mode == "dim1_mx_triton":
175-
y_d1, s_d1 = to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
175+
y_d1, s_d1 = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
176176

177177
for _ in range(2):
178-
__ = to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
178+
__ = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
179179
time_us = benchmark_cuda_function_in_microseconds(
180-
lambda x, b: to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE),
180+
lambda x, b: triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE),
181181
x,
182182
BLOCK_SIZE,
183183
)

test/prototype/mx_formats/test_custom_cast.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,11 @@
2626
get_bits,
2727
pack_uint4,
2828
pack_uint6,
29-
# TODO(before land): better name?
30-
to_mxfp8_dim1,
31-
to_mxfp8_dim1_reference,
3229
triton_f4_to_bf16,
3330
triton_f6_e2m3_to_bf16,
3431
triton_f6_e3m2_to_bf16,
32+
triton_to_mxfp8_dim1,
33+
triton_to_mxfp8_dim1_reference,
3534
unpack_uint4,
3635
)
3736
from torchao.prototype.mx_formats.fp_format_spec import (
@@ -460,9 +459,11 @@ def test_fp6_e3m2_pack_unpack():
460459
)
461460
@pytest.mark.parametrize("M", (256, 2048))
462461
@pytest.mark.parametrize("K", (256, 2048))
462+
# @pytest.mark.parametrize("M", (256,))
463+
# @pytest.mark.parametrize("K", (256,))
463464
def test_triton_mxfp8_dim1(M, K):
464465
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
465-
x_mx_ref, x_s_ref = to_mxfp8_dim1_reference(x, block_size=32)
466-
x_mx_t, x_s_t = to_mxfp8_dim1(x, inner_block_size=32)
466+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim1_reference(x, block_size=32)
467+
x_mx_t, x_s_t = triton_to_mxfp8_dim1(x, inner_block_size=32)
467468
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
468469
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)

torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
#include <arm_neon.h>
1212
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
13+
#include <torchao/experimental/kernels/cpu/aarch64/linear/pack_weights.h>
1314
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>
1415
#include <cassert>
16+
#include <vector>
1517

1618
namespace torchao::kernels::cpu::aarch64::embedding {
1719

@@ -177,6 +179,7 @@ inline void embedding_(
177179
if (weight_zeros != nullptr) {
178180
zero = weight_zeros[group_idx];
179181
}
182+
180183
internal::vec_dequantize_and_store_16_values(out + i, qvals0, scale, zero);
181184
internal::vec_dequantize_and_store_16_values(
182185
out + i + 16, qvals1, scale, zero);
@@ -322,6 +325,61 @@ inline void pack_embedding_weight_qvals(
322325
qvals + index * embedding_dim);
323326
}
324327

328+
// Embedding op that shares weights with unembedding linear op
329+
template <int weight_nbit, int nr, int kr, int sr>
330+
inline void shared_embedding(
331+
// Output
332+
float* out,
333+
// Inputs
334+
const void* packed_weights,
335+
int n,
336+
int k,
337+
int group_size,
338+
bool has_weight_zeros,
339+
bool has_bias,
340+
int index) {
341+
assert(k % group_size == 0);
342+
assert(group_size % 16 == 0);
343+
344+
int groups_per_k = k / group_size;
345+
std::vector<int8_t> weight_qvals(k * nr);
346+
std::vector<float> weight_scales(groups_per_k * nr);
347+
std::vector<int8_t> weight_zeros(groups_per_k * nr);
348+
std::vector<float> bias(nr);
349+
350+
// Set n_idx to multiple of nr that is at most index
351+
// j is index of "index" in nr group
352+
int n_idx = index / nr;
353+
n_idx = n_idx * nr;
354+
int j = index - n_idx;
355+
356+
torchao::kernels::cpu::aarch64::linear::packing::
357+
unpack_weights_at_n_idx<weight_nbit, nr, kr, sr>(
358+
weight_qvals.data(),
359+
weight_scales.data(),
360+
has_weight_zeros ? weight_zeros.data() : nullptr,
361+
has_bias ? bias.data() : nullptr,
362+
n_idx,
363+
n,
364+
k,
365+
group_size,
366+
has_weight_zeros,
367+
has_bias,
368+
packed_weights);
369+
370+
// Dequantize and store to output (size k)
371+
int8x16_t qvals;
372+
for (int i = 0; i < k; i += 16) {
373+
qvals = vld1q_s8(weight_qvals.data() + j * k + i);
374+
float scale = weight_scales[j * groups_per_k + i / group_size];
375+
float zero = 0.0;
376+
if (has_weight_zeros) {
377+
zero = weight_zeros[j * groups_per_k + i / group_size];
378+
}
379+
internal::vec_dequantize_and_store_16_values(out + i, qvals, scale, zero);
380+
}
381+
}
382+
325383
} // namespace torchao::kernels::cpu::aarch64::embedding
326384

327385
#endif // defined(__aarch64__) || defined(__ARM_NEON)

torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot-impl.h

Lines changed: 22 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
1212
#include <torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_prepare_activation_data_1xk_f32-impl.h>
13+
#include <torchao/experimental/kernels/cpu/aarch64/linear/pack_weights.h>
1314
#include <torchao/experimental/kernels/cpu/aarch64/quantization/quantize.h>
1415
#include <torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h>
1516
#include <cassert>
@@ -149,8 +150,8 @@ void kernel_impl(
149150
int32_t activation_qvals_sum = *((int32_t*)activation_ptr);
150151
activation_ptr += sizeof(int32_t);
151152

152-
int8_t weight_zero = *((int8_t*)weight_data_byte_ptr);
153-
weight_data_byte_ptr += sizeof(int8_t);
153+
int32_t weight_zero = *((int32_t*)weight_data_byte_ptr);
154+
weight_data_byte_ptr += sizeof(int32_t);
154155

155156
res += (weight_scale * activation_scale) *
156157
(qval_dot - (activation_zero * weight_qvals_sum) -
@@ -190,31 +191,14 @@ size_t inline weight_data_size_impl(
190191
int weight_nbit,
191192
bool has_weight_zeros,
192193
bool has_bias) {
193-
assert(k % group_size == 0);
194-
assert(k % 32 == 0);
195-
int groups_per_col = k / group_size;
196-
int col_size = 0;
197-
198-
// qvals
199-
// (k * weight_bit) bits -> ((k / 8) * weight_bit) bytes
200-
col_size += (k / 8) * weight_nbit;
201-
202-
// scales
203-
col_size += sizeof(float) * groups_per_col;
204-
205-
// qvals_sum
206-
col_size += sizeof(int32_t) * groups_per_col;
207-
208-
// zeros
209-
if (has_weight_zeros) {
210-
col_size += sizeof(int8_t) * groups_per_col;
211-
}
212-
213-
if (has_bias) {
214-
col_size += sizeof(float);
215-
}
216-
217-
return col_size * n;
194+
return torchao::kernels::cpu::aarch64::linear::packing::packed_weights_size(
195+
n,
196+
k,
197+
group_size,
198+
weight_nbit,
199+
has_weight_zeros,
200+
has_bias,
201+
/*nr*/ 1);
218202
}
219203

220204
template <int weight_nbit>
@@ -227,56 +211,19 @@ void prepare_weight_data_impl(
227211
int group_size,
228212
const int8_t* weight_qvals,
229213
const float* weight_scales,
214+
// Ignored if has_weight_zeros = false
230215
const int8_t* weight_zeros,
231216
const float* bias) {
232-
assert(k % group_size == 0);
233-
assert(group_size % 32 == 0);
234-
235-
bool has_weight_zeros = (weight_zeros != nullptr);
236-
bool has_bias = (bias != nullptr);
237-
238-
auto weight_data_byte_ptr = (char*)weight_data;
239-
constexpr int bytes_per_32_weight_values = 4 * weight_nbit;
240-
241-
int8x16_t wq0, wq1;
242-
243-
const int8_t* qvals_ptr = weight_qvals;
244-
const float* scales_ptr = weight_scales;
245-
const int8_t* zeros_ptr = weight_zeros;
246-
const float* bias_ptr = bias;
247-
248-
for (int n_idx = 0; n_idx < n; n_idx++) {
249-
for (int k_idx = 0; k_idx < k; k_idx += group_size) {
250-
int32_t group_qvals_sum = 0;
251-
for (int i = 0; i < group_size; i += 32) {
252-
wq0 = vld1q_s8(qvals_ptr);
253-
wq1 = vld1q_s8(qvals_ptr + 16);
254-
qvals_ptr += 32;
255-
256-
group_qvals_sum += vaddlvq_s8(wq0) + vaddlvq_s8(wq1);
257-
258-
torchao::bitpacking::vec_pack_32_lowbit_values<weight_nbit>(
259-
/*packed=*/(uint8_t*)weight_data_byte_ptr,
260-
/*unpacked0=*/wq0,
261-
/*unpacked1=*/wq1);
262-
weight_data_byte_ptr += bytes_per_32_weight_values;
263-
}
264-
*((float*)weight_data_byte_ptr) = *scales_ptr++;
265-
weight_data_byte_ptr += sizeof(float);
266-
267-
*((int32_t*)weight_data_byte_ptr) = group_qvals_sum;
268-
weight_data_byte_ptr += sizeof(int32_t);
269-
270-
if (has_weight_zeros) {
271-
*((int8_t*)weight_data_byte_ptr) = *zeros_ptr++;
272-
weight_data_byte_ptr += sizeof(int8_t);
273-
}
274-
}
275-
if (has_bias) {
276-
*((float*)weight_data_byte_ptr) = *bias_ptr++;
277-
weight_data_byte_ptr += sizeof(float);
278-
}
279-
}
217+
torchao::kernels::cpu::aarch64::linear::packing::
218+
pack_weights<weight_nbit, /*nr*/ 1, /*kr*/ 32, /*sr*/ 2>(
219+
weight_data,
220+
n,
221+
k,
222+
group_size,
223+
weight_qvals,
224+
weight_scales,
225+
weight_zeros,
226+
bias);
280227
}
281228

282229
} // namespace

0 commit comments

Comments
 (0)