Skip to content

Commit

Permalink
Merge pull request #23 from SymbolicML/additional-utils
Browse files Browse the repository at this point in the history
Additional utilities for identity functions, `+`, and `-`
  • Loading branch information
MilesCranmer authored Jun 14, 2023
2 parents edca829 + 2412d55 commit 2044168
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 25 deletions.
24 changes: 15 additions & 9 deletions src/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,28 @@ Base.:*(l::Dimensions, r::Dimensions) = @map_dimensions(+, l, r)
Base.:*(l::Quantity, r::Quantity) = Quantity(l.value * r.value, l.dimensions * r.dimensions)
Base.:*(l::Quantity, r::Dimensions) = Quantity(l.value, l.dimensions * r)
Base.:*(l::Dimensions, r::Quantity) = Quantity(r.value, l * r.dimensions)
Base.:*(l::Quantity, r::Number) = Quantity(l.value * r, l.dimensions)
Base.:*(l::Number, r::Quantity) = Quantity(l * r.value, r.dimensions)
Base.:*(l::Dimensions, r::Number) = Quantity(r, l)
Base.:*(l::Number, r::Dimensions) = Quantity(l, r)
Base.:*(l::Quantity, r) = Quantity(l.value * r, l.dimensions)
Base.:*(l, r::Quantity) = Quantity(l * r.value, r.dimensions)
Base.:*(l::Dimensions, r) = Quantity(r, l)
Base.:*(l, r::Dimensions) = Quantity(l, r)

Base.:/(l::Dimensions, r::Dimensions) = @map_dimensions(-, l, r)
Base.:/(l::Quantity, r::Quantity) = Quantity(l.value / r.value, l.dimensions / r.dimensions)
Base.:/(l::Quantity, r::Dimensions) = Quantity(l.value, l.dimensions / r)
Base.:/(l::Dimensions, r::Quantity) = Quantity(inv(r.value), l / r.dimensions)
Base.:/(l::Quantity, r::Number) = Quantity(l.value / r, l.dimensions)
Base.:/(l::Number, r::Quantity) = l * inv(r)
Base.:/(l::Dimensions, r::Number) = Quantity(inv(r), l)
Base.:/(l::Number, r::Dimensions) = Quantity(l, inv(r))
Base.:/(l::Quantity, r) = Quantity(l.value / r, l.dimensions)
Base.:/(l, r::Quantity) = l * inv(r)
Base.:/(l::Dimensions, r) = Quantity(inv(r), l)
Base.:/(l, r::Dimensions) = Quantity(l, inv(r))

Base.:+(l::Quantity, r::Quantity) = dimension(l) == dimension(r) ? Quantity(l.value + r.value, l.dimensions) : throw(DimensionError(l, r))
Base.:-(l::Quantity, r::Quantity) = dimension(l) == dimension(r) ? Quantity(l.value - r.value, l.dimensions) : throw(DimensionError(l, r))
Base.:-(l::Quantity) = Quantity(-l.value, l.dimensions)
Base.:-(l::Quantity, r::Quantity) = l + (-r)

Base.:+(l::Quantity, r) = dimension(l) == dimension(r) ? Quantity(l.value + r, l.dimensions) : throw(DimensionError(l, r))
Base.:+(l, r::Quantity) = dimension(l) == dimension(r) ? Quantity(l + r.value, r.dimensions) : throw(DimensionError(l, r))
Base.:-(l::Quantity, r) = l + (-r)
Base.:-(l, r::Quantity) = l + (-r)

_pow(l::Dimensions, r) = @map_dimensions(Base.Fix1(*, r), l)
_pow(l::Quantity{T}, r) where {T} = Quantity(l.value^r, _pow(l.dimensions, r))
Expand Down
33 changes: 25 additions & 8 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ Base.iszero(q::Quantity) = iszero(q.value)
Base.getindex(d::Dimensions, k::Symbol) = getfield(d, k)
Base.:(==)(l::Dimensions, r::Dimensions) = @all_dimensions(==, l, r)
Base.:(==)(l::Quantity, r::Quantity) = l.value == r.value && l.dimensions == r.dimensions
Base.:(==)(l, r::Quantity) = ustrip(l) == ustrip(r) && dimension(l) == dimension(r)
Base.:(==)(l::Quantity, r) = ustrip(l) == ustrip(r) && dimension(l) == dimension(r)
Base.isless(l::Quantity, r::Quantity) = dimension(l) == dimension(r) ? isless(ustrip(l), ustrip(r)) : throw(DimensionError(l, r))
Base.isless(l::Quantity, r) = dimension(l) == dimension(r) ? isless(ustrip(l), r) : throw(DimensionError(l, r))
Base.isless(l, r::Quantity) = dimension(l) == dimension(r) ? isless(l, ustrip(r)) : throw(DimensionError(l, r))
Base.isapprox(l::Quantity, r::Quantity; kws...) = isapprox(l.value, r.value; kws...) && l.dimensions == r.dimensions
Base.length(::Dimensions) = 1
Base.length(::Quantity) = 1
Expand All @@ -45,16 +50,26 @@ Base.iterate(::Dimensions, ::Nothing) = nothing
Base.iterate(q::Quantity) = (q, nothing)
Base.iterate(::Quantity, ::Nothing) = nothing

Base.zero(::Type{Quantity{T,R}}) where {T,R} = Quantity(zero(T), R)
# Multiplicative identities:
Base.one(::Type{Quantity{T,R}}) where {T,R} = Quantity(one(T), R)
Base.one(::Type{Dimensions{R}}) where {R} = Dimensions{R}()

Base.zero(::Type{Quantity{T}}) where {T} = zero(Quantity{T,DEFAULT_DIM_TYPE})
Base.one(::Type{Quantity{T}}) where {T} = one(Quantity{T,DEFAULT_DIM_TYPE})

Base.zero(::Type{Quantity}) = zero(Quantity{DEFAULT_VALUE_TYPE})
Base.one(::Type{Quantity}) = one(Quantity{DEFAULT_VALUE_TYPE})
Base.one(::Type{Dimensions{R}}) where {R} = Dimensions{R}()
Base.one(::Type{Dimensions}) = one(Dimensions{DEFAULT_DIM_TYPE})
Base.one(q::Quantity) = Quantity(one(ustrip(q)), one(dimension(q)))
Base.one(d::Dimensions) = one(typeof(d))

# Additive identities:
Base.zero(q::Quantity) = Quantity(zero(ustrip(q)), dimension(q))
Base.zero(::Dimensions) = error("There is no such thing as an additive identity for a `Dimensions` object, as + is only defined for `Quantity`.")
Base.zero(::Type{<:Quantity}) = error("Cannot create an additive identity for a `Quantity` type, as the dimensions are unknown. Please use `zero(::Quantity)` instead.")
Base.zero(::Type{<:Dimensions}) = error("There is no such thing as an additive identity for a `Dimensions` type, as + is only defined for `Quantity`.")

# Dimensionful 1:
Base.oneunit(q::Quantity) = Quantity(oneunit(ustrip(q)), dimension(q))
Base.oneunit(::Dimensions) = error("There is no such thing as a dimensionful 1 for a `Dimensions` object, as + is only defined for `Quantity`.")
Base.oneunit(::Type{<:Quantity}) = error("Cannot create a dimensionful 1 for a `Quantity` type without knowing the dimensions. Please use `oneunit(::Quantity)` instead.")
Base.oneunit(::Type{<:Dimensions}) = error("There is no such thing as a dimensionful 1 for a `Dimensions` type, as + is only defined for `Quantity`.")

Base.show(io::IO, d::Dimensions) =
let tmp_io = IOBuffer()
Expand Down Expand Up @@ -101,15 +116,17 @@ Base.convert(::Type{Dimensions{R}}, d::Dimensions) where {R} = Dimensions{R}(d)
Remove the units from a quantity.
"""
ustrip(q::Quantity) = q.value
ustrip(q::Number) = q
ustrip(::Dimensions) = error("Cannot remove units from a `Dimensions` object.")
ustrip(q) = q

"""
dimension(q::Quantity)
Get the dimensions of a quantity, returning a `Dimensions` object.
"""
dimension(q::Quantity) = q.dimensions
dimension(::Number) = Dimensions()
dimension(d::Dimensions) = d
dimension(_) = Dimensions()

"""
ulength(q::Quantity)
Expand Down
80 changes: 72 additions & 8 deletions test/unittests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,27 @@ using Test
@test uluminosity(y) == R(0)
@test uamount(y) == R(0)
@test ustrip(y) T(0.2^2.1)

dimensionless = Quantity(one(T), R)
y = T(2) + dimensionless
@test ustrip(y) == T(3)
@test dimension(y) == Dimensions(R)
@test typeof(y) == Quantity{T,R}

y = T(2) - dimensionless
@test ustrip(y) == T(1)
@test dimension(y) == Dimensions(R)
@test typeof(y) == Quantity{T,R}

y = dimensionless + T(2)
@test ustrip(y) == T(3)
y = dimensionless - T(2)
@test ustrip(y) == T(-1)

@test_throws DimensionError Quantity(one(T), R, length=1) + 1.0
@test_throws DimensionError Quantity(one(T), R, length=1) - 1.0
@test_throws DimensionError 1.0 + Quantity(one(T), R, length=1)
@test_throws DimensionError 1.0 - Quantity(one(T), R, length=1)
end

x = Quantity(-1.2, length=2 // 5)
Expand All @@ -108,7 +129,12 @@ end

@testset "Fallbacks" begin
@test ustrip(0.5) == 0.5
@test ustrip(ones(32)) == ones(32)
@test dimension(0.5) == Dimensions()
@test dimension(ones(32)) == Dimensions()
@test dimension(Dimensions()) === Dimensions()

@test_throws ErrorException ustrip(Dimensions())
end

@testset "Arrays" begin
Expand All @@ -126,6 +152,14 @@ end

uX = X .* Quantity(2, length=2.5, luminosity=0.5)
@test sum(X) == 0.5 * ustrip(sum(uX))

x = Quantity(ones(T, 32))
@test ustrip(x + ones(T, 32))[32] == 2
@test typeof(x + ones(T, 32)) <: Quantity{Vector{T}}
@test typeof(x - ones(T, 32)) <: Quantity{Vector{T}}
@test typeof(ones(T, 32) * Dimensions(length=1)) <: Quantity{Vector{T}}
@test typeof(ones(T, 32) / Dimensions(length=1)) <: Quantity{Vector{T}}
@test ones(T, 32) / Dimensions(length=1) == Quantity(ones(T, 32), length=-1)
end
end

Expand All @@ -150,25 +184,55 @@ end

@test Dimensions{Int8}([0 for i=1:length(DIMENSION_NAMES)]...) == Dimensions{Int8}()

@test zero(Quantity{ComplexF64,Int8}) + Quantity(1) == Quantity(1.0+0.0im, length=Int8(0))
@test one(Quantity{ComplexF64,Int8}) - Quantity(1) == Quantity(0.0+0.0im, length=Int8(0))
@test zero(Quantity(0.0+0.0im)) + Quantity(1) == Quantity(1.0+0.0im, length=Int8(0))
@test oneunit(Quantity(0.0+0.0im)) - Quantity(1) == Quantity(0.0+0.0im, length=Int8(0))
@test typeof(one(Dimensions{Int16})) == Dimensions{Int16}
@test one(Dimensions{Int16}) == Dimensions(mass=Int16(0))

@test zero(Quantity{ComplexF64}) == Quantity(0.0+0.0im)
@test zero(Quantity(0.0im)) == Quantity(0.0+0.0im)
@test one(Quantity{ComplexF64}) == Quantity(1.0+0.0im)

@test zero(Quantity) == Quantity(0.0)
@test typeof(zero(Quantity)) == Quantity{DEFAULT_VALUE_TYPE,DEFAULT_DIM_TYPE}
@test one(Quantity) - Quantity(1) == Quantity(0.0)
@test typeof(one(Quantity)) == Quantity{DEFAULT_VALUE_TYPE,DEFAULT_DIM_TYPE}
@test typeof(one(Dimensions)) == Dimensions{DEFAULT_DIM_TYPE}
@test zero(Quantity(0.0)) == Quantity(0.0)
@test typeof(zero(Quantity(0.0))) == Quantity{Float64,DEFAULT_DIM_TYPE}
@test oneunit(Quantity(1.0)) - Quantity(1.0) == Quantity(0.0)
@test typeof(one(Quantity(1.0))) == Quantity{DEFAULT_VALUE_TYPE,DEFAULT_DIM_TYPE}
@test one(Dimensions) == Dimensions()
@test one(Dimensions()) == Dimensions()
@test typeof(one(Quantity)) == Quantity{DEFAULT_VALUE_TYPE,DEFAULT_DIM_TYPE}
@test ustrip(one(Quantity)) === one(DEFAULT_VALUE_TYPE)
@test typeof(one(Quantity(ones(32, 32)))) == Quantity{Matrix{Float64},DEFAULT_DIM_TYPE}
@test dimension(one(Quantity(ones(32, 32), length=1))) == Dimensions()

x = Quantity(1, length=1)

@test zero(x) == Quantity(0, length=1)
@test typeof(zero(x)) == Quantity{Int64,DEFAULT_DIM_TYPE}

# Invalid calls:
@test_throws ErrorException zero(Quantity)
@test_throws ErrorException zero(Dimensions())
@test_throws ErrorException zero(Dimensions)
@test_throws ErrorException oneunit(Quantity)
@test_throws ErrorException oneunit(Dimensions())
@test_throws ErrorException oneunit(Dimensions)

@test sqrt(z * -1) == Quantity(sqrt(52), length=1 // 2, mass=1)
@test cbrt(z) == Quantity(cbrt(-52), length=1 // 3, mass=2 // 3)

@test 1.0 * (Dimensions(length=3)^2) == Quantity(1.0, length=6)

x = 0.9u"km/s"
y = 0.3 * x
@test x > y
@test y < x

x = Quantity(1.0)

@test x == 1.0
@test x >= 1.0
@test x < 2.0

@test_throws DimensionError x < 1.0u"m"
end

@testset "Manual construction" begin
Expand Down

0 comments on commit 2044168

Please sign in to comment.