Skip to content

Claen up op interface #1998

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,27 +60,47 @@ namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {

using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel;

template <int mr, int kr, int sr>
size_t
activation_data_size(int m, int k, int group_size, bool has_weight_zeros) {
size_t packed_activations_size(
int m,
int k,
int group_size,
bool has_weight_zeros,
int mr,
int kr,
int sr) {
(void)group_size; // unused
(void)has_weight_zeros; // unused
auto lhs_packing = get_lhs_packing();
return lhs_packing.get_lhs_packed_size(m, k, mr, kr, sr);
}

template <int mr, int kr, int sr>
void prepare_activation_data(
void* activation_data,
size_t packed_activations_offset(
int m_idx,
int k,
int group_size,
bool has_weight_zeros,
int mr,
int kr,
int sr) {
(void)group_size; // unused
(void)has_weight_zeros; // unused
auto lhs_pack = get_lhs_packing();
return lhs_pack.get_lhs_packed_offset(m_idx, k, mr, kr, sr);
}

void pack_activations(
void* packed_activations,
int m,
int k,
int group_size,
const float* activations,
bool has_weight_zeros) {
bool has_weight_zeros,
int mr,
int kr,
int sr) {
(void)group_size; // unused
(void)has_weight_zeros; // unused
auto lhs_pack = get_lhs_packing();

lhs_pack.run_lhs_pack(
m,
k,
Expand All @@ -90,33 +110,62 @@ void prepare_activation_data(
/*m_index_start=*/0,
activations,
/*lhs_stride=*/k * sizeof(float),
activation_data);
packed_activations);
}

template <int nr, int kr, int sr>
size_t weight_data_size(
size_t packed_weights_size(
int n,
int k,
int group_size,
int weight_nbit,
bool has_weight_zeros,
bool has_bias) {
bool has_bias,
int nr,
int kr,
int sr) {
(void)weight_nbit; // unused
(void)has_weight_zeros; // unused
(void)has_bias; // unused
auto rhs_pack = get_rhs_packing();
return rhs_pack.get_rhs_packed_size(
n, k, nr, kr, sr, group_size, kai_datatype::kai_dt_bf16);
internal::adjust_n(n),
k,
nr,
kr,
sr,
group_size,
kai_datatype::kai_dt_bf16);
}

size_t packed_weights_offset(
int n_idx,
int k,
int group_size,
int weight_nbit,
bool has_weight_zeros,
bool has_bias,
int nr,
int kr,
int sr) {
(void)has_weight_zeros; // unused
(void)has_bias; // unused
auto rhs_pack = get_rhs_packing();
return rhs_pack.get_rhs_packed_offset(
n_idx, k, nr, kr, sr, group_size, kai_datatype::kai_dt_bf16);
}

template <int nr, int kr, int sr>
void prepare_weight_data(
void* weight_data,
void pack_weights(
void* packed_weights,
int n,
int k,
int group_size,
const int8_t* weight_qvals,
const float* weight_scales,
const int8_t* weight_zeros,
const float* bias) {
const float* bias,
int nr,
int kr,
int sr) {
if (group_size % 32 != 0) {
throw std::runtime_error(
"Group size must be a multiple of 32, but got group_size=" +
Expand Down Expand Up @@ -187,7 +236,7 @@ void prepare_weight_data(
reinterpret_cast<const uint16_t*>(weight_scales_bf16_padded.data()),
/*scale_stride=*/sizeof(uint16_t) *
(internal::roundup(k, group_size) / group_size),
/*rhs_packed=*/weight_data,
/*rhs_packed=*/packed_weights,
/*extra_bytes=*/0,
/*qparams=*/&qparams);
}
Expand Down Expand Up @@ -220,8 +269,8 @@ size_t get_preferred_alignement() {
int n, \
int k, \
int group_size, \
const void* weight_data, \
const void* activation_data, \
const void* packed_weights, \
const void* packed_activations, \
float clamp_min, \
float clamp_max, \
bool has_weight_zeros, \
Expand All @@ -235,11 +284,11 @@ size_t get_preferred_alignement() {
} \
get_ukernel().run_matmul( \
m, \
internal::adjust_n(n), \
n, \
k, \
group_size, \
activation_data, \
weight_data, \
packed_activations, \
packed_weights, \
output, \
/*dst_stride_row=*/output_m_stride * sizeof(float), \
/*dst_stride_col=*/sizeof(float), \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,21 @@ inline size_t packed_activations_offset(
return (m_idx / mr) * packed_activations_size_mr_rows;
}

template <int mr, int kr, int sr>
template <int mr_, int kr_, int sr_>
void pack_activations(
void* packed_activations,
int m,
int k,
int group_size,
const float* activations,
bool has_weight_zeros) {
activation_packing::pack_activations<mr, kr, sr>(
bool has_weight_zeros,
int mr,
int kr,
int sr) {
(void)mr; // unused
(void)kr; // unused
(void)sr; // unused
activation_packing::pack_activations<mr_, kr_, sr_>(
packed_activations, m, k, group_size, activations, has_weight_zeros);
}

Expand Down Expand Up @@ -93,7 +99,7 @@ inline size_t packed_weights_offset(
return (n_idx / nr) * packed_weights_size_nr_cols;
}

template <int weight_nbit, int nr, int kr, int sr>
template <int weight_nbit, int nr_, int kr_, int sr_>
void pack_weights(
void* packed_weights,
int n,
Expand All @@ -102,8 +108,14 @@ void pack_weights(
const int8_t* weight_qvals,
const float* weight_scales,
const int8_t* weight_zeros,
const float* bias) {
weight_packing::pack_weights<weight_nbit, nr, kr, sr>(
const float* bias,
int nr,
int kr,
int sr) {
(void)nr; // unused
(void)kr; // unused
(void)sr; // unused
weight_packing::pack_weights<weight_nbit, nr_, kr_, sr_>(
packed_weights,
n,
k,
Expand Down
Loading
Loading