Skip to content
This repository was archived by the owner on Apr 28, 2025. It is now read-only.

Improve sqrt and sqrtf by using hardware more often. #222

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[![crates.io](https://img.shields.io/crates/v/libm.svg)](https://crates.io/crates/libm)
[![docs.rs](https://docs.rs/libm/badge.svg)](https://docs.rs/libm/)
[![Build Status](https://dev.azure.com/rust-lang/libm/_apis/build/status/rust-lang-nursery.libm?branchName=master)](https://dev.azure.com/rust-lang/libm/_build/latest?definitionId=7&branchName=master)

# `libm`

A port of [MUSL]'s libm to Rust.
Expand Down
5 changes: 1 addition & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
//! libm in pure Rust
#![deny(warnings)]
#![no_std]
#![cfg_attr(
all(target_arch = "wasm32", feature = "unstable"),
feature(core_intrinsics)
)]
#![cfg_attr(feature = "unstable", feature(core_intrinsics))]
#![allow(clippy::unreadable_literal)]
#![allow(clippy::many_single_char_names)]
#![allow(clippy::needless_return)]
Expand Down
300 changes: 159 additions & 141 deletions src/math/sqrt.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,54 @@
/// [Square root](https://en.wikipedia.org/wiki/Square_root) of an `f64`.
///
/// This function is intended to exactly match the
/// [`sqrt`](https://en.cppreference.com/w/c/numeric/math/sqrt) function as
/// defined by the C/C++ spec.
#[allow(unreachable_code)]
pub fn sqrt(x: f64) -> f64 {
// On most targets LLVM will issue a hardware sqrt instruction instead of a
// call to this function when you call `core::intrinsics::sqrtf64(x)`.
// However, not all targets have hardware sqrt, and also people might end up
// calling this function directly, so we must still do our best to get a
// hardware instruction used, and then go with a software fallback if
// necessary.

// Nightly: if we're *sure* that LLVM supports a hardware instruction then
// we'll call the LLVM intrinsic. We have to be conservative in our
// selection for when to do this, because if the intrinsic usage ends up
// calling back here it's infinite recursion.
#[cfg(all(
feature = "unstable",
any(
all(target_arch = "x86", not(target_feature = "soft-float")),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I may not have been clear enough earlier, but soft-float and hard-float are not features, I don't think that target_feature = "soft-float" is ever set (nor for hard-float)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I got the name of those target features from rustc --print target-features --target TARGET, so I think at least rustc thinks they exist.

How else should we detect for hard or soft float?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately -C target-feature, what --print target-features does, is not equivalent to #[cfg(target_feature = "..")]. I'd recommend looking at compiled code to see how the #[cfg] here isn't working.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I see now that target_feature = "soft-float" will always evaluate to false (for example), but that brings us back to the question: how do we respect the hard-float/soft-float configuration of the target? If we can't do that, we can't ever be sure that a hardware instruction exists to be called.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I don't always have the answers to all the questions, I'm just pointing out how this isn't working as intended. I don't know how this would be detected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's totally fine if you don't know it all! You just happen to usually know the most, and you're "here already". We'll have to ask around I guess.

Without an answer to this problem we're just plain stuck. It would be incorrect to call back to the LLVM intrinsic if the arch's version of hardware floating point isn't enabled, since LLVM will end up calling us again and it's just infinite recursion.

The only real relief to be had is that if there is hardware sqrt then LLVM won't ever call us here, so we could just skip trying to divert back to LLVM at all. Except that's where we started basically and that's what could potentially give the performance regressions that we're trying to avoid.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened an issue to seek a wider audience for support: rust-lang/rust#64514

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I'm understanding the issue it's that if we export sqrt and sqrtf from compiler_builtins unconditionally then any C code that's linked will use this version and it might perform slower than what it's already using. I think the solution is to not export sqrt and sqrtf unconditionally not only to avoid the performance issue but also because C code could depend on the specific error handling of the libm it was compiled against that this code can't emulate. core already depends on fmod so I think sqrt can and should be handled in the same way.

all(target_arch = "x86_64", not(target_feature = "soft-float")),
all(target_arch = "aarch64", not(target_feature = "soft-float")),
all(target_arch = "powerpc", target_feature = "hard-float"),
all(target_arch = "powerpc64", target_feature = "hard-float"),
all(target_arch = "risc", target_feature = "d"),
target_arch = "wasm32",
)
))]
{
return unsafe { core::intrinsics::sqrtf64(x) };
}
// Stable: We can use `sse2` if available. As more intrinsic sets stabilize
// we can expand this to use hardware more often.
#[cfg(target_feature = "sse2")]
{
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
return unsafe {
let m = _mm_set_sd(x);
let m_sqrt = _mm_sqrt_pd(m);
_mm_cvtsd_f64(m_sqrt)
};
}
// Finally, if we must, we use the software version (below).
software_sqrt(x)
}

/* origin: FreeBSD /usr/src/lib/msun/src/e_sqrt.c */
/*
* ====================================================
Expand Down Expand Up @@ -75,169 +126,136 @@
* sqrt(-ve) = NaN ... with invalid signal
* sqrt(NaN) = NaN ... with invalid signal for signaling NaN
*/

use core::f64;

#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
pub fn sqrt(x: f64) -> f64 {
// On wasm32 we know that LLVM's intrinsic will compile to an optimized
// `f64.sqrt` native instruction, so we can leverage this for both code size
// and speed.
llvm_intrinsically_optimized! {
#[cfg(target_arch = "wasm32")] {
return if x < 0.0 {
f64::NAN
} else {
unsafe { ::core::intrinsics::sqrtf64(x) }
}
}
}
#[cfg(target_feature = "sse2")]
{
// Note: This path is unlikely since LLVM will usually have already
// optimized sqrt calls into hardware instructions if sse2 is available,
// but if someone does end up here they'll apprected the speed increase.
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
unsafe {
let m = _mm_set_sd(x);
let m_sqrt = _mm_sqrt_pd(m);
_mm_cvtsd_f64(m_sqrt)
}
}
#[cfg(not(target_feature = "sse2"))]
{
use core::num::Wrapping;
fn software_sqrt(x: f64) -> f64 {
use core::num::Wrapping;

const TINY: f64 = 1.0e-300;
const TINY: f64 = 1.0e-300;

let mut z: f64;
let sign: Wrapping<u32> = Wrapping(0x80000000);
let mut ix0: i32;
let mut s0: i32;
let mut q: i32;
let mut m: i32;
let mut t: i32;
let mut i: i32;
let mut r: Wrapping<u32>;
let mut t1: Wrapping<u32>;
let mut s1: Wrapping<u32>;
let mut ix1: Wrapping<u32>;
let mut q1: Wrapping<u32>;
let mut z: f64;
let sign: Wrapping<u32> = Wrapping(0x80000000);
let mut ix0: i32;
let mut s0: i32;
let mut q: i32;
let mut m: i32;
let mut t: i32;
let mut i: i32;
let mut r: Wrapping<u32>;
let mut t1: Wrapping<u32>;
let mut s1: Wrapping<u32>;
let mut ix1: Wrapping<u32>;
let mut q1: Wrapping<u32>;

ix0 = (x.to_bits() >> 32) as i32;
ix1 = Wrapping(x.to_bits() as u32);
ix0 = (x.to_bits() >> 32) as i32;
ix1 = Wrapping(x.to_bits() as u32);

/* take care of Inf and NaN */
if (ix0 & 0x7ff00000) == 0x7ff00000 {
return x * x + x; /* sqrt(NaN)=NaN, sqrt(+inf)=+inf, sqrt(-inf)=sNaN */
/* take care of Inf and NaN */
if (ix0 & 0x7ff00000) == 0x7ff00000 {
return x * x + x; /* sqrt(NaN)=NaN, sqrt(+inf)=+inf, sqrt(-inf)=sNaN */
}
/* take care of zero */
if ix0 <= 0 {
if ((ix0 & !(sign.0 as i32)) | ix1.0 as i32) == 0 {
return x; /* sqrt(+-0) = +-0 */
}
/* take care of zero */
if ix0 <= 0 {
if ((ix0 & !(sign.0 as i32)) | ix1.0 as i32) == 0 {
return x; /* sqrt(+-0) = +-0 */
}
if ix0 < 0 {
return (x - x) / (x - x); /* sqrt(-ve) = sNaN */
}
if ix0 < 0 {
return (x - x) / (x - x); /* sqrt(-ve) = sNaN */
}
/* normalize x */
m = ix0 >> 20;
if m == 0 {
/* subnormal x */
while ix0 == 0 {
m -= 21;
ix0 |= (ix1 >> 11).0 as i32;
ix1 <<= 21;
}
i = 0;
while (ix0 & 0x00100000) == 0 {
i += 1;
ix0 <<= 1;
}
m -= i - 1;
ix0 |= (ix1 >> (32 - i) as usize).0 as i32;
ix1 = ix1 << i as usize;
}
/* normalize x */
m = ix0 >> 20;
if m == 0 {
/* subnormal x */
while ix0 == 0 {
m -= 21;
ix0 |= (ix1 >> 11).0 as i32;
ix1 <<= 21;
}
m -= 1023; /* unbias exponent */
ix0 = (ix0 & 0x000fffff) | 0x00100000;
if (m & 1) == 1 {
/* odd m, double x to make it even */
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
i = 0;
while (ix0 & 0x00100000) == 0 {
i += 1;
ix0 <<= 1;
}
m >>= 1; /* m = [m/2] */

/* generate sqrt(x) bit by bit */
m -= i - 1;
ix0 |= (ix1 >> (32 - i) as usize).0 as i32;
ix1 = ix1 << i as usize;
}
m -= 1023; /* unbias exponent */
ix0 = (ix0 & 0x000fffff) | 0x00100000;
if (m & 1) == 1 {
/* odd m, double x to make it even */
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
q = 0; /* [q,q1] = sqrt(x) */
q1 = Wrapping(0);
s0 = 0;
s1 = Wrapping(0);
r = Wrapping(0x00200000); /* r = moving bit from right to left */
}
m >>= 1; /* m = [m/2] */

while r != Wrapping(0) {
t = s0 + r.0 as i32;
if t <= ix0 {
s0 = t + r.0 as i32;
ix0 -= t;
q += r.0 as i32;
}
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
r >>= 1;
/* generate sqrt(x) bit by bit */
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
q = 0; /* [q,q1] = sqrt(x) */
q1 = Wrapping(0);
s0 = 0;
s1 = Wrapping(0);
r = Wrapping(0x00200000); /* r = moving bit from right to left */

while r != Wrapping(0) {
t = s0 + r.0 as i32;
if t <= ix0 {
s0 = t + r.0 as i32;
ix0 -= t;
q += r.0 as i32;
}
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
r >>= 1;
}

r = sign;
while r != Wrapping(0) {
t1 = s1 + r;
t = s0;
if t < ix0 || (t == ix0 && t1 <= ix1) {
s1 = t1 + r;
if (t1 & sign) == sign && (s1 & sign) == Wrapping(0) {
s0 += 1;
}
ix0 -= t;
if ix1 < t1 {
ix0 -= 1;
}
ix1 -= t1;
q1 += r;
r = sign;
while r != Wrapping(0) {
t1 = s1 + r;
t = s0;
if t < ix0 || (t == ix0 && t1 <= ix1) {
s1 = t1 + r;
if (t1 & sign) == sign && (s1 & sign) == Wrapping(0) {
s0 += 1;
}
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
r >>= 1;
ix0 -= t;
if ix1 < t1 {
ix0 -= 1;
}
ix1 -= t1;
q1 += r;
}
ix0 += ix0 + ((ix1 & sign) >> 31).0 as i32;
ix1 += ix1;
r >>= 1;
}

/* use floating add to find out rounding direction */
if (ix0 as u32 | ix1.0) != 0 {
z = 1.0 - TINY; /* raise inexact flag */
if z >= 1.0 {
z = 1.0 + TINY;
if q1.0 == 0xffffffff {
q1 = Wrapping(0);
/* use floating add to find out rounding direction */
if (ix0 as u32 | ix1.0) != 0 {
z = 1.0 - TINY; /* raise inexact flag */
if z >= 1.0 {
z = 1.0 + TINY;
if q1.0 == 0xffffffff {
q1 = Wrapping(0);
q += 1;
} else if z > 1.0 {
if q1.0 == 0xfffffffe {
q += 1;
} else if z > 1.0 {
if q1.0 == 0xfffffffe {
q += 1;
}
q1 += Wrapping(2);
} else {
q1 += q1 & Wrapping(1);
}
q1 += Wrapping(2);
} else {
q1 += q1 & Wrapping(1);
}
}
ix0 = (q >> 1) + 0x3fe00000;
ix1 = q1 >> 1;
if (q & 1) == 1 {
ix1 |= sign;
}
ix0 += m << 20;
f64::from_bits((ix0 as u64) << 32 | ix1.0 as u64)
}
ix0 = (q >> 1) + 0x3fe00000;
ix1 = q1 >> 1;
if (q & 1) == 1 {
ix1 |= sign;
}
ix0 += m << 20;
f64::from_bits((ix0 as u64) << 32 | ix1.0 as u64)
}

#[cfg(test)]
Expand Down
Loading