Skip to content

Commit 22d7d51

Browse files
authored
Reformat (#1723)
* reformat * up
1 parent 8fc49fe commit 22d7d51

File tree

6 files changed

+962
-1741
lines changed

6 files changed

+962
-1741
lines changed

torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h

Lines changed: 23 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,14 @@ namespace torchao::kernels::cpu::aarch64::kleidi {
2323
// Helper functions
2424
// TODO: find a better place for these?
2525

26-
size_t roundup(size_t a, size_t b) {
27-
return ((a + b - 1) / b) * b;
28-
}
26+
size_t roundup(size_t a, size_t b) { return ((a + b - 1) / b) * b; }
2927

3028
uint16_t get_bf16_from_float(float f) {
3129
uint16_t bf16;
3230
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
3331
memcpy(&bf16, &f, sizeof(uint16_t));
3432
#else
35-
const void* fp = reinterpret_cast<const void*>(
33+
const void *fp = reinterpret_cast<const void *>(
3634
reinterpret_cast<uintptr_t>(&f) + sizeof(float) - sizeof(uint16_t));
3735
memcpy(&bf16, fp, sizeof(uint16_t));
3836
#endif // __BYTE_ORDER__
@@ -45,52 +43,31 @@ using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel;
4543

4644
size_t activation_data_size(const Ukernel ukernel, int m, int k) {
4745
auto lhs_packing = get_lhs_packing();
48-
return lhs_packing.get_lhs_packed_size(
49-
m, k, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr());
46+
return lhs_packing.get_lhs_packed_size(m, k, ukernel.get_mr(),
47+
ukernel.get_kr(), ukernel.get_sr());
5048
}
5149

52-
void prepare_activation_data(
53-
const Ukernel ukernel,
54-
void* activation_data,
55-
int m,
56-
int k,
57-
const float* activations) {
50+
void prepare_activation_data(const Ukernel ukernel, void *activation_data,
51+
int m, int k, const float *activations) {
5852
auto lhs_pack = get_lhs_packing();
5953

60-
lhs_pack.run_lhs_pack(
61-
m,
62-
k,
63-
ukernel.get_mr(),
64-
ukernel.get_kr(),
65-
ukernel.get_sr(),
66-
/*m_index_start=*/0,
67-
activations,
68-
/*lhs_stride=*/k * sizeof(float),
69-
activation_data);
54+
lhs_pack.run_lhs_pack(m, k, ukernel.get_mr(), ukernel.get_kr(),
55+
ukernel.get_sr(),
56+
/*m_index_start=*/0, activations,
57+
/*lhs_stride=*/k * sizeof(float), activation_data);
7058
}
7159

7260
size_t weight_data_size(const Ukernel ukernel, int n, int k, int group_size) {
7361
auto rhs_pack = get_rhs_packing();
74-
return rhs_pack.get_rhs_packed_size(
75-
n,
76-
k,
77-
ukernel.get_nr(),
78-
ukernel.get_kr(),
79-
ukernel.get_sr(),
80-
group_size,
81-
kai_datatype::kai_dt_bf16);
62+
return rhs_pack.get_rhs_packed_size(n, k, ukernel.get_nr(), ukernel.get_kr(),
63+
ukernel.get_sr(), group_size,
64+
kai_datatype::kai_dt_bf16);
8265
}
8366

84-
void prepare_weight_data(
85-
const Ukernel ukernel,
86-
void* weight_data,
87-
int n,
88-
int k,
89-
int group_size,
90-
const int8_t* weight_qvals,
91-
const float* weight_scales,
92-
const int8_t* weight_zeros,
93-
const float* bias) {
67+
void prepare_weight_data(const Ukernel ukernel, void *weight_data, int n, int k,
68+
int group_size, const int8_t *weight_qvals,
69+
const float *weight_scales, const int8_t *weight_zeros,
70+
const float *bias) {
9471
// TODO(T204312268) - remove this constraint and pad when possible
9572
assert(n % 2 == 0);
9673

@@ -123,25 +100,19 @@ void prepare_weight_data(
123100
}
124101

125102
// Parameters for packing
126-
rhs_packing::qparams_t qparams{
127-
.lhs_zero_point = 1,
128-
.rhs_zero_point = wzp,
129-
.scale_dt = kai_datatype::kai_dt_bf16};
103+
rhs_packing::qparams_t qparams{.lhs_zero_point = 1,
104+
.rhs_zero_point = wzp,
105+
.scale_dt = kai_datatype::kai_dt_bf16};
130106

131107
auto rhs_pack = get_rhs_packing();
132108

133109
rhs_pack.run_rhs_pack(
134-
/*groups=*/1,
135-
n,
136-
k,
137-
ukernel.get_nr(),
138-
ukernel.get_kr(),
139-
ukernel.get_sr(),
110+
/*groups=*/1, n, k, ukernel.get_nr(), ukernel.get_kr(), ukernel.get_sr(),
140111
group_size,
141-
/*rhs=*/reinterpret_cast<const uint8_t*>(packed_weight_qvals.data()),
112+
/*rhs=*/reinterpret_cast<const uint8_t *>(packed_weight_qvals.data()),
142113
/*rhs_stride=*/roundup(k, 2) / 2,
143114
/*bias=*/bias,
144-
/*scale=*/reinterpret_cast<const uint16_t*>(weight_scales_bf16.data()),
115+
/*scale=*/reinterpret_cast<const uint16_t *>(weight_scales_bf16.data()),
145116
/*scale_stride=*/sizeof(uint16_t) * (roundup(k, group_size) / group_size),
146117
/*rhs_packed=*/weight_data,
147118
/*extra_bytes=*/0,

0 commit comments

Comments
 (0)