Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more accurate sqrt function #129

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 121 additions & 30 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,40 +281,87 @@ impl<T: Float> Complex<T> {
///
/// The branch satisfies `-π/2 ≤ arg(sqrt(z)) ≤ π/2`.
#[inline]
pub fn sqrt(self) -> Self {
if self.im.is_zero() {
if self.re.is_sign_positive() {
// simple positive real √r, and copy `im` for its sign
Self::new(self.re.sqrt(), self.im)
pub fn sqrt(mut self) -> Self {
// complex sqrt algorithm based on the algorithm from
// dl.acm.org/doi/abs/10.1145/363717.363780 with additional tweaks
// to increase accuracy. Compared to a naive implementationt that
// reuses the complex exp/ln implementations this algorithm has better
// accuarcy since both (real) sqrt and (real) hypot are garunteed to
// round perfectly. It's also faster since this implementation requires
// less transcendental functions and those it does use (sqrt/hypto) are
// faster comparted to exp/sin/cos.
//
// The musl libc implementation was referenced while implementing the
// algorithm here:
// https://git.musl-libc.org/cgit/musl/tree/src/complex/csqrt.c

// TODO: rounding for very tiny subnormal numbers isn't perfect yet so
// the assert shown fails in the very worst case this leads to about
// 10% accuracy loss (see example below). As the magnitude increase the
// error quickly drops to basically zero.
//
// glibc handles that (but other implementations like musl and numpy do
// not) by upscaling very small values. That upscaling (and particularly
// it's reversal) are weird and hard to understand (and rely on mantissa
// bit size which we can't get out of the trait). In general the glibc
// implementation is ever so subtley different and I wouldn't want to
// introduce bugs by trying to adapt the underflow handling.
//
// assert_eq!(
// Complex64::new(5.212e-324, 5.212e-324).sqrt(),
// Complex64::new(2.4421097261308304e-162, 1.0115549693666347e-162)
// );

// specical cases for correct nan/inf handling
// see https://en.cppreference.com/w/c/numeric/complex/csqrt

if self.re.is_zero() && self.im.is_zero() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a source for all these special cases? e.g.
https://en.cppreference.com/w/c/numeric/complex/csqrt
(and make sure all those are covered)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added more test to test_nan to make sure all of these are covered by theses and added a comment

// 0 +/- 0 i
return Self::new(T::zero(), self.im);
}
if self.im.is_infinite() {
// inf +/- inf i
return Self::new(T::infinity(), self.im);
}
if self.re.is_nan() {
// nan + nan i
return Self::new(self.re, T::nan());
}
if self.re.is_infinite() {
// √(inf +/- NaN i) = inf +/- NaN i
// √(inf +/- x i) = inf +/- 0 i
// √(-inf +/- NaN i) = NaN +/- inf i
// √(-inf +/- x i) = 0 +/- inf i

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a variable to make this clearer:

Suggested change
#[allow(clippy::eq_op)]
let zero_or_nan = self.im - self.im;

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is indeed more readable, I also added a comments. good point

// if im is inf (or nan) this is nan, otherwise it's zero
#[allow(clippy::eq_op)]
let zero_or_nan = self.im - self.im;
if self.re.is_sign_negative() {
return Self::new(zero_or_nan.abs(), self.re.copysign(self.im));
} else {
// √(r e^(iπ)) = √r e^(iπ/2) = i√r
// √(r e^(-iπ)) = √r e^(-iπ/2) = -i√r
let re = T::zero();
let im = (-self.re).sqrt();
if self.im.is_sign_positive() {
Self::new(re, im)
} else {
Self::new(re, -im)
}
}
} else if self.re.is_zero() {
// √(r e^(iπ/2)) = √r e^(iπ/4) = √(r/2) + i√(r/2)
// √(r e^(-iπ/2)) = √r e^(-iπ/4) = √(r/2) - i√(r/2)
let one = T::one();
let two = one + one;
let x = (self.im.abs() / two).sqrt();
if self.im.is_sign_positive() {
Self::new(x, x)
} else {
Self::new(x, -x)
return Self::new(self.re, zero_or_nan.copysign(self.im));
}
}
let two = T::one() + T::one();
let four = two + two;
let overflow = T::max_value() / (T::one() + T::sqrt(two));
let max_magnitude = self.re.abs().max(self.im.abs());
let scale = max_magnitude >= overflow;
if scale {
self = self / four;
}
if self.re.is_sign_negative() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also use a citation and link in a comment for the algorithm you mentioned.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a citation to the algorithm and the musl libc implementation as well as provide some additional background in a comement

let tmp = ((-self.re + self.norm()) / two).sqrt();
self.re = self.im.abs() / (two * tmp);
self.im = tmp.copysign(self.im);
} else {
// formula: sqrt(r e^(it)) = sqrt(r) e^(it/2)
let one = T::one();
let two = one + one;
let (r, theta) = self.to_polar();
Self::from_polar(r.sqrt(), theta / two)
self.re = ((self.re + self.norm()) / two).sqrt();
self.im = self.im / (two * self.re);
}
if scale {
self = self * two;
}
self
}

/// Computes the principal value of the cube root of `self`.
Expand Down Expand Up @@ -2065,6 +2112,50 @@ pub(crate) mod test {
}
}

#[test]
fn test_sqrt_nan() {
assert!(close_naninf(
Complex64::new(f64::INFINITY, f64::NAN).sqrt(),
Complex64::new(f64::INFINITY, f64::NAN),
));
assert!(close_naninf(
Complex64::new(f64::NAN, f64::INFINITY).sqrt(),
Complex64::new(f64::INFINITY, f64::INFINITY),
));
assert!(close_naninf(
Complex64::new(f64::NEG_INFINITY, -f64::NAN).sqrt(),
Complex64::new(f64::NAN, f64::NEG_INFINITY),
));
assert!(close_naninf(
Complex64::new(f64::NEG_INFINITY, f64::NAN).sqrt(),
Complex64::new(f64::NAN, f64::INFINITY),
));
assert!(close_naninf(
Complex64::new(-0.0, 0.0).sqrt(),
Complex64::new(0.0, 0.0),
));
for x in (-100..100).map(f64::from) {
assert!(close_naninf(
Complex64::new(x, f64::INFINITY).sqrt(),
Complex64::new(f64::INFINITY, f64::INFINITY),
));
assert!(close_naninf(
Complex64::new(f64::NAN, x).sqrt(),
Complex64::new(f64::NAN, f64::NAN),
));
// √(inf + x i) = inf + 0 i
assert!(close_naninf(
Complex64::new(f64::INFINITY, x).sqrt(),
Complex64::new(f64::INFINITY, 0.0.copysign(x)),
));
// √(-inf + x i) = 0 + inf i
assert!(close_naninf(
Complex64::new(f64::NEG_INFINITY, x).sqrt(),
Complex64::new(0.0, f64::INFINITY.copysign(x)),
));
}
}

#[test]
fn test_cbrt() {
assert!(close(_0_0i.cbrt(), _0_0i));
Expand Down