Skip to content

Commit 9753008

Browse files
SongXiaoXibluss
authored andcommitted
sgemm: Reduce unnecessary AVX register permutations
- Removed redundant `_mm256_permute2f128_ps` instructions for lane swapping. - Consolidated `bv_lh` usage for upper and lower halves, reducing the number of separate permutes. - Reordered final output assignments to match the expected layout directly, simplifying downstream processing. - This change reduces register pressure and improves instruction efficiency without altering the computation logic.
1 parent 301ebc5 commit 9753008

File tree

1 file changed

+19
-23
lines changed

1 file changed

+19
-23
lines changed

src/sgemm_kernel.rs

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -404,21 +404,17 @@ unsafe fn kernel_x86_avx<MA>(k: usize, alpha: T, a: *const T, b: *const T,
404404
let a1357 = _mm256_movehdup_ps(av); // Load: a1 a1 a3 a3 a5 a5 a7 a7
405405
let a3175 = _mm256_permute_ps(a1357, PERM32_2301);
406406

407-
let a4602 = _mm256_permute2f128_ps(a0246, a0246, PERM128_30);
408-
let a6420 = _mm256_permute2f128_ps(a2064, a2064, PERM128_30);
409-
410-
let a5713 = _mm256_permute2f128_ps(a1357, a1357, PERM128_30);
411-
let a7531 = _mm256_permute2f128_ps(a3175, a3175, PERM128_30);
407+
let bv_lh = _mm256_permute2f128_ps(bv, bv, PERM128_30);
412408

413409
ab[0] = MA::multiply_add(a0246, bv, ab[0]);
414410
ab[1] = MA::multiply_add(a2064, bv, ab[1]);
415-
ab[2] = MA::multiply_add(a4602, bv, ab[2]);
416-
ab[3] = MA::multiply_add(a6420, bv, ab[3]);
411+
ab[2] = MA::multiply_add(a0246, bv_lh, ab[2]);
412+
ab[3] = MA::multiply_add(a2064, bv_lh, ab[3]);
417413

418414
ab[4] = MA::multiply_add(a1357, bv, ab[4]);
419415
ab[5] = MA::multiply_add(a3175, bv, ab[5]);
420-
ab[6] = MA::multiply_add(a5713, bv, ab[6]);
421-
ab[7] = MA::multiply_add(a7531, bv, ab[7]);
416+
ab[6] = MA::multiply_add(a1357, bv_lh, ab[6]);
417+
ab[7] = MA::multiply_add(a3175, bv_lh, ab[7]);
422418

423419
if !is_last {
424420
a = a.add(MR);
@@ -441,19 +437,19 @@ unsafe fn kernel_x86_avx<MA>(k: usize, alpha: T, a: *const T, b: *const T,
441437

442438
let ab0246 = ab[0];
443439
let ab2064 = ab[1];
444-
let ab4602 = ab[2];
445-
let ab6420 = ab[3];
440+
let ab4602 = ab[2]; // reverse order
441+
let ab6420 = ab[3]; // reverse order
446442

447443
let ab1357 = ab[4];
448444
let ab3175 = ab[5];
449-
let ab5713 = ab[6];
450-
let ab7531 = ab[7];
445+
let ab5713 = ab[6]; // reverse order
446+
let ab7531 = ab[7]; // reverse order
451447

452448
const SHUF_0123: i32 = shuffle_mask!(3, 2, 1, 0);
453449
debug_assert_eq!(SHUF_0123, 0xE4);
454450

455-
const PERM128_03: i32 = permute2f128_mask!(3, 0);
456-
const PERM128_21: i32 = permute2f128_mask!(1, 2);
451+
const PERM128_02: i32 = permute2f128_mask!(2, 0);
452+
const PERM128_31: i32 = permute2f128_mask!(1, 3);
457453

458454
// No elements are "shuffled" in truth, they all stay at their index
459455
// but we combine vectors to de-stripe them.
@@ -480,17 +476,17 @@ unsafe fn kernel_x86_avx<MA>(k: usize, alpha: T, a: *const T, b: *const T,
480476
let ab5511 = _mm256_shuffle_ps(ab5713, ab7531, SHUF_0123);
481477
let ab7733 = _mm256_shuffle_ps(ab7531, ab5713, SHUF_0123);
482478

483-
let ab0000 = _mm256_permute2f128_ps(ab0044, ab4400, PERM128_03);
484-
let ab4444 = _mm256_permute2f128_ps(ab0044, ab4400, PERM128_21);
479+
let ab0000 = _mm256_permute2f128_ps(ab0044, ab4400, PERM128_02);
480+
let ab4444 = _mm256_permute2f128_ps(ab0044, ab4400, PERM128_31);
485481

486-
let ab2222 = _mm256_permute2f128_ps(ab2266, ab6622, PERM128_03);
487-
let ab6666 = _mm256_permute2f128_ps(ab2266, ab6622, PERM128_21);
482+
let ab2222 = _mm256_permute2f128_ps(ab2266, ab6622, PERM128_02);
483+
let ab6666 = _mm256_permute2f128_ps(ab2266, ab6622, PERM128_31);
488484

489-
let ab1111 = _mm256_permute2f128_ps(ab1155, ab5511, PERM128_03);
490-
let ab5555 = _mm256_permute2f128_ps(ab1155, ab5511, PERM128_21);
485+
let ab1111 = _mm256_permute2f128_ps(ab1155, ab5511, PERM128_02);
486+
let ab5555 = _mm256_permute2f128_ps(ab1155, ab5511, PERM128_31);
491487

492-
let ab3333 = _mm256_permute2f128_ps(ab3377, ab7733, PERM128_03);
493-
let ab7777 = _mm256_permute2f128_ps(ab3377, ab7733, PERM128_21);
488+
let ab3333 = _mm256_permute2f128_ps(ab3377, ab7733, PERM128_02);
489+
let ab7777 = _mm256_permute2f128_ps(ab3377, ab7733, PERM128_31);
494490

495491
ab[0] = ab0000;
496492
ab[1] = ab1111;

0 commit comments

Comments
 (0)