Skip to content

Commit

Permalink
Faster AVX2 prompt processing for k-quants and IQ4_XS (#394)
Browse files Browse the repository at this point in the history
  • Loading branch information
ikawrakow authored May 7, 2024
1 parent 911d58f commit e6532f7
Show file tree
Hide file tree
Showing 3 changed files with 843 additions and 2 deletions.
53 changes: 51 additions & 2 deletions llama.cpp/ggml-quants.inc
Original file line number Diff line number Diff line change
Expand Up @@ -3341,8 +3341,57 @@ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int6
}
}

void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q8_K_reference(x, y, k);
void quantize_row_q8_K(const float * restrict x, void * restrict vy, int64_t k) {
#ifdef __AVX2__
assert(k % QK_K == 0);
const int nb = k / QK_K;
block_q8_K * y = vy;
const __m256 signBit = _mm256_set1_ps( -0.0f );
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
for (int i = 0; i < nb; i++) {
const float * xb = x + i*QK_K;
__m256 maxAbs = _mm256_setzero_ps();
const float * xx = xb;
for (int ib = 0; ib < QK_K/8; ++ib) {
const __m256 v = _mm256_loadu_ps(xx); xx += 8;
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps(signBit, v));
}
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
const float maxScalar = _mm_cvtss_f32( max4 );
const float d = maxScalar / 127.f;
y[i].d = d;
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
const __m256 mul = _mm256_set1_ps( id );
xx = xb;
int8_t * q8 = y[i].qs;
for (int ib = 0; ib < QK_K/32; ++ib) {
__m256 v0 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
__m256 v1 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
__m256 v2 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
__m256 v3 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
__m256i i0 = _mm256_cvtps_epi32( v0 );
__m256i i1 = _mm256_cvtps_epi32( v1 );
__m256i i2 = _mm256_cvtps_epi32( v2 );
__m256i i3 = _mm256_cvtps_epi32( v3 );
y[i].bsums[2*ib+0] = hsum_i32_8(_mm256_add_epi32(i0, i1));
y[i].bsums[2*ib+1] = hsum_i32_8(_mm256_add_epi32(i2, i3));
i0 = _mm256_packs_epi32( i0, i1 );
i2 = _mm256_packs_epi32( i2, i3 );
i0 = _mm256_packs_epi16( i0, i2 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
_mm256_storeu_si256((__m256i *)q8, i0);
q8 += 32;
}
}
#else
quantize_row_q8_K_reference(x, vy, k);
#endif
}

//===================================== Dot ptoducts =================================
Expand Down
Loading

0 comments on commit e6532f7

Please sign in to comment.