Skip to content

Commit

Permalink
auto-vectorize using clang instead
Browse files Browse the repository at this point in the history
  • Loading branch information
kelindar committed Oct 27, 2024
1 parent 9a112b2 commit 094c9fa
Show file tree
Hide file tree
Showing 10 changed files with 326 additions and 403 deletions.
2 changes: 1 addition & 1 deletion bruteforce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (

/*
cpu: 13th Gen Intel(R) Core(TM) i7-13700K
BenchmarkIndex/search-24 3030 383868 ns/op 265 B/op 2 allocs/op
BenchmarkIndex/search-24 3807 316587 ns/op 264 B/op 2 allocs/op
*/
func BenchmarkIndex(b *testing.B) {
data, err := loadDataset()
Expand Down
41 changes: 12 additions & 29 deletions internal/cosine/cosine_apple.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,27 @@
// Licensed under the MIT license. See LICENSE file in the project root for details.

#include <stdint.h>
#include <arm_neon.h>
#include <math.h>

void f32_cosine_distance(const float *x, const float *y, double *result, const uint64_t size) {
float32x4_t sum_xy = vdupq_n_f32(0.0f); // Sum of x * y
float32x4_t sum_xx = vdupq_n_f32(0.0f); // Sum of x * x
float32x4_t sum_yy = vdupq_n_f32(0.0f); // Sum of y * y

uint64_t i;
for (i = 0; i + 3 < size; i += 4) {
float32x4_t x_vec = vld1q_f32(x + i);
float32x4_t y_vec = vld1q_f32(y + i);

sum_xy = vmlaq_f32(sum_xy, x_vec, y_vec);
sum_xx = vmlaq_f32(sum_xx, x_vec, x_vec);
sum_yy = vmlaq_f32(sum_yy, y_vec, y_vec);
}

// Sum the elements of the vectors
float dot_xy = vaddvq_f32(sum_xy);
float norm_x = vaddvq_f32(sum_xx);
float norm_y = vaddvq_f32(sum_yy);

// Handle any remaining elements
for (; i < size; i++) {
dot_xy += x[i] * y[i];
norm_x += x[i] * x[i];
norm_y += y[i] * y[i];
float sum_xy = 0.0f;
float sum_xx = 0.0f;
float sum_yy = 0.0f;

#pragma clang loop vectorize(enable) interleave_count(4)
for (uint64_t i = 0; i < size; i++) {
sum_xy += x[i] * y[i]; // Sum of x * y
sum_xx += x[i] * x[i]; // Sum of x * x
sum_yy += y[i] * y[i]; // Sum of y * y
}


// Avoid division by zero
float denominator = sqrtf(norm_x) * sqrtf(norm_y);
// Calculate the final result
float denominator = sqrtf(sum_xx) * sqrtf(sum_yy);
if (denominator == 0.0f) {
*result = (double)0.0f;
return;
}

double cosine_similarity = (double)dot_xy / (double)denominator;
double cosine_similarity = (double)sum_xy / (double)denominator;
*result = cosine_similarity;
}
54 changes: 12 additions & 42 deletions internal/cosine/cosine_avx.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,58 +2,28 @@
// Copyright (c) Roman Atachiants and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for details.


#include <stdint.h>
#include <immintrin.h>
#include <math.h>

void f32_cosine_distance(const float *x, const float *y, double *result, const uint64_t size) {
__m256 sum_xy = _mm256_setzero_ps(); // Sum of x * y
__m256 sum_xx = _mm256_setzero_ps(); // Sum of x * x
__m256 sum_yy = _mm256_setzero_ps(); // Sum of y * y

uint64_t i;
for (i = 0; i <= size - 8; i += 8) {
__m256 x_vec = _mm256_loadu_ps(x + i);
__m256 y_vec = _mm256_loadu_ps(y + i);

sum_xy = _mm256_fmadd_ps(x_vec, y_vec, sum_xy); // sum_xy += x_vec * y_vec
sum_xx = _mm256_fmadd_ps(x_vec, x_vec, sum_xx); // sum_xx += x_vec * x_vec
sum_yy = _mm256_fmadd_ps(y_vec, y_vec, sum_yy); // sum_yy += y_vec * y_vec
}

// Sum elements of sum_xy
__m256 temp_xy = _mm256_hadd_ps(sum_xy, sum_xy); // Sum adjacent pairs
temp_xy = _mm256_hadd_ps(temp_xy, temp_xy); // Sum adjacent quadruples
__m128 sum_xy_128 = _mm_add_ps(_mm256_castps256_ps128(temp_xy), _mm256_extractf128_ps(temp_xy, 1));
float dot_xy = _mm_cvtss_f32(sum_xy_128); // Extract final sum

// Sum elements of sum_xx
__m256 temp_xx = _mm256_hadd_ps(sum_xx, sum_xx);
temp_xx = _mm256_hadd_ps(temp_xx, temp_xx);
__m128 sum_xx_128 = _mm_add_ps(_mm256_castps256_ps128(temp_xx), _mm256_extractf128_ps(temp_xx, 1));
float norm_x = _mm_cvtss_f32(sum_xx_128);

// Sum elements of sum_yy
__m256 temp_yy = _mm256_hadd_ps(sum_yy, sum_yy);
temp_yy = _mm256_hadd_ps(temp_yy, temp_yy);
__m128 sum_yy_128 = _mm_add_ps(_mm256_castps256_ps128(temp_yy), _mm256_extractf128_ps(temp_yy, 1));
float norm_y = _mm_cvtss_f32(sum_yy_128);

// Handle remaining elements (if any)
for (; i < size; i++) {
dot_xy += x[i] * y[i];
norm_x += x[i] * x[i];
norm_y += y[i] * y[i];
float sum_xy = 0.0f;
float sum_xx = 0.0f;
float sum_yy = 0.0f;

#pragma clang loop vectorize(enable) interleave_count(4)
for (uint64_t i = 0; i < size; i++) {
sum_xy += x[i] * y[i]; // Sum of x * y
sum_xx += x[i] * x[i]; // Sum of x * x
sum_yy += y[i] * y[i]; // Sum of y * y
}

// Avoid division by zero
float denominator = sqrtf(norm_x) * sqrtf(norm_y);
// Calculate the final result
float denominator = sqrtf(sum_xx) * sqrtf(sum_yy);
if (denominator == 0.0f) {
*result = (double)0.0f;
return;
}

double cosine_similarity = (double)dot_xy / (double)denominator;
double cosine_similarity = (double)sum_xy / (double)denominator;
*result = cosine_similarity;
}
41 changes: 12 additions & 29 deletions internal/cosine/cosine_neon.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,27 @@
// Licensed under the MIT license. See LICENSE file in the project root for details.

#include <stdint.h>
#include <arm_neon.h>
#include <math.h>

void f32_cosine_distance(const float *x, const float *y, double *result, const uint64_t size) {
float32x4_t sum_xy = vdupq_n_f32(0.0f); // Sum of x * y
float32x4_t sum_xx = vdupq_n_f32(0.0f); // Sum of x * x
float32x4_t sum_yy = vdupq_n_f32(0.0f); // Sum of y * y

uint64_t i;
for (i = 0; i + 3 < size; i += 4) {
float32x4_t x_vec = vld1q_f32(x + i);
float32x4_t y_vec = vld1q_f32(y + i);

sum_xy = vmlaq_f32(sum_xy, x_vec, y_vec);
sum_xx = vmlaq_f32(sum_xx, x_vec, x_vec);
sum_yy = vmlaq_f32(sum_yy, y_vec, y_vec);
}

// Sum the elements of the vectors
float dot_xy = vaddvq_f32(sum_xy);
float norm_x = vaddvq_f32(sum_xx);
float norm_y = vaddvq_f32(sum_yy);

// Handle any remaining elements
for (; i < size; i++) {
dot_xy += x[i] * y[i];
norm_x += x[i] * x[i];
norm_y += y[i] * y[i];
float sum_xy = 0.0f;
float sum_xx = 0.0f;
float sum_yy = 0.0f;

#pragma clang loop vectorize(enable) interleave_count(4)
for (uint64_t i = 0; i < size; i++) {
sum_xy += x[i] * y[i]; // Sum of x * y
sum_xx += x[i] * x[i]; // Sum of x * x
sum_yy += y[i] * y[i]; // Sum of y * y
}


// Avoid division by zero
float denominator = sqrtf(norm_x) * sqrtf(norm_y);
// Calculate the final result
float denominator = sqrtf(sum_xx) * sqrtf(sum_yy);
if (denominator == 0.0f) {
*result = (double)0.0f;
return;
}

double cosine_similarity = (double)dot_xy / (double)denominator;
double cosine_similarity = (double)sum_xy / (double)denominator;
*result = cosine_similarity;
}
175 changes: 82 additions & 93 deletions internal/cosine/simd/cosine_apple.s
Original file line number Diff line number Diff line change
Expand Up @@ -8,120 +8,109 @@ TEXT ·f32_cosine_distance(SB), $0-32
MOVD size+24(FP), R3
WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]! ; 16-byte Folded Spill
WORD $0x910003fd // mov x29, sp
WORD $0xf100107f // cmp x3, #4
WORD $0x54000223 // b.lo LBB0_4
WORD $0x6f00e400 // movi.2d v0, #0000000000000000
WORD $0x52800068 // mov w8, #3
WORD $0xaa0103e9 // mov x9, x1
WORD $0xaa0003ea // mov x10, x0
WORD $0x6f00e401 // movi.2d v1, #0000000000000000
WORD $0x6f00e402 // movi.2d v2, #0000000000000000
WORD $0x2f00e400 // movi d0, #0000000000000000
WORD $0xb4000103 // cbz x3, LBB0_3
WORD $0xf100407f // cmp x3, #16
WORD $0x54000182 // b.hs LBB0_4
WORD $0xd2800008 // mov x8, #0
WORD $0x2f00e401 // movi d1, #0000000000000000
WORD $0x2f00e405 // movi d5, #0000000000000000
WORD $0x2f00e411 // movi d17, #0000000000000000
WORD $0x1400003a // b LBB0_7

BB0_2:
WORD $0x3cc10543 // ldr q3, [x10], #16
WORD $0x3cc10524 // ldr q4, [x9], #16
WORD $0x4e23cc80 // fmla.4s v0, v4, v3
WORD $0x4e23cc61 // fmla.4s v1, v3, v3
WORD $0x4e24cc82 // fmla.4s v2, v4, v4
WORD $0x91001108 // add x8, x8, #4
WORD $0xeb03011f // cmp x8, x3
WORD $0x54ffff23 // b.lo LBB0_2
WORD $0x927ef46b // and x11, x3, #0xfffffffffffffffc
WORD $0x14000005 // b LBB0_5
BB0_3:
WORD $0x2f00e402 // movi d2, #0000000000000000
WORD $0x2f00e401 // movi d1, #0000000000000000
WORD $0x1e21c042 // fsqrt s2, s2
WORD $0x1e202048 // fcmp s2, #0.0
WORD $0x540008a0 // b.eq LBB0_10
WORD $0x14000047 // b LBB0_11

BB0_4:
WORD $0xd280000b // mov x11, #0
WORD $0x6f00e402 // movi.2d v2, #0000000000000000
WORD $0x6f00e401 // movi.2d v1, #0000000000000000
WORD $0x6f00e400 // movi.2d v0, #0000000000000000

BB0_5:
WORD $0x6e20d400 // faddp.4s v0, v0, v0
WORD $0x7e30d800 // faddp.2s s0, v0
WORD $0x6e21d421 // faddp.4s v1, v1, v1
WORD $0x7e30d821 // faddp.2s s1, v1
WORD $0x6e22d442 // faddp.4s v2, v2, v2
WORD $0x7e30d842 // faddp.2s s2, v2
WORD $0xeb03017f // cmp x11, x3
WORD $0x54000702 // b.hs LBB0_13
WORD $0xcb0b0069 // sub x9, x3, x11
WORD $0xf100213f // cmp x9, #8
WORD $0x54000062 // b.hs LBB0_8
WORD $0xaa0b03e8 // mov x8, x11
WORD $0x14000028 // b LBB0_11

BB0_8:
WORD $0x927df12a // and x10, x9, #0xfffffffffffffff8
WORD $0x8b0a0168 // add x8, x11, x10
WORD $0x927cec68 // and x8, x3, #0xfffffffffffffff0
WORD $0x6f00e402 // movi.2d v2, #0000000000000000
WORD $0x91008009 // add x9, x0, #32
WORD $0x6f00e403 // movi.2d v3, #0000000000000000
WORD $0x9100802a // add x10, x1, #32
WORD $0x6f00e404 // movi.2d v4, #0000000000000000
WORD $0x6e040444 // mov.s v4[0], v2[0]
WORD $0x6f00e402 // movi.2d v2, #0000000000000000
WORD $0x6e040422 // mov.s v2[0], v1[0]
WORD $0x6f00e401 // movi.2d v1, #0000000000000000
WORD $0x6e040401 // mov.s v1[0], v0[0]
WORD $0xd37ef56b // lsl x11, x11, #2
WORD $0x9100416c // add x12, x11, #16
WORD $0x8b0c000b // add x11, x0, x12
WORD $0x8b0c002c // add x12, x1, x12
WORD $0xaa0a03ed // mov x13, x10
WORD $0xaa0803eb // mov x11, x8
WORD $0x6f00e405 // movi.2d v5, #0000000000000000
WORD $0x6f00e400 // movi.2d v0, #0000000000000000
WORD $0x6f00e406 // movi.2d v6, #0000000000000000
WORD $0x6f00e407 // movi.2d v7, #0000000000000000
WORD $0x6f00e410 // movi.2d v16, #0000000000000000
WORD $0x6f00e411 // movi.2d v17, #0000000000000000
WORD $0x6f00e412 // movi.2d v18, #0000000000000000
WORD $0x6f00e413 // movi.2d v19, #0000000000000000
WORD $0x6f00e414 // movi.2d v20, #0000000000000000

BB0_9:
WORD $0xad7f9d66 // ldp q6, q7, [x11, #-16]
WORD $0xad7fc590 // ldp q16, q17, [x12, #-16]
WORD $0x4e26ce01 // fmla.4s v1, v16, v6
WORD $0x4e27ce20 // fmla.4s v0, v17, v7
WORD $0x4e26ccc2 // fmla.4s v2, v6, v6
WORD $0x4e27cce5 // fmla.4s v5, v7, v7
WORD $0x4e30ce04 // fmla.4s v4, v16, v16
WORD $0x4e31ce23 // fmla.4s v3, v17, v17
WORD $0x9100816b // add x11, x11, #32
WORD $0x9100818c // add x12, x12, #32
WORD $0xf10021ad // subs x13, x13, #8
WORD $0x54fffea1 // b.ne LBB0_9
WORD $0x4e21d400 // fadd.4s v0, v0, v1
WORD $0x6e20d400 // faddp.4s v0, v0, v0
WORD $0x7e30d800 // faddp.2s s0, v0
WORD $0x4e22d4a1 // fadd.4s v1, v5, v2
BB0_5:
WORD $0xad7f5935 // ldp q21, q22, [x9, #-32]
WORD $0xacc26137 // ldp q23, q24, [x9], #64
WORD $0xad7f6959 // ldp q25, q26, [x10, #-32]
WORD $0xacc2715b // ldp q27, q28, [x10], #64
WORD $0x4e35cf21 // fmla.4s v1, v25, v21
WORD $0x4e36cf42 // fmla.4s v2, v26, v22
WORD $0x4e37cf63 // fmla.4s v3, v27, v23
WORD $0x4e38cf84 // fmla.4s v4, v28, v24
WORD $0x4e35cea5 // fmla.4s v5, v21, v21
WORD $0x4e36cec6 // fmla.4s v6, v22, v22
WORD $0x4e37cee7 // fmla.4s v7, v23, v23
WORD $0x4e38cf10 // fmla.4s v16, v24, v24
WORD $0x4e39cf31 // fmla.4s v17, v25, v25
WORD $0x4e3acf52 // fmla.4s v18, v26, v26
WORD $0x4e3bcf73 // fmla.4s v19, v27, v27
WORD $0x4e3ccf94 // fmla.4s v20, v28, v28
WORD $0xf100416b // subs x11, x11, #16
WORD $0x54fffde1 // b.ne LBB0_5
WORD $0x4e31d651 // fadd.4s v17, v18, v17
WORD $0x4e31d671 // fadd.4s v17, v19, v17
WORD $0x4e31d691 // fadd.4s v17, v20, v17
WORD $0x6e31d631 // faddp.4s v17, v17, v17
WORD $0x7e30da31 // faddp.2s s17, v17
WORD $0x4e25d4c5 // fadd.4s v5, v6, v5
WORD $0x4e25d4e5 // fadd.4s v5, v7, v5
WORD $0x4e25d605 // fadd.4s v5, v16, v5
WORD $0x6e25d4a5 // faddp.4s v5, v5, v5
WORD $0x7e30d8a5 // faddp.2s s5, v5
WORD $0x4e21d441 // fadd.4s v1, v2, v1
WORD $0x4e21d461 // fadd.4s v1, v3, v1
WORD $0x4e21d481 // fadd.4s v1, v4, v1
WORD $0x6e21d421 // faddp.4s v1, v1, v1
WORD $0x7e30d821 // faddp.2s s1, v1
WORD $0x4e24d462 // fadd.4s v2, v3, v4
WORD $0x6e22d442 // faddp.4s v2, v2, v2
WORD $0x7e30d842 // faddp.2s s2, v2
WORD $0xeb0a013f // cmp x9, x10
WORD $0x54000180 // b.eq LBB0_13
WORD $0xeb03011f // cmp x8, x3
WORD $0x54000180 // b.eq LBB0_9

BB0_11:
BB0_7:
WORD $0xcb080069 // sub x9, x3, x8
WORD $0xd37ef50a // lsl x10, x8, #2
WORD $0x8b0a0028 // add x8, x1, x10
WORD $0x8b0a000a // add x10, x0, x10

BB0_12:
WORD $0xbc404543 // ldr s3, [x10], #4
WORD $0xbc404504 // ldr s4, [x8], #4
WORD $0x1f030080 // fmadd s0, s4, s3, s0
WORD $0x1f030461 // fmadd s1, s3, s3, s1
WORD $0x1f040882 // fmadd s2, s4, s4, s2
BB0_8:
WORD $0xbc404542 // ldr s2, [x10], #4
WORD $0xbc404503 // ldr s3, [x8], #4
WORD $0x1f020461 // fmadd s1, s3, s2, s1
WORD $0x1f021445 // fmadd s5, s2, s2, s5
WORD $0x1f034471 // fmadd s17, s3, s3, s17
WORD $0xf1000529 // subs x9, x9, #1
WORD $0x54ffff41 // b.ne LBB0_12
WORD $0x54ffff41 // b.ne LBB0_8

BB0_13:
WORD $0x1e210841 // fmul s1, s2, s1
WORD $0x1e21c022 // fsqrt s2, s1
WORD $0x2f00e401 // movi d1, #0000000000000000
BB0_9:
WORD $0x1e3108a2 // fmul s2, s5, s17
WORD $0x1e22c021 // fcvt d1, s1
WORD $0x1e21c042 // fsqrt s2, s2
WORD $0x1e202048 // fcmp s2, #0.0
WORD $0x54000081 // b.ne LBB0_15
WORD $0xfd000041 // str d1, [x2]
WORD $0x54000081 // b.ne LBB0_11

BB0_10:
WORD $0xfd000040 // str d0, [x2]
WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 ; 16-byte Folded Reload
WORD $0xd65f03c0 // ret

BB0_15:
WORD $0x1e22c000 // fcvt d0, s0
WORD $0x1e22c041 // fcvt d1, s2
WORD $0x1e611801 // fdiv d1, d0, d1
WORD $0xfd000041 // str d1, [x2]
BB0_11:
WORD $0x1e22c040 // fcvt d0, s2
WORD $0x1e601820 // fdiv d0, d1, d0
WORD $0xfd000040 // str d0, [x2]
WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 ; 16-byte Folded Reload
WORD $0xd65f03c0 // ret
Loading

0 comments on commit 094c9fa

Please sign in to comment.