Skip to content

Commit 6f95f37

Browse files
authored
[ENH] Use AVX in distance calculations (#5258)
## Description of changes This PR updates the distance crate to prioritize AVX over SSE, and forces builds with AVX flags when ENABLE_AVX512 is set in the Dockerfile ## Test plan _How are these changes tested?_ - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent 06e36fc commit 6f95f37

File tree

1 file changed

+27
-27
lines changed

1 file changed

+27
-27
lines changed

rust/distance/src/types.rs

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -243,15 +243,6 @@ impl DistanceFunction {
243243
return unsafe { crate::distance_neon::euclidean_distance(a, b) };
244244
}
245245
}
246-
#[cfg(all(
247-
any(target_arch = "x86_64", target_arch = "x86"),
248-
target_feature = "sse"
249-
))]
250-
{
251-
if std::arch::is_x86_feature_detected!("sse") {
252-
return unsafe { crate::distance_sse::euclidean_distance(a, b) };
253-
}
254-
}
255246
#[cfg(all(
256247
target_arch = "x86_64",
257248
all(target_feature = "avx", target_feature = "fma")
@@ -263,6 +254,15 @@ impl DistanceFunction {
263254
return unsafe { crate::distance_avx::euclidean_distance(a, b) };
264255
}
265256
}
257+
#[cfg(all(
258+
any(target_arch = "x86_64", target_arch = "x86"),
259+
target_feature = "sse"
260+
))]
261+
{
262+
if std::arch::is_x86_feature_detected!("sse") {
263+
return unsafe { crate::distance_sse::euclidean_distance(a, b) };
264+
}
265+
}
266266
let mut sum = 0.0;
267267
for i in 0..a.len() {
268268
sum += (a[i] - b[i]).powi(2);
@@ -276,15 +276,6 @@ impl DistanceFunction {
276276
return unsafe { crate::distance_neon::cosine_distance(a, b) };
277277
}
278278
}
279-
#[cfg(all(
280-
any(target_arch = "x86_64", target_arch = "x86"),
281-
target_feature = "sse"
282-
))]
283-
{
284-
if std::arch::is_x86_feature_detected!("sse") {
285-
return unsafe { crate::distance_sse::cosine_distance(a, b) };
286-
}
287-
}
288279
#[cfg(all(
289280
target_arch = "x86_64",
290281
all(target_feature = "avx", target_feature = "fma")
@@ -296,6 +287,15 @@ impl DistanceFunction {
296287
return unsafe { crate::distance_avx::cosine_distance(a, b) };
297288
}
298289
}
290+
#[cfg(all(
291+
any(target_arch = "x86_64", target_arch = "x86"),
292+
target_feature = "sse"
293+
))]
294+
{
295+
if std::arch::is_x86_feature_detected!("sse") {
296+
return unsafe { crate::distance_sse::cosine_distance(a, b) };
297+
}
298+
}
299299
// For cosine we just assume the vectors have been normalized, since that
300300
// is what our indices expect.
301301
let mut sum = 0.0;
@@ -311,15 +311,6 @@ impl DistanceFunction {
311311
return unsafe { crate::distance_neon::inner_product(a, b) };
312312
}
313313
}
314-
#[cfg(all(
315-
any(target_arch = "x86_64", target_arch = "x86"),
316-
target_feature = "sse"
317-
))]
318-
{
319-
if std::arch::is_x86_feature_detected!("sse") {
320-
return unsafe { crate::distance_sse::inner_product(a, b) };
321-
}
322-
}
323314
#[cfg(all(
324315
target_arch = "x86_64",
325316
all(target_feature = "avx", target_feature = "fma")
@@ -331,6 +322,15 @@ impl DistanceFunction {
331322
return unsafe { crate::distance_avx::inner_product(a, b) };
332323
}
333324
}
325+
#[cfg(all(
326+
any(target_arch = "x86_64", target_arch = "x86"),
327+
target_feature = "sse"
328+
))]
329+
{
330+
if std::arch::is_x86_feature_detected!("sse") {
331+
return unsafe { crate::distance_sse::inner_product(a, b) };
332+
}
333+
}
334334
let mut sum = 0.0;
335335
for i in 0..a.len() {
336336
sum += a[i] * b[i];

0 commit comments

Comments
 (0)