Skip to content

Commit

Permalink
Improved and simplified BinaryOperation with "stubborn" location infe…
Browse files Browse the repository at this point in the history
…rence (#1599)

* More general location inference for binary operations

* Add identity to list of unary operators

* Adds additional type parameter to zero field

* Revert "Adds additional type parameter to zero field"

This reverts commit 1a77150.

* Introduces "stubborn" binary operations

* Cleans up doc strings and adapt_structure for AbstractOperations

* Fix some scoping issues with interpolation wrapper for at macro

* Move interpolate_operation definition to at.jl

* No need to throw errors for operations at Nothing location

* Computations with AveragedField might work?

* Adds incremental build up of AveragedField operations testing

* Dont interpolate unecessarily

* Bugfix in abstract operations tests

* Use custom binary op function for getindex

* No more test skip in abstrsct operations tests!
  • Loading branch information
glwagner authored Apr 21, 2021
1 parent 91bdb1a commit 91e40bf
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 77 deletions.
4 changes: 3 additions & 1 deletion src/AbstractOperations/AbstractOperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ Base.parent(op::AbstractOperation) = op
# AbstractOperation macros add their associated functions to this list
const operators = Set()

include("at.jl")
include("grid_validation.jl")

include("unary_operations.jl")
Expand Down Expand Up @@ -78,4 +77,7 @@ eval(define_multiary_operator(:*))
push!(operators, :*)
push!(multiary_operators, :*)

include("at.jl")

end # module

18 changes: 17 additions & 1 deletion src/AbstractOperations/at.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ end
"Fallback for when `insert_location` is called on objects other than expressions."
insert_location!(anything, location) = nothing

# A very special UnaryOperation
@inbounds identity(i, j, k, grid, a::Number) = a
@inbounds identity(i, j, k, grid, a::AbstractField) = @inbounds a[i, j, k]

function interpolate_operation(L, x::AbstractField)
L == location(x) && return x # Don't interpolate unecessarily
return _unary_operation(L, identity, x, location(x), x.grid)
end

"""
@at location abstract_operation
Expand All @@ -30,5 +39,12 @@ Modify the `abstract_operation` so that it returns values at
"""
macro at(location, abstract_operation)
insert_location!(abstract_operation, location)
return esc(abstract_operation)

# We wrap it all in an interpolator to help "stubborn" binary operations
# arrive in the right place.
wrapped_operation = quote
interpolate_operation($(esc(location)), $(esc(abstract_operation)))
end

return wrapped_operation
end
72 changes: 36 additions & 36 deletions src/AbstractOperations/binary_operations.jl
Original file line number Diff line number Diff line change
@@ -1,51 +1,48 @@
const binary_operators = Set()

"""
BinaryOperation{X, Y, Z, O, A, B, IA, IB, IΩ, G} <: AbstractOperation{X, Y, Z, G}
An abstract representation of a binary operation on `AbstractField`s.
"""
struct BinaryOperation{X, Y, Z, O, A, B, IA, IB, IΩ, G} <: AbstractOperation{X, Y, Z, G}
struct BinaryOperation{X, Y, Z, O, A, B, IA, IB, G} <: AbstractOperation{X, Y, Z, G}
op :: O
a :: A
b :: B
▶a :: IA
▶b :: IB
▶op :: IΩ
grid :: G

"""
BinaryOperation{X, Y, Z}(op, a, b, ▶a, ▶b, ▶op, grid)
BinaryOperation{X, Y, Z}(op, a, b, ▶a, ▶b, grid)
Returns an abstract representation of the binary operation `op(▶a(a), ▶b(b))`,
followed by interpolation by `▶op` to `(X, Y, Z)`, where `▶a` and `▶b` interpolate
`a` and `b` to a common location.
Returns an abstract representation of the binary operation `op(▶a(a), ▶b(b))`.
where `▶a` and `▶b` interpolate `a` and `b` to (X, Y, Z).
"""
function BinaryOperation{X, Y, Z}(op, a, b, ▶a, ▶b, ▶op, grid) where {X, Y, Z}

any((X, Y, Z) .=== Nothing) && throw(ArgumentError("Nothing locations are invalid! " *
"Cannot construct BinaryOperation at ($X, $Y, $Z)."))

function BinaryOperation{X, Y, Z}(op, a, b, ▶a, ▶b, grid) where {X, Y, Z}
return new{X, Y, Z, typeof(op), typeof(a), typeof(b), typeof(▶a), typeof(▶b),
typeof(▶op), typeof(grid)}(op, a, b, ▶a, ▶b, ▶op, grid)
typeof(grid)}(op, a, b, ▶a, ▶b, grid)
end
end

@inline Base.getindex::BinaryOperation, i, j, k) = β.op(i, j, k, β.grid, β.op, β.▶a, β.▶b, β.a, β.b)
@inline Base.getindex::BinaryOperation, i, j, k) = β.op(i, j, k, β.grid, β.▶a, β.▶b, β.a, β.b)

#####
##### BinaryOperation construction
#####

"""Create a binary operation for `op` acting on `a` and `b` with locations `La` and `Lb`.
The operator acts at `Lab` and the result is interpolated to `Lc`."""
function _binary_operation(Lc, op, a, b, La, Lb, Lab, grid)
▶a = interpolation_operator(La, Lab)
▶b = interpolation_operator(Lb, Lab)
▶op = interpolation_operator(Lab, Lc)
return BinaryOperation{Lc[1], Lc[2], Lc[3]}(op, a, b, ▶a, ▶b, ▶op, grid)
function _binary_operation(Lc, op, a, b, La, Lb, grid)
▶a = interpolation_operator(La, Lc)
▶b = interpolation_operator(Lb, Lc)
return BinaryOperation{Lc[1], Lc[2], Lc[3]}(op, a, b, ▶a, ▶b, grid)
end

const ConcreteLocationType = Union{Type{Face}, Type{Center}}

# Precedence rules for choosing operation location:
choose_location(La, Lb, Lc) = Lc # Fallback to the specification Lc, but also...
choose_location(::Type{Face}, ::Type{Face}, Lc) = Face # keep common locations; and
choose_location(::Type{Center}, ::Type{Center}, Lc) = Center #
choose_location(La::ConcreteLocationType, ::Type{Nothing}, Lc) = La # don't interpolate unspecified locations.
choose_location(::Type{Nothing}, Lb::ConcreteLocationType, Lc) = Lb #

"""Return an expression that defines an abstract `BinaryOperator` named `op` for `AbstractField`."""
function define_binary_operator(op)
return quote
Expand All @@ -60,26 +57,29 @@ function define_binary_operator(op)
@inbounds $op(▶a(i, j, k, grid, a), ▶b(i, j, k, grid, b))

"""
$($op)(Lc, Lab, a, b)
$($op)(Lc, a, b)
Returns an abstract representation of the operator `$($op)` acting on `a` and `b` at
location `Lab`, and subsequently interpolated to location `Lc`.
Returns an abstract representation of the operator `$($op)` acting on `a` and `b`.
The operation occurs at location(a) except for Nothing dimensions. In that case,
the location of the dimension in question is supplied either by location(b) or
if that is also Nothing, Lc.
"""
function $op(Lc::Tuple, Lop::Tuple, a, b)
function $op(Lc::Tuple, a, b)
La = location(a)
Lb = location(b)
Lab = choose_location.(La, Lb, Lc)

grid = Oceananigans.AbstractOperations.validate_grid(a, b)
return Oceananigans.AbstractOperations._binary_operation(Lc, $op, a, b, La, Lb, Lop, grid)

return Oceananigans.AbstractOperations._binary_operation(Lab, $op, a, b, La, Lb, grid)
end

$op(Lc::Tuple, a, b) = $op(Lc, Lc, a, b)
$op(Lc::Tuple, a::Number, b) = $op(Lc, location(b), a, b)
$op(Lc::Tuple, a, b::Number) = $op(Lc, location(a), a, b)
$op(Lc::Tuple, a::AF{X, Y, Z}, b::AF{X, Y, Z}) where {X, Y, Z} = $op(Lc, location(a), a, b)
# Numbers are not fields...
$op(Lc::Tuple, a::Number, b::Number) = $op(a, b)

# Sugar for mixing in functions of (x, y, z)
$op(Lc::Tuple, a::Function, b::AbstractField) = $op(Lc, FunctionField(Lc, a, b.grid), b)
$op(Lc::Tuple, a::AbstractField, b::Function) = $op(Lc, a, FunctionField(Lc, b, a.grid))
$op(Lc::Tuple, f::Function, b::AbstractField) = $op(Lc, FunctionField(location(b), f, b.grid), b)
$op(Lc::Tuple, a::AbstractField, f::Function) = $op(Lc, a, FunctionField(location(a), f, a.grid))

# Sugary versions with default locations
$op(a::AF, b::AF) = $op(location(a), a, b)
Expand Down Expand Up @@ -184,5 +184,5 @@ end
"Adapt `BinaryOperation` to work on the GPU via CUDAnative and CUDAdrv."
Adapt.adapt_structure(to, binary::BinaryOperation{X, Y, Z}) where {X, Y, Z} =
BinaryOperation{X, Y, Z}(Adapt.adapt(to, binary.op), Adapt.adapt(to, binary.a), Adapt.adapt(to, binary.b),
Adapt.adapt(to, binary.▶a), Adapt.adapt(to, binary.▶b), Adapt.adapt(to, binary.▶op),
binary.grid)
Adapt.adapt(to, binary.▶a), Adapt.adapt(to, binary.▶b), Adapt.adapt(to, binary.grid))

7 changes: 1 addition & 6 deletions src/AbstractOperations/derivatives.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
using Oceananigans.Operators: interpolation_code

"""
Derivative{X, Y, Z, D, A, I, G} <: AbstractOperation{X, Y, Z, G}
An abstract representation of a derivative of an `AbstractField`.
"""
struct Derivative{X, Y, Z, D, A, I, G} <: AbstractOperation{X, Y, Z, G}
:: D
arg :: A
Expand Down Expand Up @@ -122,4 +117,4 @@ compute_at!(∂::Derivative, time) = compute_at!(∂.arg, time)
"Adapt `Derivative` to work on the GPU via CUDAnative and CUDAdrv."
Adapt.adapt_structure(to, deriv::Derivative{X, Y, Z}) where {X, Y, Z} =
Derivative{X, Y, Z}(Adapt.adapt(to, deriv.∂), Adapt.adapt(to, deriv.arg),
Adapt.adapt(to, deriv.▶), deriv.grid)
Adapt.adapt(to, deriv.▶), Adapt.adapt(to, deriv.grid))
2 changes: 1 addition & 1 deletion src/AbstractOperations/multiary_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,4 @@ end
"Adapt `MultiaryOperation` to work on the GPU via CUDAnative and CUDAdrv."
Adapt.adapt_structure(to, multiary::MultiaryOperation{X, Y, Z}) where {X, Y, Z} =
MultiaryOperation{X, Y, Z}(Adapt.adapt(to, multiary.op), Adapt.adapt(to, multiary.args),
Adapt.adapt(to, multiary.▶), multiary.grid)
Adapt.adapt(to, multiary.▶), Adapt.adapt(to, multiary.grid))
2 changes: 1 addition & 1 deletion src/AbstractOperations/show_abstract_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ end
function tree_show(binary::BinaryOperation{X, Y, Z}, depth, nesting) where {X, Y, Z}
padding = get_tree_padding(depth, nesting)

return string(binary.op, " at ", show_location(X, Y, Z), " via ", show_interp(binary.▶op), '\n',
return string(binary.op, " at ", show_location(X, Y, Z), '\n',
padding, "├── ", tree_show(binary.a, depth+1, nesting+1), '\n',
padding, "└── ", tree_show(binary.b, depth+1, nesting))
end
Expand Down
8 changes: 1 addition & 7 deletions src/AbstractOperations/unary_operations.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
const unary_operators = Set()

"""
UnaryOperation{X, Y, Z, O, A, I, G} <: AbstractOperation{X, Y, Z, G}
An abstract representation of a unary operation on an `AbstractField`; or a function
`f(x)` with on argument acting on `x::AbstractField`.
"""
struct UnaryOperation{X, Y, Z, O, A, I, G} <: AbstractOperation{X, Y, Z, G}
op :: O
arg :: A
Expand Down Expand Up @@ -131,4 +125,4 @@ compute_at!(υ::UnaryOperation, time) = compute_at!(υ.arg, time)
"Adapt `UnaryOperation` to work on the GPU via CUDAnative and CUDAdrv."
Adapt.adapt_structure(to, unary::UnaryOperation{X, Y, Z}) where {X, Y, Z} =
UnaryOperation{X, Y, Z}(Adapt.adapt(to, unary.op), Adapt.adapt(to, unary.arg),
Adapt.adapt(to, unary.▶), unary.grid)
Adapt.adapt(to, unary.▶), Adapt.adapt(to, unary.grid))
2 changes: 1 addition & 1 deletion src/Fields/abstract_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ end
##### AbstractField functionality
#####

@inline location(a) = nothing
@inline location(a) = (Nothing, Nothing, Nothing)

"Returns the location `(X, Y, Z)` of an `AbstractField{X, Y, Z}`."
@inline location(::AbstractField{X, Y, Z}) where {X, Y, Z} = (X, Y, Z) # note no instantiation
Expand Down
66 changes: 43 additions & 23 deletions test/test_abstract_operations_computed_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,6 @@ end

u, v, w, T, S = fields(model)

@test_throws ArgumentError @at (Nothing, Nothing, Center) T * S

for ϕ in (u, v, w, T, S)
for op in (sin, cos, sqrt, exp, tanh)
@test op(ϕ) isa UnaryOperation
Expand Down Expand Up @@ -597,6 +595,15 @@ end
@test compute_plus(model)
@test compute_minus(model)
@test compute_times(model)

# Basic compilation test for nested BinaryOperations...
u, v, w = model.velocities
@test try
compute!(ComputedField(u + v - w))
true
catch
false
end
end

@testset "Multiary computations [$FT, $(typeof(arch))]" begin
Expand Down Expand Up @@ -661,12 +668,7 @@ end
@info " Testing operations with AveragedField..."

T, S = model.tracers

TS = AveragedField(T * S, dims=(1, 2))

@test_throws ArgumentError @at (Nothing, Nothing, Center) T * S
@test_throws ArgumentError TS * S

@test operations_with_averaged_field(model)
end

Expand Down Expand Up @@ -695,26 +697,44 @@ end

@test computations_with_averaged_field_derivative(model)

# These don't work on the GPU right now
if arch isa CPU
@test computations_with_averaged_fields(model)
else
@test_skip computations_with_averaged_fields(model)
end
u, v, w = model.velocities

set!(model, enforce_incompressibility = false, u = (x, y, z) -> z, v = 2, w = 3)

# Two ways to compute turbulent kinetic energy
U = AveragedField(u, dims=(1, 2))
V = AveragedField(v, dims=(1, 2))

# Build up compilation tests incrementally...
u_prime = u - U
u_prime_ccc = @at (Center, Center, Center) u - U
u_prime_squared = (u - U)^2
u_prime_squared_ccc = @at (Center, Center, Center) (u - U)^2
horizontal_twice_tke = (u - U)^2 + (v - V)^2
horizontal_tke = ((u - U)^2 + (v - V)^2) / 2
horizontal_tke_ccc = @at (Center, Center, Center) ((u - U)^2 + (v - V)^2) / 2
twice_tke = (u - U)^2 + (v - V)^2 + w^2
tke = ((u - U)^2 + (v - V)^2 + w^2) / 2
tke_ccc = @at (Center, Center, Center) ((u - U)^2 + (v - V)^2 + w^2) / 2

@test try compute!(ComputedField(u_prime )); true; catch; false; end
@test try compute!(ComputedField(u_prime_ccc )); true; catch; false; end
@test try compute!(ComputedField(u_prime_squared )); true; catch; false; end
@test try compute!(ComputedField(u_prime_squared_ccc )); true; catch; false; end
@test try compute!(ComputedField(horizontal_twice_tke)); true; catch; false; end
@test try compute!(ComputedField(horizontal_tke )); true; catch; false; end
@test try compute!(ComputedField(twice_tke )); true; catch; false; end

@test try compute!(ComputedField(horizontal_tke_ccc )); true; catch; false; end
@test try compute!(ComputedField(tke )); true; catch; false; end
@test try compute!(ComputedField(tke_ccc )); true; catch; false; end

computed_tke = ComputedField(tke_ccc)
@test (compute!(computed_tke); all(interior(computed_tke)[2:3, 2:3, 2:3] .== 9/2))
end

@testset "Computations with ComputedFields [$FT, $(typeof(arch))]" begin
@info " Testing computations with ComputedField [$FT, $(typeof(arch))]..."

# Basic compilation test...
u, v, w = model.velocities
@test try
compute!(ComputedField(u + v - w))
true
catch
false
end

@test computations_with_computed_fields(model)
end

Expand Down

0 comments on commit 91e40bf

Please sign in to comment.