Skip to content

Commit

Permalink
Further saturating arithmetic test and implementation cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentyrone committed Apr 27, 2023
1 parent f00824e commit fe15113
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 34 deletions.
70 changes: 44 additions & 26 deletions Sources/IntegerUtilities/SaturatingArithmetic.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,15 @@
//===----------------------------------------------------------------------===//

extension FixedWidthInteger {
@_transparent @usableFromInline
var sextOrZext: Self { self >> Self.bitWidth }
/// `~0` (all-ones) if this value is negative, otherwise `0`.
///
/// Note that if `Self` is unsigned, this always returns `0`,
/// but it is useful for writing algorithms that are generic over
/// signed and unsigned integers.
@inline(__always) @usableFromInline
var signbit: Self {
return self < .zero ? ~.zero : .zero
}

/// Saturating integer addition
///
Expand All @@ -29,8 +36,7 @@ extension FixedWidthInteger {
@inlinable
public func addingWithSaturation(_ other: Self) -> Self {
let (wrapped, overflow) = addingReportingOverflow(other)
if !overflow { return wrapped }
return Self.max &- sextOrZext
return overflow ? Self.max &- signbit : wrapped
}

/// Saturating integer subtraction
Expand All @@ -54,7 +60,7 @@ extension FixedWidthInteger {
public func subtractingWithSaturation(_ other: Self) -> Self {
let (wrapped, overflow) = subtractingReportingOverflow(other)
if !overflow { return wrapped }
return Self.isSigned ? Self.max &- sextOrZext : 0
return Self.isSigned ? Self.max &- signbit : 0
}

/// Saturating integer negation
Expand Down Expand Up @@ -85,10 +91,10 @@ extension FixedWidthInteger {
public func multipliedWithSaturation(by other: Self) -> Self {
let (high, low) = multipliedFullWidth(by: other)
let wrapped = Self(truncatingIfNeeded: low)
if high == wrapped.sextOrZext { return wrapped }
return Self.max &- high.sextOrZext
if high == wrapped.signbit { return wrapped }
return Self.max &- high.signbit
}

/// Bitwise left with rounding and saturation.
///
/// `self` multiplied by the rational number 2^(`count`), saturated to the
Expand All @@ -102,28 +108,20 @@ extension FixedWidthInteger {
/// and if negative a right shift.
/// - rounding rule: the direction in which to round if `count` is negative.
@inlinable
public func shiftedWithSaturation<Count: BinaryInteger>(
leftBy count: Count, rounding rule: RoundingRule = .down
public func shiftedWithSaturation(
leftBy count: Int,
rounding rule: RoundingRule = .down
) -> Self {
// If count is zero or negative, negate it and do a right
// shift without saturation instead, since we already have
// that implemented.
if count == 0 { return self }
// If count is negative, negate it and do a right shift without
// saturation instead, since we already have that implemented.
guard count > 0 else {
// negating count is tricky, because count's type can be
// an arbitrary BinaryInteger; in particular, it could be
// .min of a signed type, so that its negation cannot be
// represented in the same type. Fortunately, Int64 is
// always big enough to represent arbitrary shifts of
// arbitrary types, so we can use that as an intermediate
// type, and then we can use negatedWithSaturation() to
// handle the .min case.
let int64Count = Int64(clamping: count)
return shifted(
rightBy: int64Count.negatedWithSaturation(),
rightBy: count.negatedWithSaturation(),
rounding: rule
)
}
let clamped = Self.max &- sextOrZext
let clamped = Self.max &- signbit
guard count < Self.bitWidth else {
// If count is bitWidth or greater, we always overflow
// unless self is zero.
Expand All @@ -143,7 +141,27 @@ extension FixedWidthInteger {
// does equal 0b0000_0000.
let valueBits = Self.bitWidth &- (Self.isSigned ? 1 : 0)
let wrapped = self &<< count
let complement = valueBits &- Int(count)
return self &>> complement == sextOrZext ? wrapped : clamped
let complement = valueBits &- count
return self &>> complement == signbit ? wrapped : clamped
}

/// Bitwise left with rounding and saturation.
///
/// `self` multiplied by the rational number 2^(`count`), saturated to the
/// range `Self.min ... Self.max`, and rounded according to `rule`.
///
/// See `shifted(rightBy:rounding:)` for more discussion of rounding
/// shifts with examples.
///
/// - Parameters:
/// - leftBy count: the number of bits to shift by. If positive, this is a left-shift,
/// and if negative a right shift.
/// - rounding rule: the direction in which to round if `count` is negative.
@_transparent
public func shiftedWithSaturation(
leftBy count: some BinaryInteger,
rounding rule: RoundingRule = .down
) -> Self {
self.shiftedWithSaturation(leftBy: Int(clamping: count), rounding: rule)
}
}
45 changes: 42 additions & 3 deletions Sources/IntegerUtilities/ShiftWithRounding.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ extension BinaryInteger {
/// a.shifted(rightBy: count, rounding: rule)
/// a.divided(by: 1 << count, rounding: rule)
@inlinable
public func shifted<Count: BinaryInteger>(
rightBy count: Count,
public func shifted(
rightBy count: Int,
rounding rule: RoundingRule = .down
) -> Self {
// Easiest case: count is zero or negative, so shift is always exact;
Expand All @@ -61,7 +61,7 @@ extension BinaryInteger {
// shifts by first shifting all but bitWidth - 1 bits with sticky
// rounding, and then shifting the remaining bitWidth - 1 bits with
// the desired rounding mode.
let count = count - Count(bitWidth - 1)
let count = count - (bitWidth - 1)
let floor = self >> count
let lost = self - (floor << count)
let sticky = floor | (lost == 0 ? 0 : 1)
Expand Down Expand Up @@ -155,4 +155,43 @@ extension BinaryInteger {
return floor
}
}

/// `self` divided by 2^(`count`), rounding the result according to `rule`.
///
/// The default rounding rule is `.down`, which matches the behavior of
/// the `>>` operator from the standard library.
///
/// Some examples of different rounding rules:
///
/// // 3/2 is 1.5, which rounds (down by default) to 1.
/// 3.shifted(rightBy: 1)
///
/// // 1.5 rounds up to 2.
/// 3.shifted(rightBy: 1, rounding: .up)
///
/// // The two closest values are 1 and 2, 1 is returned because it
/// // is odd.
/// 3.shifted(rightBy: 1, rounding: .toOdd)
///
/// // 7/2^2 = 1.75, so the result is 1 with probability 1/4, and 2
/// // with probability 3/4.
/// 7.shifted(rightBy: 2, rounding: .stochastically)
///
/// // 4/2^2 = 4/4 = 1, exactly.
/// 4.shifted(rightBy: 2, rounding: .trap)
///
/// // 5/2 is 2.5, which is not exact, so this traps.
/// 5.shifted(rightBy: 1, rounding: .requireExact)
///
/// When `Self(1) << count` is positive, the following are equivalent:
///
/// a.shifted(rightBy: count, rounding: rule)
/// a.divided(by: 1 << count, rounding: rule)
@_transparent
public func shifted(
rightBy count: some BinaryInteger,
rounding rule: RoundingRule = .down
) -> Self {
self.shifted(rightBy: Int(clamping: count), rounding: rule)
}
}
23 changes: 18 additions & 5 deletions Tests/IntegerUtilitiesTests/SaturatingArithmeticTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ final class IntegerUtilitiesSaturatingTests: XCTestCase {
}
}

func testSaturatingSubtractSigned() {
func testSaturatingSubSigned() {
for a in Int8.min ... Int8.max {
for b in Int8.min ... Int8.max {
let expected = Int8(clamping: Int16(a) - Int16(b))
Expand All @@ -48,7 +48,7 @@ final class IntegerUtilitiesSaturatingTests: XCTestCase {
}
}

func testSaturatingNegation() {
func testSaturatingNegSigned() {
for a in Int8.min ... Int8.max {
let expected = Int8(clamping: 0 - Int16(a))
let observed = a.negatedWithSaturation()
Expand All @@ -62,7 +62,7 @@ final class IntegerUtilitiesSaturatingTests: XCTestCase {
}
}

func testSaturatingMultiplicationSigned() {
func testSaturatingMulSigned() {
for a in Int8.min ... Int8.max {
for b in Int8.min ... Int8.max {
let expected = Int8(clamping: Int16(a) * Int16(b))
Expand Down Expand Up @@ -94,7 +94,7 @@ final class IntegerUtilitiesSaturatingTests: XCTestCase {
}
}

func testSaturatingSubtractUnsigned() {
func testSaturatingSubUnsigned() {
for a in UInt8.min ... UInt8.max {
for b in UInt8.min ... UInt8.max {
let expected = UInt8(clamping: Int16(a) - Int16(b))
Expand All @@ -110,7 +110,20 @@ final class IntegerUtilitiesSaturatingTests: XCTestCase {
}
}

func testSaturatingMultiplicationUnsigned() {
func testSaturatingNegUnsigned() {
for a in UInt8.min ... UInt8.max {
let observed = a.negatedWithSaturation()
if 0 != observed {
print("Error found in (\(a)).negatedWithSaturation().")
print("Expected: zero")
print("Observed: \(String(observed, radix: 16))")
XCTFail()
return
}
}
}

func testSaturatingMulUnsigned() {
for a in UInt8.min ... UInt8.max {
for b in UInt8.min ... UInt8.max {
let expected = UInt8(clamping: UInt16(a) * UInt16(b))
Expand Down
16 changes: 16 additions & 0 deletions Tests/WindowsMain.swift
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,21 @@ extension IntegerUtilitiesRotateTests {
])
}

extension IntegerUtilitiesSaturatingTests {
static var all = testCase([
("testSaturatingAddSigned", IntegerUtilitiesSaturatingTests.testSaturatingAddSigned),
("testSaturatingSubSigned", IntegerUtilitiesSaturatingTests.testSaturatingSubSigned),
("testSaturatingNegSigned", IntegerUtilitiesSaturatingTests.testSaturatingNegSigned),
("testSaturatingMulSigned", IntegerUtilitiesSaturatingTests.testSaturatingMulSigned),
("testSaturatingAddUnsigned", IntegerUtilitiesSaturatingTests.testSaturatingAddUnsigned),
("testSaturatingSubUnsigned", IntegerUtilitiesSaturatingTests.testSaturatingSubUnsigned),
("testSaturatingNegUnsigned", IntegerUtilitiesSaturatingTests.testSaturatingNegUnsigned),
("testSaturatingMulUnsigned", IntegerUtilitiesSaturatingTests.testSaturatingMulUnsigned),
("testSaturatingShifts", IntegerUtilitiesSaturatingTests.testSaturatingShifts),
("testEdgeCaseForNegativeCount", IntegerUtilitiesSaturatingTests.testEdgeCaseForNegativeCount)
])
}

extension IntegerUtilitiesShiftTests {
static var all = testCase([
("testRoundingShifts", IntegerUtilitiesShiftTests.testRoundingShifts),
Expand Down Expand Up @@ -170,6 +185,7 @@ var testCases = [
IntegerUtilitiesGCDTests.all,
IntegerUtilitiesRotateTests.all,
IntegerUtilitiesShiftTests.all,
IntegerUtilitiesSaturatingTests.all,
IntegerUtilitiesTests.DoubleWidthTests.all,
]

Expand Down

0 comments on commit fe15113

Please sign in to comment.