Skip to content

Commit ac607dc

Browse files
nsajkooscardssmith
andauthored
exact BigFloat to IEEE FP conversion in pure Julia (#50691)
There's lots of code, but most of it seems like it will be useful in general. For example, I think I'll use the changes in float.jl and rounding.jl to improve the #49749 PR. The changes in float.jl could also be used to refactor float.jl to remove many magic constants. Benchmarking script: ```julia using BenchmarkTools f(::Type{T} = BigFloat, n::Int = 2000) where {T} = rand(T, n) g!(u, v) = map!(eltype(u), u, v) @Btime g!(u, v) setup=(u = f(Float16); v = f();) @Btime g!(u, v) setup=(u = f(Float32); v = f();) @Btime g!(u, v) setup=(u = f(Float64); v = f();) ``` On master (dc06468): ``` 46.116 μs (0 allocations: 0 bytes) 38.842 μs (0 allocations: 0 bytes) 37.039 μs (0 allocations: 0 bytes) ``` With both this commit and #50674 applied: ``` 42.310 μs (0 allocations: 0 bytes) 42.661 μs (0 allocations: 0 bytes) 41.608 μs (0 allocations: 0 bytes) ``` So, with this benchmark at least, on an AMD Zen 2 laptop, conversion to `Float16` is faster, but there's a slowdown for `Float32` and `Float64`. Fixes #50642 (exact conversion to `Float16`) Co-authored-by: Oscar Smith <oscardssmith@gmail.com>
1 parent 61ebaf6 commit ac607dc

File tree

8 files changed

+445
-29
lines changed

8 files changed

+445
-29
lines changed

base/Base.jl

+1
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ include("hashing.jl")
224224
include("rounding.jl")
225225
using .Rounding
226226
include("div.jl")
227+
include("rawbigints.jl")
227228
include("float.jl")
228229
include("twiceprecision.jl")
229230
include("complex.jl")

base/float.jl

+62
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,68 @@ i.e. the maximum integer value representable by [`exponent_bits(T)`](@ref) bits.
137137
"""
138138
function exponent_raw_max end
139139

140+
"""
141+
IEEE 754 definition of the minimum exponent.
142+
"""
143+
ieee754_exponent_min(::Type{T}) where {T<:IEEEFloat} = Int(1 - exponent_max(T))::Int
144+
145+
exponent_min(::Type{Float16}) = ieee754_exponent_min(Float16)
146+
exponent_min(::Type{Float32}) = ieee754_exponent_min(Float32)
147+
exponent_min(::Type{Float64}) = ieee754_exponent_min(Float64)
148+
149+
function ieee754_representation(
150+
::Type{F}, sign_bit::Bool, exponent_field::Integer, significand_field::Integer
151+
) where {F<:IEEEFloat}
152+
T = uinttype(F)
153+
ret::T = sign_bit
154+
ret <<= exponent_bits(F)
155+
ret |= exponent_field
156+
ret <<= significand_bits(F)
157+
ret |= significand_field
158+
end
159+
160+
# ±floatmax(T)
161+
function ieee754_representation(
162+
::Type{F}, sign_bit::Bool, ::Val{:omega}
163+
) where {F<:IEEEFloat}
164+
ieee754_representation(F, sign_bit, exponent_raw_max(F) - 1, significand_mask(F))
165+
end
166+
167+
# NaN or an infinity
168+
function ieee754_representation(
169+
::Type{F}, sign_bit::Bool, significand_field::Integer, ::Val{:nan}
170+
) where {F<:IEEEFloat}
171+
ieee754_representation(F, sign_bit, exponent_raw_max(F), significand_field)
172+
end
173+
174+
# NaN with default payload
175+
function ieee754_representation(
176+
::Type{F}, sign_bit::Bool, ::Val{:nan}
177+
) where {F<:IEEEFloat}
178+
ieee754_representation(F, sign_bit, one(uinttype(F)) << (significand_bits(F) - 1), Val(:nan))
179+
end
180+
181+
# Infinity
182+
function ieee754_representation(
183+
::Type{F}, sign_bit::Bool, ::Val{:inf}
184+
) where {F<:IEEEFloat}
185+
ieee754_representation(F, sign_bit, false, Val(:nan))
186+
end
187+
188+
# Subnormal or zero
189+
function ieee754_representation(
190+
::Type{F}, sign_bit::Bool, significand_field::Integer, ::Val{:subnormal}
191+
) where {F<:IEEEFloat}
192+
ieee754_representation(F, sign_bit, false, significand_field)
193+
end
194+
195+
# Zero
196+
function ieee754_representation(
197+
::Type{F}, sign_bit::Bool, ::Val{:zero}
198+
) where {F<:IEEEFloat}
199+
ieee754_representation(F, sign_bit, false, Val(:subnormal))
200+
end
201+
140202
"""
141203
uabs(x::Integer)
142204

base/mpfr.jl

+84-28
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@ import
1717
cbrt, typemax, typemin, unsafe_trunc, floatmin, floatmax, rounding,
1818
setrounding, maxintfloat, widen, significand, frexp, tryparse, iszero,
1919
isone, big, _string_n, decompose, minmax,
20-
sinpi, cospi, sincospi, tanpi, sind, cosd, tand, asind, acosd, atand
20+
sinpi, cospi, sincospi, tanpi, sind, cosd, tand, asind, acosd, atand,
21+
uinttype, exponent_max, exponent_min, ieee754_representation, significand_mask,
22+
RawBigIntRoundingIncrementHelper, truncated, RawBigInt
2123

2224

2325
using .Base.Libc
24-
import ..Rounding: rounding_raw, setrounding_raw
26+
import ..Rounding:
27+
rounding_raw, setrounding_raw, rounds_to_nearest, rounds_away_from_zero,
28+
tie_breaker_is_to_even, correct_rounding_requires_increment
2529

2630
import ..GMP: ClongMax, CulongMax, CdoubleMax, Limb, libgmp
2731

@@ -89,6 +93,21 @@ function convert(::Type{RoundingMode}, r::MPFRRoundingMode)
8993
end
9094
end
9195

96+
rounds_to_nearest(m::MPFRRoundingMode) = m == MPFRRoundNearest
97+
function rounds_away_from_zero(m::MPFRRoundingMode, sign_bit::Bool)
98+
if m == MPFRRoundToZero
99+
false
100+
elseif m == MPFRRoundUp
101+
!sign_bit
102+
elseif m == MPFRRoundDown
103+
sign_bit
104+
else
105+
# Assuming `m == MPFRRoundFromZero`
106+
true
107+
end
108+
end
109+
tie_breaker_is_to_even(::MPFRRoundingMode) = true
110+
92111
const ROUNDING_MODE = Ref{MPFRRoundingMode}(MPFRRoundNearest)
93112
const DEFAULT_PRECISION = Ref{Clong}(256)
94113

@@ -136,6 +155,9 @@ mutable struct BigFloat <: AbstractFloat
136155
end
137156
end
138157

158+
# The rounding mode here shouldn't matter.
159+
significand_limb_count(x::BigFloat) = div(sizeof(x._d), sizeof(Limb), RoundToZero)
160+
139161
rounding_raw(::Type{BigFloat}) = ROUNDING_MODE[]
140162
setrounding_raw(::Type{BigFloat}, r::MPFRRoundingMode) = ROUNDING_MODE[]=r
141163

@@ -386,35 +408,69 @@ function (::Type{T})(x::BigFloat) where T<:Integer
386408
trunc(T,x)
387409
end
388410

389-
## BigFloat -> AbstractFloat
390-
_cpynansgn(x::AbstractFloat, y::BigFloat) = isnan(x) && signbit(x) != signbit(y) ? -x : x
391-
392-
Float64(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) =
393-
_cpynansgn(ccall((:mpfr_get_d,libmpfr), Float64, (Ref{BigFloat}, MPFRRoundingMode), x, r), x)
394-
Float64(x::BigFloat, r::RoundingMode) = Float64(x, convert(MPFRRoundingMode, r))
395-
396-
Float32(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) =
397-
_cpynansgn(ccall((:mpfr_get_flt,libmpfr), Float32, (Ref{BigFloat}, MPFRRoundingMode), x, r), x)
398-
Float32(x::BigFloat, r::RoundingMode) = Float32(x, convert(MPFRRoundingMode, r))
399-
400-
function Float16(x::BigFloat) :: Float16
401-
res = Float32(x)
402-
resi = reinterpret(UInt32, res)
403-
if (resi&0x7fffffff) < 0x38800000 # if Float16(res) is subnormal
404-
#shift so that the mantissa lines up where it would for normal Float16
405-
shift = 113-((resi & 0x7f800000)>>23)
406-
if shift<23
407-
resi |= 0x0080_0000 # set implicit bit
408-
resi >>= shift
411+
function to_ieee754(::Type{T}, x::BigFloat, rm) where {T<:AbstractFloat}
412+
sb = signbit(x)
413+
is_zero = iszero(x)
414+
is_inf = isinf(x)
415+
is_nan = isnan(x)
416+
is_regular = !is_zero & !is_inf & !is_nan
417+
ieee_exp = Int(x.exp) - 1
418+
ieee_precision = precision(T)
419+
ieee_exp_max = exponent_max(T)
420+
ieee_exp_min = exponent_min(T)
421+
exp_diff = ieee_exp - ieee_exp_min
422+
is_normal = 0 exp_diff
423+
(rm_is_to_zero, rm_is_from_zero) = if rounds_to_nearest(rm)
424+
(false, false)
425+
else
426+
let from = rounds_away_from_zero(rm, sb)
427+
(!from, from)
409428
end
410-
end
411-
if (resi & 0x1fff == 0x1000) # if we are halfway between 2 Float16 values
412-
# adjust the value by 1 ULP in the direction that will make Float16(res) give the right answer
413-
res = nextfloat(res, cmp(x, res))
414-
end
415-
return res
429+
end::NTuple{2,Bool}
430+
exp_is_huge_p = ieee_exp_max < ieee_exp
431+
exp_is_huge_n = signbit(exp_diff + ieee_precision)
432+
rounds_to_inf = is_regular & exp_is_huge_p & !rm_is_to_zero
433+
rounds_to_zero = is_regular & exp_is_huge_n & !rm_is_from_zero
434+
U = uinttype(T)
435+
436+
ret_u = if is_regular & !rounds_to_inf & !rounds_to_zero
437+
if !exp_is_huge_p
438+
# significand
439+
v = RawBigInt(x.d, significand_limb_count(x))
440+
len = max(ieee_precision + min(exp_diff, 0), 0)::Int
441+
signif = truncated(U, v, len) & significand_mask(T)
442+
443+
# round up if necessary
444+
rh = RawBigIntRoundingIncrementHelper(v, len)
445+
incr = correct_rounding_requires_increment(rh, rm, sb)
446+
447+
# exponent
448+
exp_field = max(exp_diff, 0) + is_normal
449+
450+
ieee754_representation(T, sb, exp_field, signif) + incr
451+
else
452+
ieee754_representation(T, sb, Val(:omega))
453+
end
454+
else
455+
if is_zero | rounds_to_zero
456+
ieee754_representation(T, sb, Val(:zero))
457+
elseif is_inf | rounds_to_inf
458+
ieee754_representation(T, sb, Val(:inf))
459+
else
460+
ieee754_representation(T, sb, Val(:nan))
461+
end
462+
end::U
463+
464+
reinterpret(T, ret_u)
416465
end
417466

467+
Float16(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) = to_ieee754(Float16, x, r)
468+
Float32(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) = to_ieee754(Float32, x, r)
469+
Float64(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) = to_ieee754(Float64, x, r)
470+
Float16(x::BigFloat, r::RoundingMode) = to_ieee754(Float16, x, r)
471+
Float32(x::BigFloat, r::RoundingMode) = to_ieee754(Float32, x, r)
472+
Float64(x::BigFloat, r::RoundingMode) = to_ieee754(Float64, x, r)
473+
418474
promote_rule(::Type{BigFloat}, ::Type{<:Real}) = BigFloat
419475
promote_rule(::Type{BigInt}, ::Type{<:AbstractFloat}) = BigFloat
420476
promote_rule(::Type{BigFloat}, ::Type{<:AbstractFloat}) = BigFloat

base/rawbigints.jl

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
"""
4+
Segment of raw words of bits interpreted as a big integer. Less
5+
significant words come first. Each word is in machine-native bit-order.
6+
"""
7+
struct RawBigInt{T<:Unsigned}
8+
d::Ptr{T}
9+
word_count::Int
10+
11+
function RawBigInt{T}(d::Ptr{T}, word_count::Int) where {T<:Unsigned}
12+
new{T}(d, word_count)
13+
end
14+
end
15+
16+
RawBigInt(d::Ptr{T}, word_count::Int) where {T<:Unsigned} = RawBigInt{T}(d, word_count)
17+
elem_count(x::RawBigInt, ::Val{:words}) = x.word_count
18+
elem_count(x::Unsigned, ::Val{:bits}) = sizeof(x) * 8
19+
word_length(::RawBigInt{T}) where {T} = elem_count(zero(T), Val(:bits))
20+
elem_count(x::RawBigInt{T}, ::Val{:bits}) where {T} = word_length(x) * elem_count(x, Val(:words))
21+
reversed_index(n::Int, i::Int) = n - i - 1
22+
reversed_index(x, i::Int, v::Val) = reversed_index(elem_count(x, v), i)::Int
23+
split_bit_index(x::RawBigInt, i::Int) = divrem(i, word_length(x), RoundToZero)
24+
25+
"""
26+
`i` is the zero-based index of the wanted word in `x`, starting from
27+
the less significant words.
28+
"""
29+
function get_elem(x::RawBigInt, i::Int, ::Val{:words}, ::Val{:ascending})
30+
unsafe_load(x.d, i + 1)
31+
end
32+
33+
function get_elem(x, i::Int, v::Val, ::Val{:descending})
34+
j = reversed_index(x, i, v)
35+
get_elem(x, j, v, Val(:ascending))
36+
end
37+
38+
word_is_nonzero(x::RawBigInt, i::Int, v::Val) = !iszero(get_elem(x, i, Val(:words), v))
39+
40+
word_is_nonzero(x::RawBigInt, v::Val) = let x = x
41+
i -> word_is_nonzero(x, i, v)
42+
end
43+
44+
"""
45+
Returns a `Bool` indicating whether the `len` least significant words
46+
of `x` are nonzero.
47+
"""
48+
function tail_is_nonzero(x::RawBigInt, len::Int, ::Val{:words})
49+
any(word_is_nonzero(x, Val(:ascending)), 0:(len - 1))
50+
end
51+
52+
"""
53+
Returns a `Bool` indicating whether the `len` least significant bits of
54+
the `i`-th (zero-based index) word of `x` are nonzero.
55+
"""
56+
function tail_is_nonzero(x::RawBigInt, len::Int, i::Int, ::Val{:word})
57+
!iszero(len) &&
58+
!iszero(get_elem(x, i, Val(:words), Val(:ascending)) << (word_length(x) - len))
59+
end
60+
61+
"""
62+
Returns a `Bool` indicating whether the `len` least significant bits of
63+
`x` are nonzero.
64+
"""
65+
function tail_is_nonzero(x::RawBigInt, len::Int, ::Val{:bits})
66+
if 0 < len
67+
word_count, bit_count_in_word = split_bit_index(x, len)
68+
tail_is_nonzero(x, bit_count_in_word, word_count, Val(:word)) ||
69+
tail_is_nonzero(x, word_count, Val(:words))
70+
else
71+
false
72+
end::Bool
73+
end
74+
75+
"""
76+
Returns a `Bool` that is the `i`-th (zero-based index) bit of `x`.
77+
"""
78+
function get_elem(x::Unsigned, i::Int, ::Val{:bits}, ::Val{:ascending})
79+
(x >>> i) % Bool
80+
end
81+
82+
"""
83+
Returns a `Bool` that is the `i`-th (zero-based index) bit of `x`.
84+
"""
85+
function get_elem(x::RawBigInt, i::Int, ::Val{:bits}, v::Val{:ascending})
86+
vb = Val(:bits)
87+
if 0 i < elem_count(x, vb)
88+
word_index, bit_index_in_word = split_bit_index(x, i)
89+
word = get_elem(x, word_index, Val(:words), v)
90+
get_elem(word, bit_index_in_word, vb, v)
91+
else
92+
false
93+
end::Bool
94+
end
95+
96+
"""
97+
Returns an integer of type `R`, consisting of the `len` most
98+
significant bits of `x`.
99+
"""
100+
function truncated(::Type{R}, x::RawBigInt, len::Int) where {R<:Integer}
101+
ret = zero(R)
102+
if 0 < len
103+
word_count, bit_count_in_word = split_bit_index(x, len)
104+
k = word_length(x)
105+
vals = (Val(:words), Val(:descending))
106+
107+
for w 0:(word_count - 1)
108+
ret <<= k
109+
word = get_elem(x, w, vals...)
110+
ret |= R(word)
111+
end
112+
113+
if !iszero(bit_count_in_word)
114+
ret <<= bit_count_in_word
115+
wrd = get_elem(x, word_count, vals...)
116+
ret |= R(wrd >>> (k - bit_count_in_word))
117+
end
118+
end
119+
ret::R
120+
end
121+
122+
struct RawBigIntRoundingIncrementHelper{T<:Unsigned}
123+
n::RawBigInt{T}
124+
trunc_len::Int
125+
126+
final_bit::Bool
127+
round_bit::Bool
128+
129+
function RawBigIntRoundingIncrementHelper{T}(n::RawBigInt{T}, len::Int) where {T<:Unsigned}
130+
vals = (Val(:bits), Val(:descending))
131+
f = get_elem(n, len - 1, vals...)
132+
r = get_elem(n, len , vals...)
133+
new{T}(n, len, f, r)
134+
end
135+
end
136+
137+
function RawBigIntRoundingIncrementHelper(n::RawBigInt{T}, len::Int) where {T<:Unsigned}
138+
RawBigIntRoundingIncrementHelper{T}(n, len)
139+
end
140+
141+
(h::RawBigIntRoundingIncrementHelper)(::Rounding.FinalBit) = h.final_bit
142+
143+
(h::RawBigIntRoundingIncrementHelper)(::Rounding.RoundBit) = h.round_bit
144+
145+
function (h::RawBigIntRoundingIncrementHelper)(::Rounding.StickyBit)
146+
v = Val(:bits)
147+
n = h.n
148+
tail_is_nonzero(n, elem_count(n, v) - h.trunc_len - 1, v)
149+
end

0 commit comments

Comments
 (0)