Skip to content

Commit 2e8bf1d

Browse files
committed
exact BigFloat to IEEE FP conversion in pure Julia
There's lots of code, but most of it seems like it will be useful in general. For example, I think I'll use the changes in float.jl and rounding.jl to improve the JuliaLang#49749 PR. The changes in float.jl could also be used to refactor float.jl to remove many magic constants. Benchmarking script: ```julia using BenchmarkTools f(::Type{T} = BigFloat, n::Int = 2000) where {T} = rand(T, n) g!(u, v) = map!(eltype(u), u, v) @Btime g!(u, v) setup=(u = f(Float16); v = f();) @Btime g!(u, v) setup=(u = f(Float32); v = f();) @Btime g!(u, v) setup=(u = f(Float64); v = f();) ``` On master (dc06468): ``` 46.116 μs (0 allocations: 0 bytes) 38.842 μs (0 allocations: 0 bytes) 37.039 μs (0 allocations: 0 bytes) ``` With both this commit and JuliaLang#50674 applied: ``` 42.310 μs (0 allocations: 0 bytes) 42.661 μs (0 allocations: 0 bytes) 41.608 μs (0 allocations: 0 bytes) ``` So, with this benchmark at least, on an AMD Zen 2 laptop, conversion to `Float16` is faster, but there's a slowdown for `Float32` and `Float64`. Fixes JuliaLang#50642 (exact conversion to `Float16`)
1 parent dc06468 commit 2e8bf1d

File tree

8 files changed

+420
-29
lines changed

8 files changed

+420
-29
lines changed

base/Base.jl

+1
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ include("hashing.jl")
224224
include("rounding.jl")
225225
using .Rounding
226226
include("div.jl")
227+
include("rawbigints.jl")
227228
include("float.jl")
228229
include("twiceprecision.jl")
229230
include("complex.jl")

base/float.jl

+62
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,68 @@ i.e. the maximum integer value representable by [`exponent_bits(T)`](@ref) bits.
137137
"""
138138
function exponent_raw_max end
139139

140+
"""
141+
IEEE 754 definition of the minimum exponent.
142+
"""
143+
ieee754_exponent_min(::Type{T}) where {T<:IEEEFloat} = Int(1 - exponent_max(T))::Int
144+
145+
exponent_min(::Type{Float16}) = ieee754_exponent_min(Float16)
146+
exponent_min(::Type{Float32}) = ieee754_exponent_min(Float32)
147+
exponent_min(::Type{Float64}) = ieee754_exponent_min(Float64)
148+
149+
function ieee754_representation(
150+
::Type{F}, sign_bit::Bool, exponent_field::Integer, significand_field::Integer
151+
) where {F<:IEEEFloat}
152+
T = uinttype(F)
153+
ret::T = sign_bit
154+
ret <<= exponent_bits(F)
155+
ret |= exponent_field
156+
ret <<= significand_bits(F)
157+
ret |= significand_field
158+
end
159+
160+
# ±floatmax(T)
161+
function ieee754_representation(
162+
::Type{F}, sign_bit::Bool, ::Val{:omega}
163+
) where {F<:IEEEFloat}
164+
ieee754_representation(F, sign_bit, exponent_raw_max(F) - 1, significand_mask(F))
165+
end
166+
167+
# NaN or an infinity
168+
function ieee754_representation(
169+
::Type{F}, sign_bit::Bool, significand_field::Integer, ::Val{:nan}
170+
) where {F<:IEEEFloat}
171+
ieee754_representation(F, sign_bit, exponent_raw_max(F), significand_field)
172+
end
173+
174+
# NaN with default payload
175+
function ieee754_representation(
176+
::Type{F}, sign_bit::Bool, ::Val{:nan}
177+
) where {F<:IEEEFloat}
178+
ieee754_representation(F, sign_bit, one(uinttype(F)) << (significand_bits(F) - 1), Val(:nan))
179+
end
180+
181+
# Infinity
182+
function ieee754_representation(
183+
::Type{F}, sign_bit::Bool, ::Val{:inf}
184+
) where {F<:IEEEFloat}
185+
ieee754_representation(F, sign_bit, false, Val(:nan))
186+
end
187+
188+
# Subnormal or zero
189+
function ieee754_representation(
190+
::Type{F}, sign_bit::Bool, significand_field::Integer, ::Val{:subnormal}
191+
) where {F<:IEEEFloat}
192+
ieee754_representation(F, sign_bit, false, significand_field)
193+
end
194+
195+
# Zero
196+
function ieee754_representation(
197+
::Type{F}, sign_bit::Bool, ::Val{:zero}
198+
) where {F<:IEEEFloat}
199+
ieee754_representation(F, sign_bit, false, Val(:subnormal))
200+
end
201+
140202
"""
141203
uabs(x::Integer)
142204

base/mpfr.jl

+84-28
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@ import
1717
cbrt, typemax, typemin, unsafe_trunc, floatmin, floatmax, rounding,
1818
setrounding, maxintfloat, widen, significand, frexp, tryparse, iszero,
1919
isone, big, _string_n, decompose, minmax,
20-
sinpi, cospi, sincospi, tanpi, sind, cosd, tand, asind, acosd, atand
20+
sinpi, cospi, sincospi, tanpi, sind, cosd, tand, asind, acosd, atand,
21+
uinttype, exponent_max, exponent_min, ieee754_representation, significand_mask,
22+
RawBigIntRoundingIncrementHelper, truncated, RawBigInt
2123

2224

2325
using .Base.Libc
24-
import ..Rounding: rounding_raw, setrounding_raw
26+
import ..Rounding:
27+
rounding_raw, setrounding_raw, rounds_to_nearest, rounds_away_from_zero,
28+
tie_breaker_is_to_even, correct_rounding_requires_increment
2529

2630
import ..GMP: ClongMax, CulongMax, CdoubleMax, Limb, libgmp
2731

@@ -89,6 +93,21 @@ function convert(::Type{RoundingMode}, r::MPFRRoundingMode)
8993
end
9094
end
9195

96+
rounds_to_nearest(m::MPFRRoundingMode) = m == MPFRRoundNearest
97+
function rounds_away_from_zero(m::MPFRRoundingMode, sign_bit::Bool)
98+
if m == MPFRRoundToZero
99+
false
100+
elseif m == MPFRRoundUp
101+
!sign_bit
102+
elseif m == MPFRRoundDown
103+
sign_bit
104+
else
105+
# Assuming `m == MPFRRoundFromZero`
106+
true
107+
end
108+
end
109+
tie_breaker_is_to_even(::MPFRRoundingMode) = true
110+
92111
const ROUNDING_MODE = Ref{MPFRRoundingMode}(MPFRRoundNearest)
93112
const DEFAULT_PRECISION = Ref{Clong}(256)
94113

@@ -130,6 +149,9 @@ mutable struct BigFloat <: AbstractFloat
130149
end
131150
end
132151

152+
# The rounding mode here shouldn't matter.
153+
significand_limb_count(x::BigFloat) = div(sizeof(x._d), sizeof(Limb), RoundToZero)
154+
133155
rounding_raw(::Type{BigFloat}) = ROUNDING_MODE[]
134156
setrounding_raw(::Type{BigFloat}, r::MPFRRoundingMode) = ROUNDING_MODE[]=r
135157

@@ -380,35 +402,69 @@ function (::Type{T})(x::BigFloat) where T<:Integer
380402
trunc(T,x)
381403
end
382404

383-
## BigFloat -> AbstractFloat
384-
_cpynansgn(x::AbstractFloat, y::BigFloat) = isnan(x) && signbit(x) != signbit(y) ? -x : x
385-
386-
Float64(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) =
387-
_cpynansgn(ccall((:mpfr_get_d,libmpfr), Float64, (Ref{BigFloat}, MPFRRoundingMode), x, r), x)
388-
Float64(x::BigFloat, r::RoundingMode) = Float64(x, convert(MPFRRoundingMode, r))
389-
390-
Float32(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) =
391-
_cpynansgn(ccall((:mpfr_get_flt,libmpfr), Float32, (Ref{BigFloat}, MPFRRoundingMode), x, r), x)
392-
Float32(x::BigFloat, r::RoundingMode) = Float32(x, convert(MPFRRoundingMode, r))
393-
394-
function Float16(x::BigFloat) :: Float16
395-
res = Float32(x)
396-
resi = reinterpret(UInt32, res)
397-
if (resi&0x7fffffff) < 0x38800000 # if Float16(res) is subnormal
398-
#shift so that the mantissa lines up where it would for normal Float16
399-
shift = 113-((resi & 0x7f800000)>>23)
400-
if shift<23
401-
resi |= 0x0080_0000 # set implicit bit
402-
resi >>= shift
405+
function to_ieee754(::Type{T}, x::BigFloat, rm) where {T<:AbstractFloat}
406+
sb = signbit(x)
407+
is_zero = iszero(x)
408+
is_inf = isinf(x)
409+
is_nan = isnan(x)
410+
is_regular = !is_zero & !is_inf & !is_nan
411+
ieee_exp = Int(x.exp) - 1
412+
ieee_precision = precision(T)
413+
ieee_exp_max = exponent_max(T)
414+
ieee_exp_min = exponent_min(T)
415+
exp_diff = ieee_exp - ieee_exp_min
416+
is_normal = 0 exp_diff
417+
(rm_is_to_zero, rm_is_from_zero) = if rounds_to_nearest(rm)
418+
(false, false)
419+
else
420+
let from = rounds_away_from_zero(rm, sb)
421+
(!from, from)
403422
end
404-
end
405-
if (resi & 0x1fff == 0x1000) # if we are halfway between 2 Float16 values
406-
# adjust the value by 1 ULP in the direction that will make Float16(res) give the right answer
407-
res = nextfloat(res, cmp(x, res))
408-
end
409-
return res
423+
end::NTuple{2,Bool}
424+
exp_is_huge_p = ieee_exp_max < ieee_exp
425+
exp_is_huge_n = signbit(exp_diff + ieee_precision)
426+
rounds_to_inf = is_regular & exp_is_huge_p & !rm_is_to_zero
427+
rounds_to_zero = is_regular & exp_is_huge_n & !rm_is_from_zero
428+
U = uinttype(T)
429+
430+
ret_u = if is_regular & !rounds_to_inf & !rounds_to_zero
431+
if !exp_is_huge_p
432+
# significand
433+
v = RawBigInt(x.d, significand_limb_count(x))
434+
len = max(ieee_precision + min(exp_diff, 0), 0)::Int
435+
signif = truncated(U, v, len) & significand_mask(T)
436+
437+
# round up if necessary
438+
rh = RawBigIntRoundingIncrementHelper(v, len)
439+
incr = correct_rounding_requires_increment(rh, rm, sb)
440+
441+
# exponent
442+
exp_field = max(exp_diff, 0) + is_normal
443+
444+
ieee754_representation(T, sb, exp_field, signif) + incr
445+
else
446+
ieee754_representation(T, sb, Val(:omega))
447+
end
448+
else
449+
if is_zero | rounds_to_zero
450+
ieee754_representation(T, sb, Val(:zero))
451+
elseif is_inf | rounds_to_inf
452+
ieee754_representation(T, sb, Val(:inf))
453+
else
454+
ieee754_representation(T, sb, Val(:nan))
455+
end
456+
end::U
457+
458+
reinterpret(T, ret_u)
410459
end
411460

461+
Float16(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) = to_ieee754(Float16, x, r)
462+
Float32(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) = to_ieee754(Float32, x, r)
463+
Float64(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) = to_ieee754(Float64, x, r)
464+
Float16(x::BigFloat, r::RoundingMode) = to_ieee754(Float16, x, r)
465+
Float32(x::BigFloat, r::RoundingMode) = to_ieee754(Float32, x, r)
466+
Float64(x::BigFloat, r::RoundingMode) = to_ieee754(Float64, x, r)
467+
412468
promote_rule(::Type{BigFloat}, ::Type{<:Real}) = BigFloat
413469
promote_rule(::Type{BigInt}, ::Type{<:AbstractFloat}) = BigFloat
414470
promote_rule(::Type{BigFloat}, ::Type{<:AbstractFloat}) = BigFloat

base/rawbigints.jl

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
"""
4+
Segment of raw words of bits interpreted as a big integer. Less
5+
significant words come first. Each word is in machine-native bit-order.
6+
"""
7+
struct RawBigInt{T<:Unsigned}
8+
d::Ptr{T}
9+
word_count::Int
10+
11+
function RawBigInt{T}(d::Ptr{T}, word_count::Int) where {T<:Unsigned}
12+
new{T}(d, word_count)
13+
end
14+
end
15+
16+
RawBigInt(d::Ptr{T}, word_count::Int) where {T<:Unsigned} = RawBigInt{T}(d, word_count)
17+
elem_count(x::RawBigInt, ::Val{:words}) = x.word_count
18+
elem_count(x::Unsigned, ::Val{:bits}) = sizeof(x) * 8
19+
word_length(::RawBigInt{T}) where {T} = elem_count(zero(T), Val(:bits))
20+
elem_count(x::RawBigInt{T}, ::Val{:bits}) where {T} = word_length(x) * elem_count(x, Val(:words))
21+
reversed_index(n::Int, i::Int) = n - i - 1
22+
reversed_index(x, i::Int, v::Val) = reversed_index(elem_count(x, v), i)::Int
23+
split_bit_index(x::RawBigInt, i::Int) = divrem(i, word_length(x), RoundToZero)
24+
25+
"""
26+
`i` is the zero-based index of the wanted word in `x`, starting from
27+
the less significant words.
28+
"""
29+
function get_elem(x::RawBigInt, i::Int, ::Val{:words}, ::Val{:ascending})
30+
unsafe_load(x.d, i + 1)
31+
end
32+
33+
function get_elem(x, i::Int, v::Val, ::Val{:descending})
34+
j = reversed_index(x, i, v)
35+
get_elem(x, j, v, Val(:ascending))
36+
end
37+
38+
word_is_nonzero(x::RawBigInt, i::Int, v::Val) = !iszero(get_elem(x, i, Val(:words), v))
39+
40+
word_is_nonzero(x::RawBigInt, v::Val) = let x = x
41+
i -> word_is_nonzero(x, i, v)
42+
end
43+
44+
"""
45+
Returns a `Bool` indicating whether the `len` least significant words
46+
of `x` are nonzero.
47+
"""
48+
function tail_is_nonzero(x::RawBigInt, len::Int, ::Val{:words})
49+
any(word_is_nonzero(x, Val(:ascending)), 0:(len - 1))
50+
end
51+
52+
"""
53+
Returns a `Bool` indicating whether the `len` least significant bits of
54+
the `i`-th (zero-based index) word of `x` are nonzero.
55+
"""
56+
function tail_is_nonzero(x::RawBigInt, len::Int, i::Int, ::Val{:word})
57+
!iszero(len) &&
58+
!iszero(get_elem(x, i, Val(:words), Val(:ascending)) << (word_length(x) - len))
59+
end
60+
61+
"""
62+
Returns a `Bool` indicating whether the `len` least significant bits of
63+
`x` are nonzero.
64+
"""
65+
function tail_is_nonzero(x::RawBigInt, len::Int, ::Val{:bits})
66+
if 0 < len
67+
word_count, bit_count_in_word = split_bit_index(x, len)
68+
tail_is_nonzero(x, bit_count_in_word, word_count, Val(:word)) ||
69+
tail_is_nonzero(x, word_count, Val(:words))
70+
else
71+
false
72+
end::Bool
73+
end
74+
75+
"""
76+
Returns a `Bool` that is the `i`-th (zero-based index) bit of `x`.
77+
"""
78+
function get_elem(x::Unsigned, i::Int, ::Val{:bits}, ::Val{:ascending})
79+
(x >>> i) % Bool
80+
end
81+
82+
"""
83+
Returns a `Bool` that is the `i`-th (zero-based index) bit of `x`.
84+
"""
85+
function get_elem(x::RawBigInt, i::Int, ::Val{:bits}, v::Val{:ascending})
86+
vb = Val(:bits)
87+
if 0 i < elem_count(x, vb)
88+
word_index, bit_index_in_word = split_bit_index(x, i)
89+
word = get_elem(x, word_index, Val(:words), v)
90+
get_elem(word, bit_index_in_word, vb, v)
91+
else
92+
false
93+
end::Bool
94+
end
95+
96+
"""
97+
Returns an integer of type `R`, consisting of the `len` most
98+
significant bits of `x`.
99+
"""
100+
function truncated(::Type{R}, x::RawBigInt, len::Int) where {R<:Integer}
101+
ret = zero(R)
102+
if 0 < len
103+
word_count, bit_count_in_word = split_bit_index(x, len)
104+
k = word_length(x)
105+
vals = (Val(:words), Val(:descending))
106+
107+
for w 0:(word_count - 1)
108+
ret <<= k
109+
word = get_elem(x, w, vals...)
110+
ret |= R(word)
111+
end
112+
113+
if !iszero(bit_count_in_word)
114+
ret <<= bit_count_in_word
115+
wrd = get_elem(x, word_count, vals...)
116+
ret |= R(wrd >>> (k - bit_count_in_word))
117+
end
118+
end
119+
ret::R
120+
end
121+
122+
struct RawBigIntRoundingIncrementHelper{T<:Unsigned}
123+
n::RawBigInt{T}
124+
trunc_len::Int
125+
126+
final_bit::Bool
127+
round_bit::Bool
128+
129+
function RawBigIntRoundingIncrementHelper{T}(n::RawBigInt{T}, len::Int) where {T<:Unsigned}
130+
vals = (Val(:bits), Val(:descending))
131+
f = get_elem(n, len - 1, vals...)
132+
r = get_elem(n, len , vals...)
133+
new{T}(n, len, f, r)
134+
end
135+
end
136+
137+
function RawBigIntRoundingIncrementHelper(n::RawBigInt{T}, len::Int) where {T<:Unsigned}
138+
RawBigIntRoundingIncrementHelper{T}(n, len)
139+
end
140+
141+
(h::RawBigIntRoundingIncrementHelper)(::Rounding.FinalBit) = h.final_bit
142+
143+
(h::RawBigIntRoundingIncrementHelper)(::Rounding.RoundBit) = h.round_bit
144+
145+
function (h::RawBigIntRoundingIncrementHelper)(::Rounding.StickyBit)
146+
v = Val(:bits)
147+
n = h.n
148+
tail_is_nonzero(n, elem_count(n, v) - h.trunc_len - 1, v)
149+
end

0 commit comments

Comments
 (0)