Skip to content

Commit b077c63

Browse files
authored
Overhaul of rounding for rational numbers (#34658)
* Simplify rounding for Rationals (fix #34645) * Fix rounding for infinite Rationals (fix #34657) * Remove explicit DivideError * Type-stable rounding with digits/sigdigits
1 parent a916783 commit b077c63

File tree

4 files changed

+90
-54
lines changed

4 files changed

+90
-54
lines changed

base/floatfuncs.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ end
128128
# NOTE: this relies on the current keyword dispatch behaviour (#9498).
129129
function round(x::Real, r::RoundingMode=RoundNearest;
130130
digits::Union{Nothing,Integer}=nothing, sigdigits::Union{Nothing,Integer}=nothing, base::Union{Nothing,Integer}=nothing)
131-
isfinite(x) || return x
132131
if digits === nothing
133132
if sigdigits === nothing
134133
if base === nothing
@@ -139,10 +138,12 @@ function round(x::Real, r::RoundingMode=RoundNearest;
139138
# or throw(ArgumentError("`round` cannot use `base` argument without `digits` or `sigdigits` arguments."))
140139
end
141140
else
141+
isfinite(x) || return float(x)
142142
_round_sigdigits(x, r, sigdigits, base === nothing ? 10 : base)
143143
end
144144
else
145145
if sigdigits === nothing
146+
isfinite(x) || return float(x)
146147
_round_digits(x, r, digits, base === nothing ? 10 : base)
147148
else
148149
throw(ArgumentError("`round` cannot use both `digits` and `sigdigits` arguments."))

base/missing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ round(::Type{T}, ::Missing, ::RoundingMode=RoundNearest) where {T} =
135135
throw(MissingException("cannot convert a missing value to type $T: use Union{$T, Missing} instead"))
136136
round(::Type{T}, x::Any, r::RoundingMode=RoundNearest) where {T>:Missing} = round(nonmissingtype_checked(T), x, r)
137137
# to fix ambiguities
138-
round(::Type{T}, x::Rational, r::RoundingMode=RoundNearest) where {T>:Missing} = round(nonmissingtype_checked(T), x, r)
138+
round(::Type{T}, x::Rational{Tr}, r::RoundingMode=RoundNearest) where {T>:Missing,Tr} = round(nonmissingtype_checked(T), x, r)
139139
round(::Type{T}, x::Rational{Bool}, r::RoundingMode=RoundNearest) where {T>:Missing} = round(nonmissingtype_checked(T), x, r)
140140

141141
# Handle ceil, floor, and trunc separately as they have no RoundingMode argument

base/rational.jl

Lines changed: 8 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -395,66 +395,26 @@ for (S, T) in ((Rational, Integer), (Integer, Rational), (Rational, Rational))
395395
end
396396
end
397397

398-
trunc(::Type{T}, x::Rational) where {T} = convert(T,div(x.num,x.den))
399-
floor(::Type{T}, x::Rational) where {T} = convert(T,fld(x.num,x.den))
400-
ceil(::Type{T}, x::Rational) where {T} = convert(T,cld(x.num,x.den))
401-
round(::Type{T}, x::Rational, r::RoundingMode=RoundNearest) where {T} = _round_rational(T, x, r)
402-
round(x::Rational, r::RoundingMode) = round(Rational, x, r)
403-
404-
function _round_rational(::Type{T}, x::Rational{Tr}, ::RoundingMode{:Nearest}) where {T,Tr}
405-
if denominator(x) == zero(Tr) && T <: Integer
406-
throw(DivideError())
407-
elseif denominator(x) == zero(Tr)
408-
return convert(T, copysign(one(Tr)//zero(Tr), numerator(x)))
409-
end
410-
q,r = divrem(numerator(x), denominator(x))
411-
s = q
412-
if abs(r) >= abs((denominator(x)-copysign(Tr(4), numerator(x))+one(Tr)+iseven(q))>>1 + copysign(Tr(2), numerator(x)))
413-
s += copysign(one(Tr),numerator(x))
414-
end
415-
convert(T, s)
416-
end
398+
trunc(::Type{T}, x::Rational) where {T} = round(T, x, RoundToZero)
399+
floor(::Type{T}, x::Rational) where {T} = round(T, x, RoundDown)
400+
ceil(::Type{T}, x::Rational) where {T} = round(T, x, RoundUp)
417401

418-
function _round_rational(::Type{T}, x::Rational{Tr}, ::RoundingMode{:NearestTiesAway}) where {T,Tr}
419-
if denominator(x) == zero(Tr) && T <: Integer
420-
throw(DivideError())
421-
elseif denominator(x) == zero(Tr)
422-
return convert(T, copysign(one(Tr)//zero(Tr), numerator(x)))
423-
end
424-
q,r = divrem(numerator(x), denominator(x))
425-
s = q
426-
if abs(r) >= abs((denominator(x)-copysign(Tr(4), numerator(x))+one(Tr))>>1 + copysign(Tr(2), numerator(x)))
427-
s += copysign(one(Tr),numerator(x))
428-
end
429-
convert(T, s)
430-
end
402+
round(x::Rational, r::RoundingMode=RoundNearest) = round(typeof(x), x, r)
431403

432-
function _round_rational(::Type{T}, x::Rational{Tr}, ::RoundingMode{:NearestTiesUp}) where {T,Tr}
433-
if denominator(x) == zero(Tr) && T <: Integer
434-
throw(DivideError())
435-
elseif denominator(x) == zero(Tr)
404+
function round(::Type{T}, x::Rational{Tr}, r::RoundingMode=RoundNearest) where {T,Tr}
405+
if iszero(denominator(x)) && !(T <: Integer)
436406
return convert(T, copysign(one(Tr)//zero(Tr), numerator(x)))
437407
end
438-
q,r = divrem(numerator(x), denominator(x))
439-
s = q
440-
if abs(r) >= abs((denominator(x)-copysign(Tr(4), numerator(x))+one(Tr)+(numerator(x)<0))>>1 + copysign(Tr(2), numerator(x)))
441-
s += copysign(one(Tr),numerator(x))
442-
end
443-
convert(T, s)
408+
convert(T, div(numerator(x), denominator(x), r))
444409
end
445410

446411
function round(::Type{T}, x::Rational{Bool}, ::RoundingMode=RoundNearest) where T
447-
if denominator(x) == false && (T <: Union{Integer, Bool})
412+
if denominator(x) == false && (T <: Integer)
448413
throw(DivideError())
449414
end
450415
convert(T, x)
451416
end
452417

453-
trunc(x::Rational{T}) where {T} = Rational(trunc(T,x))
454-
floor(x::Rational{T}) where {T} = Rational(floor(T,x))
455-
ceil(x::Rational{T}) where {T} = Rational(ceil(T,x))
456-
round(x::Rational{T}) where {T} = Rational(round(T,x))
457-
458418
function ^(x::Rational, n::Integer)
459419
n >= 0 ? power_by_squaring(x,n) : power_by_squaring(inv(x),-n)
460420
end

test/rational.jl

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,22 +247,97 @@ end
247247
end
248248

249249
@testset "round" begin
250-
@test round(11//2) == 6//1 # rounds to closest _even_ integer
251-
@test round(-11//2) == -6//1 # rounds to closest _even_ integer
252-
@test round(11//3) == 4//1 # rounds to closest _even_ integer
253-
@test round(-11//3) == -4//1 # rounds to closest _even_ integer
250+
@test round(11//2) == round(11//2, RoundNearest) == 6//1 # rounds to closest _even_ integer
251+
@test round(-11//2) == round(-11//2, RoundNearest) == -6//1 # rounds to closest _even_ integer
252+
@test round(13//2) == round(13//2, RoundNearest) == 6//1 # rounds to closest _even_ integer
253+
@test round(-13//2) == round(-13//2, RoundNearest) == -6//1 # rounds to closest _even_ integer
254+
@test round(11//3) == round(11//3, RoundNearest) == 4//1 # rounds to closest _even_ integer
255+
@test round(-11//3) == round(-11//3, RoundNearest) == -4//1 # rounds to closest _even_ integer
256+
257+
@test round(11//2, RoundNearestTiesAway) == 6//1
258+
@test round(-11//2, RoundNearestTiesAway) == -6//1
259+
@test round(13//2, RoundNearestTiesAway) == 7//1
260+
@test round(-13//2, RoundNearestTiesAway) == -7//1
261+
@test round(11//3, RoundNearestTiesAway) == 4//1
262+
@test round(-11//3, RoundNearestTiesAway) == -4//1
263+
264+
@test round(11//2, RoundNearestTiesUp) == 6//1
265+
@test round(-11//2, RoundNearestTiesUp) == -5//1
266+
@test round(13//2, RoundNearestTiesUp) == 7//1
267+
@test round(-13//2, RoundNearestTiesUp) == -6//1
268+
@test round(11//3, RoundNearestTiesUp) == 4//1
269+
@test round(-11//3, RoundNearestTiesUp) == -4//1
270+
271+
@test trunc(11//2) == round(11//2, RoundToZero) == 5//1
272+
@test trunc(-11//2) == round(-11//2, RoundToZero) == -5//1
273+
@test trunc(13//2) == round(13//2, RoundToZero) == 6//1
274+
@test trunc(-13//2) == round(-13//2, RoundToZero) == -6//1
275+
@test trunc(11//3) == round(11//3, RoundToZero) == 3//1
276+
@test trunc(-11//3) == round(-11//3, RoundToZero) == -3//1
277+
278+
@test ceil(11//2) == round(11//2, RoundUp) == 6//1
279+
@test ceil(-11//2) == round(-11//2, RoundUp) == -5//1
280+
@test ceil(13//2) == round(13//2, RoundUp) == 7//1
281+
@test ceil(-13//2) == round(-13//2, RoundUp) == -6//1
282+
@test ceil(11//3) == round(11//3, RoundUp) == 4//1
283+
@test ceil(-11//3) == round(-11//3, RoundUp) == -3//1
284+
285+
@test floor(11//2) == round(11//2, RoundDown) == 5//1
286+
@test floor(-11//2) == round(-11//2, RoundDown) == -6//1
287+
@test floor(13//2) == round(13//2, RoundDown) == 6//1
288+
@test floor(-13//2) == round(-13//2, RoundDown) == -7//1
289+
@test floor(11//3) == round(11//3, RoundDown) == 3//1
290+
@test floor(-11//3) == round(-11//3, RoundDown) == -4//1
254291

255292
for T in (Float16, Float32, Float64)
256293
@test round(T, true//false) === convert(T, Inf)
257294
@test round(T, true//true) === one(T)
258295
@test round(T, false//true) === zero(T)
296+
@test trunc(T, true//false) === convert(T, Inf)
297+
@test trunc(T, true//true) === one(T)
298+
@test trunc(T, false//true) === zero(T)
299+
@test floor(T, true//false) === convert(T, Inf)
300+
@test floor(T, true//true) === one(T)
301+
@test floor(T, false//true) === zero(T)
302+
@test ceil(T, true//false) === convert(T, Inf)
303+
@test ceil(T, true//true) === one(T)
304+
@test ceil(T, false//true) === zero(T)
259305
end
260306

261307
for T in (Int8, Int16, Int32, Int64, Bool)
262308
@test_throws DivideError round(T, true//false)
263309
@test round(T, true//true) === one(T)
264310
@test round(T, false//true) === zero(T)
311+
@test_throws DivideError trunc(T, true//false)
312+
@test trunc(T, true//true) === one(T)
313+
@test trunc(T, false//true) === zero(T)
314+
@test_throws DivideError floor(T, true//false)
315+
@test floor(T, true//true) === one(T)
316+
@test floor(T, false//true) === zero(T)
317+
@test_throws DivideError ceil(T, true//false)
318+
@test ceil(T, true//true) === one(T)
319+
@test ceil(T, false//true) === zero(T)
265320
end
321+
322+
# issue 34657
323+
@test round(1//0) === round(Rational, 1//0) === 1//0
324+
@test trunc(1//0) === trunc(Rational, 1//0) === 1//0
325+
@test floor(1//0) === floor(Rational, 1//0) === 1//0
326+
@test ceil(1//0) === ceil(Rational, 1//0) === 1//0
327+
@test round(-1//0) === round(Rational, -1//0) === -1//0
328+
@test trunc(-1//0) === trunc(Rational, -1//0) === -1//0
329+
@test floor(-1//0) === floor(Rational, -1//0) === -1//0
330+
@test ceil(-1//0) === ceil(Rational, -1//0) === -1//0
331+
for r = [RoundNearest, RoundNearestTiesAway, RoundNearestTiesUp,
332+
RoundToZero, RoundUp, RoundDown]
333+
@test round(1//0, r) === 1//0
334+
@test round(-1//0, r) === -1//0
335+
end
336+
337+
@test @inferred(round(1//0, digits=1)) === Inf
338+
@test @inferred(trunc(1//0, digits=2)) === Inf
339+
@test @inferred(floor(-1//0, sigdigits=1)) === -Inf
340+
@test @inferred(ceil(-1//0, sigdigits=2)) === -Inf
266341
end
267342

268343
@testset "issue 1552" begin

0 commit comments

Comments
 (0)