Skip to content

Commit

Permalink
fix hypot with multiple arguments
Browse files Browse the repository at this point in the history
Fixes #27141: The previous code led to under/overflow.
  • Loading branch information
Jorge Fernandez-de-Cossio-Diaz authored and vtjnash committed Oct 26, 2020
1 parent 031c877 commit 02a66f7
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 54 deletions.
95 changes: 45 additions & 50 deletions base/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,10 @@ by Carlos F. Borges
The article is available online at ArXiv at the link
https://arxiv.org/abs/1904.09481
hypot(x...)
Compute the hypotenuse ``\\sqrt{\\sum |x_i|^2}`` avoiding overflow and underflow.
# Examples
```jldoctest; filter = r"Stacktrace:(\\n \\[[0-9]+\\].*)*"
julia> a = Int64(10)^10;
Expand All @@ -625,85 +629,81 @@ Stacktrace:
julia> hypot(3, 4im)
5.0
julia> hypot(-5.7)
5.7
julia> hypot(3, 4im, 12.0)
13.0
```
"""
hypot(x::Number, y::Number) = hypot(promote(x, y)...)
hypot(x::Complex, y::Complex) = hypot(abs(x), abs(y))
hypot(x::T, y::T) where {T<:Real} = hypot(float(x), float(y))
function hypot(x::T, y::T) where {T<:Number}
if !iszero(x)
z = y/x
z2 = z*z
hypot(x::Number) = abs(float(x))
hypot(x::Number, y::Number, xs::Number...) = _hypot(float.(promote(x, y, xs...))...)
function _hypot(x, y)
# preserves unit
axu = abs(x)
ayu = abs(y)

abs(x) * sqrt(oneunit(z2) + z2)
else
abs(y)
end
end
# unitless
ax = axu / oneunit(axu)
ay = ayu / oneunit(ayu)

function hypot(x::T, y::T) where T<:AbstractFloat
# Return Inf if either or both inputs is Inf (Compliance with IEEE754)
if isinf(x) || isinf(y)
return T(Inf)
if isinf(ax) || isinf(ay)
return oftype(axu, Inf)
end

# Order the operands
ax,ay = abs(x), abs(y)
if ay > ax
ax,ay = ay,ax
axu, ayu = ayu, axu
ax, ay = ay, ax
end

# Widely varying operands
if ay <= ax*sqrt(eps(T)/2) #Note: This also gets ay == 0
return ax
if ay <= ax*sqrt(eps(typeof(ax))/2) #Note: This also gets ay == 0
return axu
end

# Operands do not vary widely
scale = eps(T)*sqrt(floatmin(T)) #Rescaling constant
if ax > sqrt(floatmax(T)/2)
scale = eps(typeof(ax))*sqrt(floatmin(ax)) #Rescaling constant
if ax > sqrt(floatmax(ax)/2)
ax = ax*scale
ay = ay*scale
scale = inv(scale)
elseif ay < sqrt(floatmin(T))
elseif ay < sqrt(floatmin(ax))
ax = ax/scale
ay = ay/scale
else
scale = one(scale)
scale = oneunit(scale)
end
h = sqrt(muladd(ax,ax,ay*ay))
h = sqrt(muladd(ax, ax, ay*ay))
# This branch is correctly rounded but requires a native hardware fma.
if Base.Math.FMA_NATIVE
hsquared = h*h
axsquared = ax*ax
h -= (fma(-ay,ay,hsquared-axsquared) + fma(h,h,-hsquared) - fma(ax,ax,-axsquared))/(2*h)
h -= (fma(-ay, ay, hsquared-axsquared) + fma(h, h,-hsquared) - fma(ax, ax, -axsquared))/(2*h)
# This branch is within one ulp of correctly rounded.
else
if h <= 2*ay
delta = h-ay
h -= muladd(delta,delta-2*(ax-ay),ax*(2*delta - ax))/(2*h)
h -= muladd(delta, delta-2*(ax-ay), ax*(2*delta - ax))/(2*h)
else
delta = h-ax
h -= muladd(delta,delta,muladd(ay,(4*delta-ay),2*delta*(ax-2*ay)))/(2*h)
h -= muladd(delta, delta, muladd(ay, (4*delta - ay), 2*delta*(ax - 2*ay)))/(2*h)
end
end
return h*scale
return h*scale*oneunit(axu)
end
function _hypot(x...)
maxabs = maximum(abs, x)
if isnan(maxabs) && any(isinf, x)
return oftype(maxabs, Inf)
elseif (iszero(maxabs) || isinf(maxabs))
return maxabs
else
return maxabs * sqrt(sum(y -> abs2(y / maxabs), x))
end
end

"""
hypot(x...)
Compute the hypotenuse ``\\sqrt{\\sum |x_i|^2}`` avoiding overflow and underflow.
# Examples
```jldoctest
julia> hypot(-5.7)
5.7
julia> hypot(3, 4im, 12.0)
13.0
```
"""
hypot(x::Number...) = sqrt(sum(abs2(y) for y in x))

atan(y::Real, x::Real) = atan(promote(float(y),float(x))...)
atan(y::T, x::T) where {T<:AbstractFloat} = Base.no_op_err("atan", T)
Expand Down Expand Up @@ -1150,12 +1150,7 @@ for func in (:sin,:cos,:tan,:asin,:acos,:atan,:sinh,:cosh,:tanh,:asinh,:acosh,
end
end

for func in (:atan,:hypot)
@eval begin
$func(a::Float16,b::Float16) = Float16($func(Float32(a),Float32(b)))
end
end

atan(a::Float16,b::Float16) = Float16(atan(Float32(a),Float32(b)))
cbrt(a::Float16) = Float16(cbrt(Float32(a)))
sincos(a::Float16) = Float16.(sincos(Float32(a)))

Expand Down
68 changes: 64 additions & 4 deletions test/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1107,8 +1107,68 @@ end

isdefined(Main, :Furlongs) || @eval Main include("testhelpers/Furlongs.jl")
using .Main.Furlongs
@test hypot(Furlong(0), Furlong(0)) == Furlong(0.0)
@test hypot(Furlong(3), Furlong(4)) == Furlong(5.0)
@test hypot(Complex(3), Complex(4)) === 5.0
@test hypot(Complex(6, 8), Complex(8, 6)) === 10.0*sqrt(2)
@test (@inferred hypot(Furlong(0), Furlong(0))) == Furlong(0.0)
@test (@inferred hypot(Furlong(3), Furlong(4))) == Furlong(5.0)
@test (@inferred hypot(Furlong(NaN), Furlong(Inf))) == Furlong(Inf)
@test (@inferred hypot(Furlong(Inf), Furlong(NaN))) == Furlong(Inf)
@test (@inferred hypot(Furlong(0), Furlong(0), Furlong(0))) == Furlong(0.0)
@test (@inferred hypot(Furlong(Inf), Furlong(Inf))) == Furlong(Inf)
@test (@inferred hypot(Furlong(1), Furlong(1), Furlong(1))) == Furlong(sqrt(3))
@test (@inferred hypot(Furlong(Inf), Furlong(NaN), Furlong(0))) == Furlong(Inf)
@test (@inferred hypot(Furlong(Inf), Furlong(Inf), Furlong(Inf))) == Furlong(Inf)
@test isnan(hypot(Furlong(NaN), Furlong(0), Furlong(1)))
ex = @test_throws ErrorException hypot(Furlong(1), 1)
@test startswith(ex.value.msg, "promotion of types ")

@test_throws MethodError hypot()
@test (@inferred hypot(floatmax())) == floatmax()
@test (@inferred hypot(floatmax(), floatmax())) == Inf
@test (@inferred hypot(floatmin(), floatmin())) == 2floatmin()
@test (@inferred hypot(floatmin(), floatmin(), floatmin())) == 3floatmin()
@test (@inferred hypot(1e-162)) 1e-162
@test (@inferred hypot(2e-162, 1e-162, 1e-162)) hypot(2, 1, 1)*1e-162
@test (@inferred hypot(1e162)) 1e162
@test hypot(-2) === 2.0
@test hypot(-2, 0) === 2.0
let i = typemax(Int)
@test (@inferred hypot(i, i)) i * 2
@test (@inferred hypot(i, i, i)) i * 3
@test (@inferred hypot(i, i, i, i)) 2.0i
@test (@inferred hypot(i//1, 1//i, 1//i)) i
end
let i = typemin(Int)
@test (@inferred hypot(i, i)) -√2i
@test (@inferred hypot(i, i, i)) -√3i
@test (@inferred hypot(i, i, i, i)) -2.0i
end
@testset "$T" for T in (Float32, Float64)
@test (@inferred hypot(T(Inf), T(NaN))) == T(Inf) # IEEE754 says so
@test (@inferred hypot(T(Inf), T(3//2), T(NaN))) == T(Inf)
@test (@inferred hypot(T(1e10), T(1e10), T(1e10), T(1e10))) 2e10
@test isnan_type(T, hypot(T(3), T(3//4), T(NaN)))
@test hypot(T(1), T(0)) === T(1)
@test hypot(T(1), T(0), T(0)) === T(1)
@test (@inferred hypot(T(Inf), T(Inf), T(Inf))) == T(Inf)
for s in (zero(T), floatmin(T)*1e3, floatmax(T)*1e-3, T(Inf))
@test hypot(1s, 2s) s * hypot(1, 2) rtol=8eps(T)
@test hypot(1s, 2s, 3s) s * hypot(1, 2, 3) rtol=8eps(T)
end
end
@testset "$T" for T in (Float16, Float32, Float64, BigFloat)
let x = 1.1sqrt(floatmin(T))
@test (@inferred hypot(x, x/4)) x * sqrt(17/BigFloat(16))
@test (@inferred hypot(x, x/4, x/4)) x * sqrt(9/BigFloat(8))
end
let x = 2sqrt(nextfloat(zero(T)))
@test (@inferred hypot(x, x/4)) x * sqrt(17/BigFloat(16))
@test (@inferred hypot(x, x/4, x/4)) x * sqrt(9/BigFloat(8))
end
let x = sqrt(nextfloat(zero(T))/eps(T))/8, f = sqrt(4eps(T))
@test hypot(x, x*f) x * hypot(one(f), f) rtol=eps(T)
@test hypot(x, x*f, x*f) x * hypot(one(f), f, f) rtol=eps(T)
end
end
# hypot on Complex returns Real
@test (@inferred hypot(3, 4im)) === 5.0
@test (@inferred hypot(3, 4im, 12)) === 13.0
end
7 changes: 7 additions & 0 deletions test/testhelpers/Furlongs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ Base.oneunit(x::Type{Furlong{p,T}}) where {p,T} = Furlong{p,T}(one(T))
Base.zero(x::Furlong{p,T}) where {p,T} = Furlong{p,T}(zero(T))
Base.zero(::Type{Furlong{p,T}}) where {p,T} = Furlong{p,T}(zero(T))
Base.iszero(x::Furlong) = iszero(x.val)
Base.float(x::Furlong{p}) where {p} = Furlong{p}(float(x.val))
Base.eps(::Type{Furlong{p,T}}) where {p,T<:AbstractFloat} = Furlong{p}(eps(T))
Base.eps(::Furlong{p,T}) where {p,T<:AbstractFloat} = eps(Furlong{p,T})
Base.floatmin(::Type{Furlong{p,T}}) where {p,T<:AbstractFloat} = Furlong{p}(floatmin(T))
Base.floatmin(::Furlong{p,T}) where {p,T<:AbstractFloat} = floatmin(Furlong{p,T})
Base.floatmax(::Type{Furlong{p,T}}) where {p,T<:AbstractFloat} = Furlong{p}(floatmax(T))
Base.floatmax(::Furlong{p,T}) where {p,T<:AbstractFloat} = floatmax(Furlong{p,T})

# convert Furlong exponent p to a canonical form. This
# is not type stable, but it doesn't matter since it is used
Expand Down

0 comments on commit 02a66f7

Please sign in to comment.