@@ -23,16 +23,14 @@ namespace torchao::kernels::cpu::aarch64::kleidi {
23
23
// Helper functions
24
24
// TODO: find a better place for these?
25
25
26
- size_t roundup (size_t a, size_t b) {
27
- return ((a + b - 1 ) / b) * b;
28
- }
26
+ size_t roundup (size_t a, size_t b) { return ((a + b - 1 ) / b) * b; }
29
27
30
28
uint16_t get_bf16_from_float (float f) {
31
29
uint16_t bf16;
32
30
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
33
31
memcpy (&bf16, &f, sizeof (uint16_t ));
34
32
#else
35
- const void * fp = reinterpret_cast <const void *>(
33
+ const void * fp = reinterpret_cast <const void *>(
36
34
reinterpret_cast <uintptr_t >(&f) + sizeof (float ) - sizeof (uint16_t ));
37
35
memcpy (&bf16, fp, sizeof (uint16_t ));
38
36
#endif // __BYTE_ORDER__
@@ -45,52 +43,31 @@ using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel;
45
43
46
44
size_t activation_data_size (const Ukernel ukernel, int m, int k) {
47
45
auto lhs_packing = get_lhs_packing ();
48
- return lhs_packing.get_lhs_packed_size (
49
- m, k, ukernel. get_mr (), ukernel.get_kr (), ukernel.get_sr ());
46
+ return lhs_packing.get_lhs_packed_size (m, k, ukernel. get_mr (),
47
+ ukernel.get_kr (), ukernel.get_sr ());
50
48
}
51
49
52
- void prepare_activation_data (
53
- const Ukernel ukernel,
54
- void * activation_data,
55
- int m,
56
- int k,
57
- const float * activations) {
50
+ void prepare_activation_data (const Ukernel ukernel, void *activation_data,
51
+ int m, int k, const float *activations) {
58
52
auto lhs_pack = get_lhs_packing ();
59
53
60
- lhs_pack.run_lhs_pack (
61
- m,
62
- k,
63
- ukernel.get_mr (),
64
- ukernel.get_kr (),
65
- ukernel.get_sr (),
66
- /* m_index_start=*/ 0 ,
67
- activations,
68
- /* lhs_stride=*/ k * sizeof (float ),
69
- activation_data);
54
+ lhs_pack.run_lhs_pack (m, k, ukernel.get_mr (), ukernel.get_kr (),
55
+ ukernel.get_sr (),
56
+ /* m_index_start=*/ 0 , activations,
57
+ /* lhs_stride=*/ k * sizeof (float ), activation_data);
70
58
}
71
59
72
60
size_t weight_data_size (const Ukernel ukernel, int n, int k, int group_size) {
73
61
auto rhs_pack = get_rhs_packing ();
74
- return rhs_pack.get_rhs_packed_size (
75
- n,
76
- k,
77
- ukernel.get_nr (),
78
- ukernel.get_kr (),
79
- ukernel.get_sr (),
80
- group_size,
81
- kai_datatype::kai_dt_bf16);
62
+ return rhs_pack.get_rhs_packed_size (n, k, ukernel.get_nr (), ukernel.get_kr (),
63
+ ukernel.get_sr (), group_size,
64
+ kai_datatype::kai_dt_bf16);
82
65
}
83
66
84
- void prepare_weight_data (
85
- const Ukernel ukernel,
86
- void * weight_data,
87
- int n,
88
- int k,
89
- int group_size,
90
- const int8_t * weight_qvals,
91
- const float * weight_scales,
92
- const int8_t * weight_zeros,
93
- const float * bias) {
67
+ void prepare_weight_data (const Ukernel ukernel, void *weight_data, int n, int k,
68
+ int group_size, const int8_t *weight_qvals,
69
+ const float *weight_scales, const int8_t *weight_zeros,
70
+ const float *bias) {
94
71
// TODO(T204312268) - remove this constraint and pad when possible
95
72
assert (n % 2 == 0 );
96
73
@@ -123,25 +100,19 @@ void prepare_weight_data(
123
100
}
124
101
125
102
// Parameters for packing
126
- rhs_packing::qparams_t qparams{
127
- .lhs_zero_point = 1 ,
128
- .rhs_zero_point = wzp,
129
- .scale_dt = kai_datatype::kai_dt_bf16};
103
+ rhs_packing::qparams_t qparams{.lhs_zero_point = 1 ,
104
+ .rhs_zero_point = wzp,
105
+ .scale_dt = kai_datatype::kai_dt_bf16};
130
106
131
107
auto rhs_pack = get_rhs_packing ();
132
108
133
109
rhs_pack.run_rhs_pack (
134
- /* groups=*/ 1 ,
135
- n,
136
- k,
137
- ukernel.get_nr (),
138
- ukernel.get_kr (),
139
- ukernel.get_sr (),
110
+ /* groups=*/ 1 , n, k, ukernel.get_nr (), ukernel.get_kr (), ukernel.get_sr (),
140
111
group_size,
141
- /* rhs=*/ reinterpret_cast <const uint8_t *>(packed_weight_qvals.data ()),
112
+ /* rhs=*/ reinterpret_cast <const uint8_t *>(packed_weight_qvals.data ()),
142
113
/* rhs_stride=*/ roundup (k, 2 ) / 2 ,
143
114
/* bias=*/ bias,
144
- /* scale=*/ reinterpret_cast <const uint16_t *>(weight_scales_bf16.data ()),
115
+ /* scale=*/ reinterpret_cast <const uint16_t *>(weight_scales_bf16.data ()),
145
116
/* scale_stride=*/ sizeof (uint16_t ) * (roundup (k, group_size) / group_size),
146
117
/* rhs_packed=*/ weight_data,
147
118
/* extra_bytes=*/ 0 ,
0 commit comments