diff --git a/bruteforce_test.go b/bruteforce_test.go index 825fd7b..801e399 100644 --- a/bruteforce_test.go +++ b/bruteforce_test.go @@ -13,7 +13,7 @@ import ( /* cpu: 13th Gen Intel(R) Core(TM) i7-13700K -BenchmarkIndex/search-24 3030 383868 ns/op 265 B/op 2 allocs/op +BenchmarkIndex/search-24 3807 316587 ns/op 264 B/op 2 allocs/op */ func BenchmarkIndex(b *testing.B) { data, err := loadDataset() diff --git a/internal/cosine/cosine_apple.c b/internal/cosine/cosine_apple.c index 17746bd..aa05d5d 100644 --- a/internal/cosine/cosine_apple.c +++ b/internal/cosine/cosine_apple.c @@ -3,44 +3,27 @@ // Licensed under the MIT license. See LICENSE file in the project root for details. #include -#include #include void f32_cosine_distance(const float *x, const float *y, double *result, const uint64_t size) { - float32x4_t sum_xy = vdupq_n_f32(0.0f); // Sum of x * y - float32x4_t sum_xx = vdupq_n_f32(0.0f); // Sum of x * x - float32x4_t sum_yy = vdupq_n_f32(0.0f); // Sum of y * y - - uint64_t i; - for (i = 0; i + 3 < size; i += 4) { - float32x4_t x_vec = vld1q_f32(x + i); - float32x4_t y_vec = vld1q_f32(y + i); - - sum_xy = vmlaq_f32(sum_xy, x_vec, y_vec); - sum_xx = vmlaq_f32(sum_xx, x_vec, x_vec); - sum_yy = vmlaq_f32(sum_yy, y_vec, y_vec); - } - - // Sum the elements of the vectors - float dot_xy = vaddvq_f32(sum_xy); - float norm_x = vaddvq_f32(sum_xx); - float norm_y = vaddvq_f32(sum_yy); - - // Handle any remaining elements - for (; i < size; i++) { - dot_xy += x[i] * y[i]; - norm_x += x[i] * x[i]; - norm_y += y[i] * y[i]; + float sum_xy = 0.0f; + float sum_xx = 0.0f; + float sum_yy = 0.0f; + + #pragma clang loop vectorize(enable) interleave_count(4) + for (uint64_t i = 0; i < size; i++) { + sum_xy += x[i] * y[i]; // Sum of x * y + sum_xx += x[i] * x[i]; // Sum of x * x + sum_yy += y[i] * y[i]; // Sum of y * y } - - // Avoid division by zero - float denominator = sqrtf(norm_x) * sqrtf(norm_y); + // Calculate the final result + float denominator = sqrtf(sum_xx) * sqrtf(sum_yy); if (denominator == 0.0f) { *result = (double)0.0f; return; } - double cosine_similarity = (double)dot_xy / (double)denominator; + double cosine_similarity = (double)sum_xy / (double)denominator; *result = cosine_similarity; } \ No newline at end of file diff --git a/internal/cosine/cosine_avx.c b/internal/cosine/cosine_avx.c index e2756d7..aa05d5d 100644 --- a/internal/cosine/cosine_avx.c +++ b/internal/cosine/cosine_avx.c @@ -2,58 +2,28 @@ // Copyright (c) Roman Atachiants and contributors. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for details. - #include -#include #include void f32_cosine_distance(const float *x, const float *y, double *result, const uint64_t size) { - __m256 sum_xy = _mm256_setzero_ps(); // Sum of x * y - __m256 sum_xx = _mm256_setzero_ps(); // Sum of x * x - __m256 sum_yy = _mm256_setzero_ps(); // Sum of y * y - - uint64_t i; - for (i = 0; i <= size - 8; i += 8) { - __m256 x_vec = _mm256_loadu_ps(x + i); - __m256 y_vec = _mm256_loadu_ps(y + i); - - sum_xy = _mm256_fmadd_ps(x_vec, y_vec, sum_xy); // sum_xy += x_vec * y_vec - sum_xx = _mm256_fmadd_ps(x_vec, x_vec, sum_xx); // sum_xx += x_vec * x_vec - sum_yy = _mm256_fmadd_ps(y_vec, y_vec, sum_yy); // sum_yy += y_vec * y_vec - } - - // Sum elements of sum_xy - __m256 temp_xy = _mm256_hadd_ps(sum_xy, sum_xy); // Sum adjacent pairs - temp_xy = _mm256_hadd_ps(temp_xy, temp_xy); // Sum adjacent quadruples - __m128 sum_xy_128 = _mm_add_ps(_mm256_castps256_ps128(temp_xy), _mm256_extractf128_ps(temp_xy, 1)); - float dot_xy = _mm_cvtss_f32(sum_xy_128); // Extract final sum - - // Sum elements of sum_xx - __m256 temp_xx = _mm256_hadd_ps(sum_xx, sum_xx); - temp_xx = _mm256_hadd_ps(temp_xx, temp_xx); - __m128 sum_xx_128 = _mm_add_ps(_mm256_castps256_ps128(temp_xx), _mm256_extractf128_ps(temp_xx, 1)); - float norm_x = _mm_cvtss_f32(sum_xx_128); - - // Sum elements of sum_yy - __m256 temp_yy = _mm256_hadd_ps(sum_yy, sum_yy); - temp_yy = _mm256_hadd_ps(temp_yy, temp_yy); - __m128 sum_yy_128 = _mm_add_ps(_mm256_castps256_ps128(temp_yy), _mm256_extractf128_ps(temp_yy, 1)); - float norm_y = _mm_cvtss_f32(sum_yy_128); - - // Handle remaining elements (if any) - for (; i < size; i++) { - dot_xy += x[i] * y[i]; - norm_x += x[i] * x[i]; - norm_y += y[i] * y[i]; + float sum_xy = 0.0f; + float sum_xx = 0.0f; + float sum_yy = 0.0f; + + #pragma clang loop vectorize(enable) interleave_count(4) + for (uint64_t i = 0; i < size; i++) { + sum_xy += x[i] * y[i]; // Sum of x * y + sum_xx += x[i] * x[i]; // Sum of x * x + sum_yy += y[i] * y[i]; // Sum of y * y } - // Avoid division by zero - float denominator = sqrtf(norm_x) * sqrtf(norm_y); + // Calculate the final result + float denominator = sqrtf(sum_xx) * sqrtf(sum_yy); if (denominator == 0.0f) { *result = (double)0.0f; return; } - double cosine_similarity = (double)dot_xy / (double)denominator; + double cosine_similarity = (double)sum_xy / (double)denominator; *result = cosine_similarity; } \ No newline at end of file diff --git a/internal/cosine/cosine_neon.c b/internal/cosine/cosine_neon.c index 17746bd..aa05d5d 100644 --- a/internal/cosine/cosine_neon.c +++ b/internal/cosine/cosine_neon.c @@ -3,44 +3,27 @@ // Licensed under the MIT license. See LICENSE file in the project root for details. #include -#include #include void f32_cosine_distance(const float *x, const float *y, double *result, const uint64_t size) { - float32x4_t sum_xy = vdupq_n_f32(0.0f); // Sum of x * y - float32x4_t sum_xx = vdupq_n_f32(0.0f); // Sum of x * x - float32x4_t sum_yy = vdupq_n_f32(0.0f); // Sum of y * y - - uint64_t i; - for (i = 0; i + 3 < size; i += 4) { - float32x4_t x_vec = vld1q_f32(x + i); - float32x4_t y_vec = vld1q_f32(y + i); - - sum_xy = vmlaq_f32(sum_xy, x_vec, y_vec); - sum_xx = vmlaq_f32(sum_xx, x_vec, x_vec); - sum_yy = vmlaq_f32(sum_yy, y_vec, y_vec); - } - - // Sum the elements of the vectors - float dot_xy = vaddvq_f32(sum_xy); - float norm_x = vaddvq_f32(sum_xx); - float norm_y = vaddvq_f32(sum_yy); - - // Handle any remaining elements - for (; i < size; i++) { - dot_xy += x[i] * y[i]; - norm_x += x[i] * x[i]; - norm_y += y[i] * y[i]; + float sum_xy = 0.0f; + float sum_xx = 0.0f; + float sum_yy = 0.0f; + + #pragma clang loop vectorize(enable) interleave_count(4) + for (uint64_t i = 0; i < size; i++) { + sum_xy += x[i] * y[i]; // Sum of x * y + sum_xx += x[i] * x[i]; // Sum of x * x + sum_yy += y[i] * y[i]; // Sum of y * y } - - // Avoid division by zero - float denominator = sqrtf(norm_x) * sqrtf(norm_y); + // Calculate the final result + float denominator = sqrtf(sum_xx) * sqrtf(sum_yy); if (denominator == 0.0f) { *result = (double)0.0f; return; } - double cosine_similarity = (double)dot_xy / (double)denominator; + double cosine_similarity = (double)sum_xy / (double)denominator; *result = cosine_similarity; } \ No newline at end of file diff --git a/internal/cosine/simd/cosine_apple.s b/internal/cosine/simd/cosine_apple.s index 59f48a3..133caef 100644 --- a/internal/cosine/simd/cosine_apple.s +++ b/internal/cosine/simd/cosine_apple.s @@ -8,120 +8,109 @@ TEXT ·f32_cosine_distance(SB), $0-32 MOVD size+24(FP), R3 WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]! ; 16-byte Folded Spill WORD $0x910003fd // mov x29, sp - WORD $0xf100107f // cmp x3, #4 - WORD $0x54000223 // b.lo LBB0_4 - WORD $0x6f00e400 // movi.2d v0, #0000000000000000 - WORD $0x52800068 // mov w8, #3 - WORD $0xaa0103e9 // mov x9, x1 - WORD $0xaa0003ea // mov x10, x0 - WORD $0x6f00e401 // movi.2d v1, #0000000000000000 - WORD $0x6f00e402 // movi.2d v2, #0000000000000000 + WORD $0x2f00e400 // movi d0, #0000000000000000 + WORD $0xb4000103 // cbz x3, LBB0_3 + WORD $0xf100407f // cmp x3, #16 + WORD $0x54000182 // b.hs LBB0_4 + WORD $0xd2800008 // mov x8, #0 + WORD $0x2f00e401 // movi d1, #0000000000000000 + WORD $0x2f00e405 // movi d5, #0000000000000000 + WORD $0x2f00e411 // movi d17, #0000000000000000 + WORD $0x1400003a // b LBB0_7 -BB0_2: - WORD $0x3cc10543 // ldr q3, [x10], #16 - WORD $0x3cc10524 // ldr q4, [x9], #16 - WORD $0x4e23cc80 // fmla.4s v0, v4, v3 - WORD $0x4e23cc61 // fmla.4s v1, v3, v3 - WORD $0x4e24cc82 // fmla.4s v2, v4, v4 - WORD $0x91001108 // add x8, x8, #4 - WORD $0xeb03011f // cmp x8, x3 - WORD $0x54ffff23 // b.lo LBB0_2 - WORD $0x927ef46b // and x11, x3, #0xfffffffffffffffc - WORD $0x14000005 // b LBB0_5 +BB0_3: + WORD $0x2f00e402 // movi d2, #0000000000000000 + WORD $0x2f00e401 // movi d1, #0000000000000000 + WORD $0x1e21c042 // fsqrt s2, s2 + WORD $0x1e202048 // fcmp s2, #0.0 + WORD $0x540008a0 // b.eq LBB0_10 + WORD $0x14000047 // b LBB0_11 BB0_4: - WORD $0xd280000b // mov x11, #0 - WORD $0x6f00e402 // movi.2d v2, #0000000000000000 WORD $0x6f00e401 // movi.2d v1, #0000000000000000 - WORD $0x6f00e400 // movi.2d v0, #0000000000000000 - -BB0_5: - WORD $0x6e20d400 // faddp.4s v0, v0, v0 - WORD $0x7e30d800 // faddp.2s s0, v0 - WORD $0x6e21d421 // faddp.4s v1, v1, v1 - WORD $0x7e30d821 // faddp.2s s1, v1 - WORD $0x6e22d442 // faddp.4s v2, v2, v2 - WORD $0x7e30d842 // faddp.2s s2, v2 - WORD $0xeb03017f // cmp x11, x3 - WORD $0x54000702 // b.hs LBB0_13 - WORD $0xcb0b0069 // sub x9, x3, x11 - WORD $0xf100213f // cmp x9, #8 - WORD $0x54000062 // b.hs LBB0_8 - WORD $0xaa0b03e8 // mov x8, x11 - WORD $0x14000028 // b LBB0_11 - -BB0_8: - WORD $0x927df12a // and x10, x9, #0xfffffffffffffff8 - WORD $0x8b0a0168 // add x8, x11, x10 + WORD $0x927cec68 // and x8, x3, #0xfffffffffffffff0 + WORD $0x6f00e402 // movi.2d v2, #0000000000000000 + WORD $0x91008009 // add x9, x0, #32 WORD $0x6f00e403 // movi.2d v3, #0000000000000000 + WORD $0x9100802a // add x10, x1, #32 WORD $0x6f00e404 // movi.2d v4, #0000000000000000 - WORD $0x6e040444 // mov.s v4[0], v2[0] - WORD $0x6f00e402 // movi.2d v2, #0000000000000000 - WORD $0x6e040422 // mov.s v2[0], v1[0] - WORD $0x6f00e401 // movi.2d v1, #0000000000000000 - WORD $0x6e040401 // mov.s v1[0], v0[0] - WORD $0xd37ef56b // lsl x11, x11, #2 - WORD $0x9100416c // add x12, x11, #16 - WORD $0x8b0c000b // add x11, x0, x12 - WORD $0x8b0c002c // add x12, x1, x12 - WORD $0xaa0a03ed // mov x13, x10 + WORD $0xaa0803eb // mov x11, x8 WORD $0x6f00e405 // movi.2d v5, #0000000000000000 - WORD $0x6f00e400 // movi.2d v0, #0000000000000000 + WORD $0x6f00e406 // movi.2d v6, #0000000000000000 + WORD $0x6f00e407 // movi.2d v7, #0000000000000000 + WORD $0x6f00e410 // movi.2d v16, #0000000000000000 + WORD $0x6f00e411 // movi.2d v17, #0000000000000000 + WORD $0x6f00e412 // movi.2d v18, #0000000000000000 + WORD $0x6f00e413 // movi.2d v19, #0000000000000000 + WORD $0x6f00e414 // movi.2d v20, #0000000000000000 -BB0_9: - WORD $0xad7f9d66 // ldp q6, q7, [x11, #-16] - WORD $0xad7fc590 // ldp q16, q17, [x12, #-16] - WORD $0x4e26ce01 // fmla.4s v1, v16, v6 - WORD $0x4e27ce20 // fmla.4s v0, v17, v7 - WORD $0x4e26ccc2 // fmla.4s v2, v6, v6 - WORD $0x4e27cce5 // fmla.4s v5, v7, v7 - WORD $0x4e30ce04 // fmla.4s v4, v16, v16 - WORD $0x4e31ce23 // fmla.4s v3, v17, v17 - WORD $0x9100816b // add x11, x11, #32 - WORD $0x9100818c // add x12, x12, #32 - WORD $0xf10021ad // subs x13, x13, #8 - WORD $0x54fffea1 // b.ne LBB0_9 - WORD $0x4e21d400 // fadd.4s v0, v0, v1 - WORD $0x6e20d400 // faddp.4s v0, v0, v0 - WORD $0x7e30d800 // faddp.2s s0, v0 - WORD $0x4e22d4a1 // fadd.4s v1, v5, v2 +BB0_5: + WORD $0xad7f5935 // ldp q21, q22, [x9, #-32] + WORD $0xacc26137 // ldp q23, q24, [x9], #64 + WORD $0xad7f6959 // ldp q25, q26, [x10, #-32] + WORD $0xacc2715b // ldp q27, q28, [x10], #64 + WORD $0x4e35cf21 // fmla.4s v1, v25, v21 + WORD $0x4e36cf42 // fmla.4s v2, v26, v22 + WORD $0x4e37cf63 // fmla.4s v3, v27, v23 + WORD $0x4e38cf84 // fmla.4s v4, v28, v24 + WORD $0x4e35cea5 // fmla.4s v5, v21, v21 + WORD $0x4e36cec6 // fmla.4s v6, v22, v22 + WORD $0x4e37cee7 // fmla.4s v7, v23, v23 + WORD $0x4e38cf10 // fmla.4s v16, v24, v24 + WORD $0x4e39cf31 // fmla.4s v17, v25, v25 + WORD $0x4e3acf52 // fmla.4s v18, v26, v26 + WORD $0x4e3bcf73 // fmla.4s v19, v27, v27 + WORD $0x4e3ccf94 // fmla.4s v20, v28, v28 + WORD $0xf100416b // subs x11, x11, #16 + WORD $0x54fffde1 // b.ne LBB0_5 + WORD $0x4e31d651 // fadd.4s v17, v18, v17 + WORD $0x4e31d671 // fadd.4s v17, v19, v17 + WORD $0x4e31d691 // fadd.4s v17, v20, v17 + WORD $0x6e31d631 // faddp.4s v17, v17, v17 + WORD $0x7e30da31 // faddp.2s s17, v17 + WORD $0x4e25d4c5 // fadd.4s v5, v6, v5 + WORD $0x4e25d4e5 // fadd.4s v5, v7, v5 + WORD $0x4e25d605 // fadd.4s v5, v16, v5 + WORD $0x6e25d4a5 // faddp.4s v5, v5, v5 + WORD $0x7e30d8a5 // faddp.2s s5, v5 + WORD $0x4e21d441 // fadd.4s v1, v2, v1 + WORD $0x4e21d461 // fadd.4s v1, v3, v1 + WORD $0x4e21d481 // fadd.4s v1, v4, v1 WORD $0x6e21d421 // faddp.4s v1, v1, v1 WORD $0x7e30d821 // faddp.2s s1, v1 - WORD $0x4e24d462 // fadd.4s v2, v3, v4 - WORD $0x6e22d442 // faddp.4s v2, v2, v2 - WORD $0x7e30d842 // faddp.2s s2, v2 - WORD $0xeb0a013f // cmp x9, x10 - WORD $0x54000180 // b.eq LBB0_13 + WORD $0xeb03011f // cmp x8, x3 + WORD $0x54000180 // b.eq LBB0_9 -BB0_11: +BB0_7: WORD $0xcb080069 // sub x9, x3, x8 WORD $0xd37ef50a // lsl x10, x8, #2 WORD $0x8b0a0028 // add x8, x1, x10 WORD $0x8b0a000a // add x10, x0, x10 -BB0_12: - WORD $0xbc404543 // ldr s3, [x10], #4 - WORD $0xbc404504 // ldr s4, [x8], #4 - WORD $0x1f030080 // fmadd s0, s4, s3, s0 - WORD $0x1f030461 // fmadd s1, s3, s3, s1 - WORD $0x1f040882 // fmadd s2, s4, s4, s2 +BB0_8: + WORD $0xbc404542 // ldr s2, [x10], #4 + WORD $0xbc404503 // ldr s3, [x8], #4 + WORD $0x1f020461 // fmadd s1, s3, s2, s1 + WORD $0x1f021445 // fmadd s5, s2, s2, s5 + WORD $0x1f034471 // fmadd s17, s3, s3, s17 WORD $0xf1000529 // subs x9, x9, #1 - WORD $0x54ffff41 // b.ne LBB0_12 + WORD $0x54ffff41 // b.ne LBB0_8 -BB0_13: - WORD $0x1e210841 // fmul s1, s2, s1 - WORD $0x1e21c022 // fsqrt s2, s1 - WORD $0x2f00e401 // movi d1, #0000000000000000 +BB0_9: + WORD $0x1e3108a2 // fmul s2, s5, s17 + WORD $0x1e22c021 // fcvt d1, s1 + WORD $0x1e21c042 // fsqrt s2, s2 WORD $0x1e202048 // fcmp s2, #0.0 - WORD $0x54000081 // b.ne LBB0_15 - WORD $0xfd000041 // str d1, [x2] + WORD $0x54000081 // b.ne LBB0_11 + +BB0_10: + WORD $0xfd000040 // str d0, [x2] WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 ; 16-byte Folded Reload WORD $0xd65f03c0 // ret -BB0_15: - WORD $0x1e22c000 // fcvt d0, s0 - WORD $0x1e22c041 // fcvt d1, s2 - WORD $0x1e611801 // fdiv d1, d0, d1 - WORD $0xfd000041 // str d1, [x2] +BB0_11: + WORD $0x1e22c040 // fcvt d0, s2 + WORD $0x1e601820 // fdiv d0, d1, d0 + WORD $0xfd000040 // str d0, [x2] WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 ; 16-byte Folded Reload WORD $0xd65f03c0 // ret diff --git a/internal/cosine/simd/cosine_avx.s b/internal/cosine/simd/cosine_avx.s index dafc3e6..d7401b9 100644 --- a/internal/cosine/simd/cosine_avx.s +++ b/internal/cosine/simd/cosine_avx.s @@ -6,130 +6,134 @@ TEXT ·f32_cosine_distance(SB), $0-32 MOVQ y+8(FP), SI MOVQ result+16(FP), DX MOVQ size+24(FP), CX - BYTE $0x55 // push rbp - WORD $0x8948; BYTE $0xe5 // mov rbp, rsp - LONG $0xf8e48348 // and rsp, -8 - LONG $0xf8498d4c // lea r9, [rcx - 8] - LONG $0xc057f8c5 // vxorps xmm0, xmm0, xmm0 - LONG $0xe0c3c749; WORD $0xffff; BYTE $0xff // mov r11, -32 - WORD $0xc031 // xor eax, eax - WORD $0x8949; BYTE $0xc8 // mov r8, rcx - LONG $0xd257e8c5 // vxorps xmm2, xmm2, xmm2 - LONG $0xc957f0c5 // vxorps xmm1, xmm1, xmm1 + BYTE $0x55 // push rbp + WORD $0x8948; BYTE $0xe5 // mov rbp, rsp + LONG $0xe0e48348 // and rsp, -32 + LONG $0x60ec8348 // sub rsp, 96 + LONG $0xed57d0c5 // vxorps xmm5, xmm5, xmm5 + LONG $0xe457d8c5 // vxorps xmm4, xmm4, xmm4 + LONG $0xc957f0c5 // vxorps xmm1, xmm1, xmm1 + LONG $0xdb57e0c5 // vxorps xmm3, xmm3, xmm3 + WORD $0x8548; BYTE $0xc9 // test rcx, rcx + JE LBB0_8 + LONG $0x20f98348 // cmp rcx, 32 + JAE LBB0_4 + LONG $0xd257e8c5 // vxorps xmm2, xmm2, xmm2 + WORD $0x3145; BYTE $0xc0 // xor r8d, r8d + LONG $0xf657c8c5 // vxorps xmm6, xmm6, xmm6 + LONG $0xc957f0c5 // vxorps xmm1, xmm1, xmm1 + JMP LBB0_3 -LBB0_1: - LONG $0x1c10fcc5; BYTE $0x87 // vmovups ymm3, ymmword ptr [rdi + 4*rax] - LONG $0x2410fcc5; BYTE $0x86 // vmovups ymm4, ymmword ptr [rsi + 4*rax] - LONG $0xb865e2c4; BYTE $0xc4 // vfmadd231ps ymm0, ymm3, ymm4 - LONG $0xb865e2c4; BYTE $0xd3 // vfmadd231ps ymm2, ymm3, ymm3 - LONG $0xb85de2c4; BYTE $0xcc // vfmadd231ps ymm1, ymm4, ymm4 - LONG $0x08c08348 // add rax, 8 - LONG $0xe0c38349 // add r11, -32 - LONG $0xf8c08349 // add r8, -8 - WORD $0x394c; BYTE $0xc8 // cmp rax, r9 - JBE LBB0_1 - LONG $0xc07cffc5 // vhaddps ymm0, ymm0, ymm0 - LONG $0xc07cffc5 // vhaddps ymm0, ymm0, ymm0 - LONG $0xd27cefc5 // vhaddps ymm2, ymm2, ymm2 - LONG $0x197de3c4; WORD $0x01c3 // vextractf128 xmm3, ymm0, 1 - LONG $0xd27cefc5 // vhaddps ymm2, ymm2, ymm2 - LONG $0xc058e2c5 // vaddss xmm0, xmm3, xmm0 - LONG $0xc97cf7c5 // vhaddps ymm1, ymm1, ymm1 - LONG $0x197de3c4; WORD $0x01d3 // vextractf128 xmm3, ymm2, 1 - LONG $0xe17cf7c5 // vhaddps ymm4, ymm1, ymm1 - LONG $0xca58e2c5 // vaddss xmm1, xmm3, xmm2 - LONG $0x197de3c4; WORD $0x01e2 // vextractf128 xmm2, ymm4, 1 - LONG $0xd458eac5 // vaddss xmm2, xmm2, xmm4 - WORD $0x3948; BYTE $0xc8 // cmp rax, rcx - JAE LBB0_9 - WORD $0x8949; BYTE $0xc9 // mov r9, rcx - WORD $0x2949; BYTE $0xc1 // sub r9, rax - LONG $0x10f98349 // cmp r9, 16 - JAE LBB0_5 - WORD $0x8949; BYTE $0xc0 // mov r8, rax - JMP LBB0_8 +LBB0_4: + WORD $0x8949; BYTE $0xc8 // mov r8, rcx + LONG $0xe0e08349 // and r8, -32 + LONG $0xc057f8c5 // vxorps xmm0, xmm0, xmm0 + LONG $0x0429fcc5; BYTE $0x24 // vmovaps ymmword ptr [rsp], ymm0 + WORD $0xc031 // xor eax, eax + LONG $0xdb57e0c5 // vxorps xmm3, xmm3, xmm3 + LONG $0xe457d8c5 // vxorps xmm4, xmm4, xmm4 + LONG $0xed57d0c5 // vxorps xmm5, xmm5, xmm5 + LONG $0xf657c8c5 // vxorps xmm6, xmm6, xmm6 + LONG $0xff57c0c5 // vxorps xmm7, xmm7, xmm7 + LONG $0x573841c4; BYTE $0xc0 // vxorps xmm8, xmm8, xmm8 + LONG $0x573041c4; BYTE $0xc9 // vxorps xmm9, xmm9, xmm9 + LONG $0xc957f0c5 // vxorps xmm1, xmm1, xmm1 + LONG $0x572041c4; BYTE $0xdb // vxorps xmm11, xmm11, xmm11 + LONG $0x571841c4; BYTE $0xe4 // vxorps xmm12, xmm12, xmm12 + LONG $0x571041c4; BYTE $0xed // vxorps xmm13, xmm13, xmm13 LBB0_5: - WORD $0xf749; BYTE $0xdb // neg r11 - WORD $0x894d; BYTE $0xca // mov r10, r9 - LONG $0xf0e28349 // and r10, -16 - LONG $0xf0e08349 // and r8, -16 - WORD $0x0149; BYTE $0xc0 // add r8, rax - LONG $0xdb57e0c5 // vxorps xmm3, xmm3, xmm3 - LONG $0x0c61e3c4; WORD $0x01d2 // vblendps xmm2, xmm3, xmm2, 1 - LONG $0x0c61e3c4; WORD $0x01c9 // vblendps xmm1, xmm3, xmm1, 1 - LONG $0x0c61e3c4; WORD $0x01c0 // vblendps xmm0, xmm3, xmm0, 1 - LONG $0xdb57e0c5 // vxorps xmm3, xmm3, xmm3 - WORD $0x894c; BYTE $0xd0 // mov rax, r10 + LONG $0x6c29fcc5; WORD $0x2024 // vmovaps ymmword ptr [rsp + 32], ymm5 + LONG $0x34107cc5; BYTE $0x87 // vmovups ymm14, ymmword ptr [rdi + 4*rax] + LONG $0x7c107cc5; WORD $0x2087 // vmovups ymm15, ymmword ptr [rdi + 4*rax + 32] + LONG $0x54107cc5; WORD $0x4087 // vmovups ymm10, ymmword ptr [rdi + 4*rax + 64] + LONG $0x4410fcc5; WORD $0x6087 // vmovups ymm0, ymmword ptr [rdi + 4*rax + 96] + LONG $0x1410fcc5; BYTE $0x86 // vmovups ymm2, ymmword ptr [rsi + 4*rax] + LONG $0xec28fcc5 // vmovaps ymm5, ymm4 + LONG $0xe328fcc5 // vmovaps ymm4, ymm3 + LONG $0x1c28fcc5; BYTE $0x24 // vmovaps ymm3, ymmword ptr [rsp] + LONG $0xb86dc2c4; BYTE $0xde // vfmadd231ps ymm3, ymm2, ymm14 + LONG $0x1c29fcc5; BYTE $0x24 // vmovaps ymmword ptr [rsp], ymm3 + LONG $0xdc28fcc5 // vmovaps ymm3, ymm4 + LONG $0xe528fcc5 // vmovaps ymm4, ymm5 + LONG $0x6c28fcc5; WORD $0x2024 // vmovaps ymm5, ymmword ptr [rsp + 32] + LONG $0xb80dc2c4; BYTE $0xf6 // vfmadd231ps ymm6, ymm14, ymm14 + LONG $0x74107cc5; WORD $0x2086 // vmovups ymm14, ymmword ptr [rsi + 4*rax + 32] + LONG $0xb80dc2c4; BYTE $0xdf // vfmadd231ps ymm3, ymm14, ymm15 + LONG $0xb805c2c4; BYTE $0xff // vfmadd231ps ymm7, ymm15, ymm15 + LONG $0x7c107cc5; WORD $0x4086 // vmovups ymm15, ymmword ptr [rsi + 4*rax + 64] + LONG $0xb805c2c4; BYTE $0xe2 // vfmadd231ps ymm4, ymm15, ymm10 + LONG $0xb82d42c4; BYTE $0xc2 // vfmadd231ps ymm8, ymm10, ymm10 + LONG $0x54107cc5; WORD $0x6086 // vmovups ymm10, ymmword ptr [rsi + 4*rax + 96] + LONG $0xb82de2c4; BYTE $0xe8 // vfmadd231ps ymm5, ymm10, ymm0 + LONG $0xb87d62c4; BYTE $0xc8 // vfmadd231ps ymm9, ymm0, ymm0 + LONG $0xb86de2c4; BYTE $0xca // vfmadd231ps ymm1, ymm2, ymm2 + LONG $0xb80d42c4; BYTE $0xde // vfmadd231ps ymm11, ymm14, ymm14 + LONG $0xb80542c4; BYTE $0xe7 // vfmadd231ps ymm12, ymm15, ymm15 + LONG $0xb82d42c4; BYTE $0xea // vfmadd231ps ymm13, ymm10, ymm10 + LONG $0x20c08348 // add rax, 32 + WORD $0x3949; BYTE $0xc0 // cmp r8, rax + JNE LBB0_5 + LONG $0xc158a4c5 // vaddps ymm0, ymm11, ymm1 + LONG $0xc0589cc5 // vaddps ymm0, ymm12, ymm0 + LONG $0xc05894c5 // vaddps ymm0, ymm13, ymm0 + LONG $0x197de3c4; WORD $0x01c1 // vextractf128 xmm1, ymm0, 1 + LONG $0xc158f8c5 // vaddps xmm0, xmm0, xmm1 + LONG $0x0579e3c4; WORD $0x01c8 // vpermilpd xmm1, xmm0, 1 + LONG $0xc158f8c5 // vaddps xmm0, xmm0, xmm1 + LONG $0xc816fac5 // vmovshdup xmm1, xmm0 + LONG $0xc958fac5 // vaddss xmm1, xmm0, xmm1 + LONG $0xc658c4c5 // vaddps ymm0, ymm7, ymm6 + LONG $0xc058bcc5 // vaddps ymm0, ymm8, ymm0 + LONG $0xc058b4c5 // vaddps ymm0, ymm9, ymm0 + LONG $0x197de3c4; WORD $0x01c2 // vextractf128 xmm2, ymm0, 1 + LONG $0xc258f8c5 // vaddps xmm0, xmm0, xmm2 + LONG $0x0579e3c4; WORD $0x01d0 // vpermilpd xmm2, xmm0, 1 + LONG $0xc258f8c5 // vaddps xmm0, xmm0, xmm2 + LONG $0xd016fac5 // vmovshdup xmm2, xmm0 + LONG $0xf258fac5 // vaddss xmm6, xmm0, xmm2 + LONG $0x0458e4c5; BYTE $0x24 // vaddps ymm0, ymm3, ymmword ptr [rsp] + LONG $0xc058dcc5 // vaddps ymm0, ymm4, ymm0 + LONG $0xc058d4c5 // vaddps ymm0, ymm5, ymm0 + LONG $0x197de3c4; WORD $0x01c2 // vextractf128 xmm2, ymm0, 1 + LONG $0xc258f8c5 // vaddps xmm0, xmm0, xmm2 + LONG $0x0579e3c4; WORD $0x01d0 // vpermilpd xmm2, xmm0, 1 + LONG $0xc258f8c5 // vaddps xmm0, xmm0, xmm2 + LONG $0xd016fac5 // vmovshdup xmm2, xmm0 + LONG $0xd258fac5 // vaddss xmm2, xmm0, xmm2 + WORD $0x3949; BYTE $0xc8 // cmp r8, rcx LONG $0xe457d8c5 // vxorps xmm4, xmm4, xmm4 LONG $0xed57d0c5 // vxorps xmm5, xmm5, xmm5 + JE LBB0_7 -LBB0_6: - LONG $0x107ca1c4; WORD $0x1f74; BYTE $0xe0 // vmovups ymm6, ymmword ptr [rdi + r11 - 32] - LONG $0x107ca1c4; WORD $0x1f3c // vmovups ymm7, ymmword ptr [rdi + r11] - LONG $0x107c21c4; WORD $0x1e44; BYTE $0xe0 // vmovups ymm8, ymmword ptr [rsi + r11 - 32] - LONG $0x107c21c4; WORD $0x1e0c // vmovups ymm9, ymmword ptr [rsi + r11] - LONG $0xb83de2c4; BYTE $0xc6 // vfmadd231ps ymm0, ymm8, ymm6 - LONG $0xb835e2c4; BYTE $0xef // vfmadd231ps ymm5, ymm9, ymm7 - LONG $0xb84de2c4; BYTE $0xce // vfmadd231ps ymm1, ymm6, ymm6 - LONG $0xb845e2c4; BYTE $0xe7 // vfmadd231ps ymm4, ymm7, ymm7 - LONG $0xb83dc2c4; BYTE $0xd0 // vfmadd231ps ymm2, ymm8, ymm8 - LONG $0xb835c2c4; BYTE $0xd9 // vfmadd231ps ymm3, ymm9, ymm9 - LONG $0x40c38349 // add r11, 64 - LONG $0xf0c08348 // add rax, -16 - JNE LBB0_6 - LONG $0xc058d4c5 // vaddps ymm0, ymm5, ymm0 - LONG $0x197de3c4; WORD $0x01c5 // vextractf128 xmm5, ymm0, 1 - LONG $0xc558f8c5 // vaddps xmm0, xmm0, xmm5 - LONG $0x0579e3c4; WORD $0x01e8 // vpermilpd xmm5, xmm0, 1 - LONG $0xc558f8c5 // vaddps xmm0, xmm0, xmm5 - LONG $0xe816fac5 // vmovshdup xmm5, xmm0 - LONG $0xc558fac5 // vaddss xmm0, xmm0, xmm5 - LONG $0xc958dcc5 // vaddps ymm1, ymm4, ymm1 - LONG $0x197de3c4; WORD $0x01cc // vextractf128 xmm4, ymm1, 1 - LONG $0xcc58f0c5 // vaddps xmm1, xmm1, xmm4 - LONG $0x0579e3c4; WORD $0x01e1 // vpermilpd xmm4, xmm1, 1 - LONG $0xcc58f0c5 // vaddps xmm1, xmm1, xmm4 - LONG $0xe116fac5 // vmovshdup xmm4, xmm1 - LONG $0xcc58f2c5 // vaddss xmm1, xmm1, xmm4 - LONG $0xd258e4c5 // vaddps ymm2, ymm3, ymm2 - LONG $0x197de3c4; WORD $0x01d3 // vextractf128 xmm3, ymm2, 1 - LONG $0xd358e8c5 // vaddps xmm2, xmm2, xmm3 - LONG $0x0579e3c4; WORD $0x01da // vpermilpd xmm3, xmm2, 1 - LONG $0xd358e8c5 // vaddps xmm2, xmm2, xmm3 - LONG $0xda16fac5 // vmovshdup xmm3, xmm2 - LONG $0xd358eac5 // vaddss xmm2, xmm2, xmm3 - WORD $0x394d; BYTE $0xd1 // cmp r9, r10 - JE LBB0_9 - -LBB0_8: - LONG $0x107aa1c4; WORD $0x871c // vmovss xmm3, dword ptr [rdi + 4*r8] - LONG $0x107aa1c4; WORD $0x8624 // vmovss xmm4, dword ptr [rsi + 4*r8] - LONG $0xb959e2c4; BYTE $0xc3 // vfmadd231ss xmm0, xmm4, xmm3 +LBB0_3: + LONG $0x107aa1c4; WORD $0x8704 // vmovss xmm0, dword ptr [rdi + 4*r8] + LONG $0x107aa1c4; WORD $0x861c // vmovss xmm3, dword ptr [rsi + 4*r8] + LONG $0xb961e2c4; BYTE $0xd0 // vfmadd231ss xmm2, xmm3, xmm0 + LONG $0xb979e2c4; BYTE $0xf0 // vfmadd231ss xmm6, xmm0, xmm0 LONG $0xb961e2c4; BYTE $0xcb // vfmadd231ss xmm1, xmm3, xmm3 - LONG $0xb959e2c4; BYTE $0xd4 // vfmadd231ss xmm2, xmm4, xmm4 WORD $0xff49; BYTE $0xc0 // inc r8 WORD $0x394c; BYTE $0xc1 // cmp rcx, r8 - JNE LBB0_8 + JNE LBB0_3 -LBB0_9: - LONG $0xc959eac5 // vmulss xmm1, xmm2, xmm1 - LONG $0xd151f2c5 // vsqrtss xmm2, xmm1, xmm1 - LONG $0xc957f0c5 // vxorps xmm1, xmm1, xmm1 - LONG $0xdb57e0c5 // vxorps xmm3, xmm3, xmm3 - LONG $0xd32ef8c5 // vucomiss xmm2, xmm3 - JNE LBB0_10 - LONG $0x0a11fbc5 // vmovsd qword ptr [rdx], xmm1 +LBB0_7: + LONG $0xd959cac5 // vmulss xmm3, xmm6, xmm1 + LONG $0xca5aeac5 // vcvtss2sd xmm1, xmm2, xmm2 + +LBB0_8: + LONG $0xd351e2c5 // vsqrtss xmm2, xmm3, xmm3 + LONG $0xd52ef8c5 // vucomiss xmm2, xmm5 + JNE LBB0_9 + LONG $0x2211fbc5 // vmovsd qword ptr [rdx], xmm4 WORD $0x8948; BYTE $0xec // mov rsp, rbp BYTE $0x5d // pop rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper BYTE $0xc3 // ret -LBB0_10: - LONG $0xc05afac5 // vcvtss2sd xmm0, xmm0, xmm0 - LONG $0xca5aeac5 // vcvtss2sd xmm1, xmm2, xmm2 - LONG $0xc95efbc5 // vdivsd xmm1, xmm0, xmm1 - LONG $0x0a11fbc5 // vmovsd qword ptr [rdx], xmm1 +LBB0_9: + LONG $0xc25aeac5 // vcvtss2sd xmm0, xmm2, xmm2 + LONG $0xe05ef3c5 // vdivsd xmm4, xmm1, xmm0 + LONG $0x2211fbc5 // vmovsd qword ptr [rdx], xmm4 WORD $0x8948; BYTE $0xec // mov rsp, rbp BYTE $0x5d // pop rbp WORD $0xf8c5; BYTE $0x77 // vzeroupper diff --git a/internal/cosine/simd/cosine_neon.s b/internal/cosine/simd/cosine_neon.s index b3df4b5..767c1e8 100644 --- a/internal/cosine/simd/cosine_neon.s +++ b/internal/cosine/simd/cosine_neon.s @@ -7,117 +7,110 @@ TEXT ·f32_cosine_distance(SB), $0-32 MOVD result+16(FP), R2 MOVD size+24(FP), R3 WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]! - WORD $0xf100107f // cmp x3, #4 + WORD $0x2f00e400 // movi d0, #0000000000000000 WORD $0x910003fd // mov x29, sp - WORD $0x54000223 // b.lo .LBB0_4 - WORD $0x6f00e401 // movi v1.2d, #0000000000000000 - WORD $0x52800068 // mov w8, #3 - WORD $0x6f00e402 // movi v2.2d, #0000000000000000 - WORD $0xaa0103e9 // mov x9, x1 - WORD $0x6f00e400 // movi v0.2d, #0000000000000000 - WORD $0xaa0003ea // mov x10, x0 + WORD $0xb4000103 // cbz x3, .LBB0_3 + WORD $0xf100407f // cmp x3, #16 + WORD $0x54000182 // b.hs .LBB0_4 + WORD $0x2f00e401 // movi d1, #0000000000000000 + WORD $0x2f00e402 // movi d2, #0000000000000000 + WORD $0x2f00e403 // movi d3, #0000000000000000 + WORD $0xaa1f03e8 // mov x8, xzr + WORD $0x1400003a // b .LBB0_7 -LBB0_2: - WORD $0x3cc10543 // ldr q3, [x10], #16 - WORD $0x3cc10524 // ldr q4, [x9], #16 - WORD $0x91001108 // add x8, x8, #4 - WORD $0x4e23cc62 // fmla v2.4s, v3.4s, v3.4s - WORD $0xeb03011f // cmp x8, x3 - WORD $0x4e23cc81 // fmla v1.4s, v4.4s, v3.4s - WORD $0x4e24cc80 // fmla v0.4s, v4.4s, v4.4s - WORD $0x54ffff23 // b.lo .LBB0_2 - WORD $0x927ef468 // and x8, x3, #0xfffffffffffffffc - WORD $0x14000005 // b .LBB0_5 +LBB0_3: + WORD $0x2f00e402 // movi d2, #0000000000000000 + WORD $0x2f00e401 // movi d1, #0000000000000000 + WORD $0x1e21c042 // fsqrt s2, s2 + WORD $0x1e202048 // fcmp s2, #0.0 + WORD $0x540008a0 // b.eq .LBB0_10 + WORD $0x14000047 // b .LBB0_11 LBB0_4: - WORD $0x6f00e400 // movi v0.2d, #0000000000000000 - WORD $0xaa1f03e8 // mov x8, xzr - WORD $0x6f00e402 // movi v2.2d, #0000000000000000 + WORD $0x927cec68 // and x8, x3, #0xfffffffffffffff0 + WORD $0x91008009 // add x9, x0, #32 WORD $0x6f00e401 // movi v1.2d, #0000000000000000 - -LBB0_5: - WORD $0x6e21d421 // faddp v1.4s, v1.4s, v1.4s - WORD $0xeb03011f // cmp x8, x3 - WORD $0x6e22d442 // faddp v2.4s, v2.4s, v2.4s - WORD $0x6e20d403 // faddp v3.4s, v0.4s, v0.4s - WORD $0x7e30d820 // faddp s0, v1.2s - WORD $0x7e30d841 // faddp s1, v2.2s - WORD $0x7e30d862 // faddp s2, v3.2s - WORD $0x540006c2 // b.hs .LBB0_12 - WORD $0xcb080069 // sub x9, x3, x8 - WORD $0xf100213f // cmp x9, #8 - WORD $0x54000503 // b.lo .LBB0_10 + WORD $0x9100802a // add x10, x1, #32 + WORD $0x6f00e402 // movi v2.2d, #0000000000000000 + WORD $0xaa0803eb // mov x11, x8 WORD $0x6f00e403 // movi v3.2d, #0000000000000000 - WORD $0xd37ef50b // lsl x11, x8, #2 WORD $0x6f00e404 // movi v4.2d, #0000000000000000 - WORD $0x927df12a // and x10, x9, #0xfffffffffffffff8 + WORD $0x6f00e411 // movi v17.2d, #0000000000000000 + WORD $0x6f00e412 // movi v18.2d, #0000000000000000 WORD $0x6f00e405 // movi v5.2d, #0000000000000000 - WORD $0x9100416c // add x12, x11, #16 - WORD $0x8b0a0108 // add x8, x8, x10 - WORD $0x8b0c000b // add x11, x0, x12 - WORD $0x6e040443 // mov v3.s[0], v2.s[0] - WORD $0x8b0c002c // add x12, x1, x12 - WORD $0x6f00e402 // movi v2.2d, #0000000000000000 - WORD $0xaa0a03ed // mov x13, x10 - WORD $0x6e040424 // mov v4.s[0], v1.s[0] - WORD $0x6e040405 // mov v5.s[0], v0.s[0] - WORD $0x6f00e400 // movi v0.2d, #0000000000000000 - WORD $0x6f00e401 // movi v1.2d, #0000000000000000 + WORD $0x6f00e406 // movi v6.2d, #0000000000000000 + WORD $0x6f00e413 // movi v19.2d, #0000000000000000 + WORD $0x6f00e414 // movi v20.2d, #0000000000000000 + WORD $0x6f00e407 // movi v7.2d, #0000000000000000 + WORD $0x6f00e410 // movi v16.2d, #0000000000000000 -LBB0_8: - WORD $0xad7f9d66 // ldp q6, q7, [x11, #-16] - WORD $0xf10021ad // subs x13, x13, #8 - WORD $0x9100816b // add x11, x11, #32 - WORD $0x4e26ccc4 // fmla v4.4s, v6.4s, v6.4s - WORD $0xad7fc590 // ldp q16, q17, [x12, #-16] - WORD $0x4e27cce0 // fmla v0.4s, v7.4s, v7.4s - WORD $0x9100818c // add x12, x12, #32 - WORD $0x4e26ce05 // fmla v5.4s, v16.4s, v6.4s - WORD $0x4e30ce03 // fmla v3.4s, v16.4s, v16.4s - WORD $0x4e27ce21 // fmla v1.4s, v17.4s, v7.4s - WORD $0x4e31ce22 // fmla v2.4s, v17.4s, v17.4s - WORD $0x54fffea1 // b.ne .LBB0_8 - WORD $0x4e25d421 // fadd v1.4s, v1.4s, v5.4s - WORD $0xeb0a013f // cmp x9, x10 - WORD $0x4e24d400 // fadd v0.4s, v0.4s, v4.4s - WORD $0x4e23d442 // fadd v2.4s, v2.4s, v3.4s - WORD $0x6e21d421 // faddp v1.4s, v1.4s, v1.4s - WORD $0x6e20d403 // faddp v3.4s, v0.4s, v0.4s +LBB0_5: + WORD $0xad7f5935 // ldp q21, q22, [x9, #-32] + WORD $0xf100416b // subs x11, x11, #16 + WORD $0x4e35ceb1 // fmla v17.4s, v21.4s, v21.4s + WORD $0xacc26137 // ldp q23, q24, [x9], #64 + WORD $0x4e36ced2 // fmla v18.4s, v22.4s, v22.4s + WORD $0x4e37cee5 // fmla v5.4s, v23.4s, v23.4s + WORD $0xad7f6959 // ldp q25, q26, [x10, #-32] + WORD $0x4e38cf06 // fmla v6.4s, v24.4s, v24.4s + WORD $0x4e35cf21 // fmla v1.4s, v25.4s, v21.4s + WORD $0x4e39cf33 // fmla v19.4s, v25.4s, v25.4s + WORD $0xacc2715b // ldp q27, q28, [x10], #64 + WORD $0x4e36cf42 // fmla v2.4s, v26.4s, v22.4s + WORD $0x4e3acf54 // fmla v20.4s, v26.4s, v26.4s + WORD $0x4e37cf63 // fmla v3.4s, v27.4s, v23.4s + WORD $0x4e3bcf67 // fmla v7.4s, v27.4s, v27.4s + WORD $0x4e38cf84 // fmla v4.4s, v28.4s, v24.4s + WORD $0x4e3ccf90 // fmla v16.4s, v28.4s, v28.4s + WORD $0x54fffde1 // b.ne .LBB0_5 + WORD $0x4e33d693 // fadd v19.4s, v20.4s, v19.4s + WORD $0xeb03011f // cmp x8, x3 + WORD $0x4e31d651 // fadd v17.4s, v18.4s, v17.4s + WORD $0x4e21d441 // fadd v1.4s, v2.4s, v1.4s + WORD $0x4e33d4e2 // fadd v2.4s, v7.4s, v19.4s + WORD $0x4e31d4a5 // fadd v5.4s, v5.4s, v17.4s + WORD $0x4e21d461 // fadd v1.4s, v3.4s, v1.4s + WORD $0x4e22d602 // fadd v2.4s, v16.4s, v2.4s + WORD $0x4e25d4c3 // fadd v3.4s, v6.4s, v5.4s + WORD $0x4e21d481 // fadd v1.4s, v4.4s, v1.4s WORD $0x6e22d442 // faddp v2.4s, v2.4s, v2.4s - WORD $0x7e30d820 // faddp s0, v1.2s - WORD $0x7e30d861 // faddp s1, v3.2s - WORD $0x7e30d842 // faddp s2, v2.2s - WORD $0x54000180 // b.eq .LBB0_12 + WORD $0x6e23d464 // faddp v4.4s, v3.4s, v3.4s + WORD $0x6e21d421 // faddp v1.4s, v1.4s, v1.4s + WORD $0x7e30d843 // faddp s3, v2.2s + WORD $0x7e30d882 // faddp s2, v4.2s + WORD $0x7e30d821 // faddp s1, v1.2s + WORD $0x54000180 // b.eq .LBB0_9 -LBB0_10: +LBB0_7: WORD $0xd37ef50a // lsl x10, x8, #2 WORD $0xcb080069 // sub x9, x3, x8 WORD $0x8b0a0028 // add x8, x1, x10 WORD $0x8b0a000a // add x10, x0, x10 -LBB0_11: - WORD $0xbc404543 // ldr s3, [x10], #4 - WORD $0xbc404504 // ldr s4, [x8], #4 +LBB0_8: + WORD $0xbc404544 // ldr s4, [x10], #4 + WORD $0xbc404505 // ldr s5, [x8], #4 WORD $0xf1000529 // subs x9, x9, #1 - WORD $0x1f030461 // fmadd s1, s3, s3, s1 - WORD $0x1f030080 // fmadd s0, s4, s3, s0 WORD $0x1f040882 // fmadd s2, s4, s4, s2 - WORD $0x54ffff41 // b.ne .LBB0_11 + WORD $0x1f0404a1 // fmadd s1, s5, s4, s1 + WORD $0x1f050ca3 // fmadd s3, s5, s5, s3 + WORD $0x54ffff41 // b.ne .LBB0_8 -LBB0_12: - WORD $0x1e210841 // fmul s1, s2, s1 - WORD $0x1e21c022 // fsqrt s2, s1 - WORD $0x2f00e401 // movi d1, #0000000000000000 +LBB0_9: + WORD $0x1e230842 // fmul s2, s2, s3 + WORD $0x1e22c021 // fcvt d1, s1 + WORD $0x1e21c042 // fsqrt s2, s2 WORD $0x1e202048 // fcmp s2, #0.0 - WORD $0x54000081 // b.ne .LBB0_14 - WORD $0xfd000041 // str d1, [x2] + WORD $0x54000081 // b.ne .LBB0_11 + +LBB0_10: + WORD $0xfd000040 // str d0, [x2] WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 WORD $0xd65f03c0 // ret -LBB0_14: - WORD $0x1e22c000 // fcvt d0, s0 - WORD $0x1e22c041 // fcvt d1, s2 - WORD $0x1e611801 // fdiv d1, d0, d1 - WORD $0xfd000041 // str d1, [x2] +LBB0_11: + WORD $0x1e22c040 // fcvt d0, s2 + WORD $0x1e601820 // fdiv d0, d1, d0 + WORD $0xfd000040 // str d0, [x2] WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 WORD $0xd65f03c0 // ret diff --git a/internal/cosine/simd/simd.go b/internal/cosine/simd/simd.go index ecfdfd1..70059c1 100644 --- a/internal/cosine/simd/simd.go +++ b/internal/cosine/simd/simd.go @@ -42,17 +42,18 @@ func Cosine(a, b []float32) float64 { } // cosine calculates the cosine similarity between two vectors -func cosine(vec1, vec2 []float32) float64 { - var dotProduct, normA, normB float64 - for i := range vec1 { - dotProduct += float64(vec1[i] * vec2[i]) - normA += float64(vec1[i] * vec1[i]) - normB += float64(vec2[i] * vec2[i]) +func cosine(x, y []float32) float64 { + var sum_xy, sum_xx, sum_yy float64 + for i := range x { + sum_xy += float64(x[i] * y[i]) + sum_xx += float64(x[i] * x[i]) + sum_yy += float64(y[i] * y[i]) } - if normA == 0 || normB == 0 { + denominator := math.Sqrt(sum_xx) * math.Sqrt(sum_yy) + if denominator == 0 { return 0.0 } - return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB)) + return sum_xy / denominator } diff --git a/internal/cosine/simd/simd_test.go b/internal/cosine/simd/simd_test.go index 8b14867..b65cac8 100644 --- a/internal/cosine/simd/simd_test.go +++ b/internal/cosine/simd/simd_test.go @@ -9,8 +9,8 @@ import ( /* cpu: 13th Gen Intel(R) Core(TM) i7-13700K -BenchmarkCosine/std-24 19490612 66.05 ns/op 0 B/op 0 allocs/op -BenchmarkCosine/our-24 67442631 17.83 ns/op 0 B/op 0 allocs/op +BenchmarkCosine/std-24 15074694 80.61 ns/op 0 B/op 0 allocs/op +BenchmarkCosine/our-24 45370162 25.92 ns/op 0 B/op 0 allocs/op */ func BenchmarkCosine(b *testing.B) { x := randVec() diff --git a/llama_test.go b/llama_test.go index dc7dfff..0bf8963 100644 --- a/llama_test.go +++ b/llama_test.go @@ -12,7 +12,7 @@ import ( ) /* -BenchmarkLLM/encode-24 400 2641772 ns/op 18.00 tok/s 2024 B/op 11 allocs/op +BenchmarkLLM/encode-24 465 2305573 ns/op 2024 B/op 11 allocs/op */ func BenchmarkLLM(b *testing.B) { m := loadModel()