From 91e40bf9a9df80c6c25c72aa1bd99ef1e8dda7eb Mon Sep 17 00:00:00 2001 From: "Gregory L. Wagner" Date: Tue, 20 Apr 2021 20:37:12 -0700 Subject: [PATCH] Improved and simplified BinaryOperation with "stubborn" location inference (#1599) * More general location inference for binary operations * Add identity to list of unary operators * Adds additional type parameter to zero field * Revert "Adds additional type parameter to zero field" This reverts commit 1a77150444ddb2a34345a4ed2fa358b2d5689990. * Introduces "stubborn" binary operations * Cleans up doc strings and adapt_structure for AbstractOperations * Fix some scoping issues with interpolation wrapper for at macro * Move interpolate_operation definition to at.jl * No need to throw errors for operations at Nothing location * Computations with AveragedField might work? * Adds incremental build up of AveragedField operations testing * Dont interpolate unecessarily * Bugfix in abstract operations tests * Use custom binary op function for getindex * No more test skip in abstrsct operations tests! --- src/AbstractOperations/AbstractOperations.jl | 4 +- src/AbstractOperations/at.jl | 18 ++++- src/AbstractOperations/binary_operations.jl | 72 +++++++++---------- src/AbstractOperations/derivatives.jl | 7 +- src/AbstractOperations/multiary_operations.jl | 2 +- .../show_abstract_operations.jl | 2 +- src/AbstractOperations/unary_operations.jl | 8 +-- src/Fields/abstract_field.jl | 2 +- ...test_abstract_operations_computed_field.jl | 66 +++++++++++------ 9 files changed, 104 insertions(+), 77 deletions(-) diff --git a/src/AbstractOperations/AbstractOperations.jl b/src/AbstractOperations/AbstractOperations.jl index 79b7c951e3..52f34ec8b8 100644 --- a/src/AbstractOperations/AbstractOperations.jl +++ b/src/AbstractOperations/AbstractOperations.jl @@ -40,7 +40,6 @@ Base.parent(op::AbstractOperation) = op # AbstractOperation macros add their associated functions to this list const operators = Set() -include("at.jl") include("grid_validation.jl") include("unary_operations.jl") @@ -78,4 +77,7 @@ eval(define_multiary_operator(:*)) push!(operators, :*) push!(multiary_operators, :*) +include("at.jl") + end # module + diff --git a/src/AbstractOperations/at.jl b/src/AbstractOperations/at.jl index 634d602418..cae8fefdb7 100644 --- a/src/AbstractOperations/at.jl +++ b/src/AbstractOperations/at.jl @@ -22,6 +22,15 @@ end "Fallback for when `insert_location` is called on objects other than expressions." insert_location!(anything, location) = nothing +# A very special UnaryOperation +@inbounds identity(i, j, k, grid, a::Number) = a +@inbounds identity(i, j, k, grid, a::AbstractField) = @inbounds a[i, j, k] + +function interpolate_operation(L, x::AbstractField) + L == location(x) && return x # Don't interpolate unecessarily + return _unary_operation(L, identity, x, location(x), x.grid) +end + """ @at location abstract_operation @@ -30,5 +39,12 @@ Modify the `abstract_operation` so that it returns values at """ macro at(location, abstract_operation) insert_location!(abstract_operation, location) - return esc(abstract_operation) + + # We wrap it all in an interpolator to help "stubborn" binary operations + # arrive in the right place. + wrapped_operation = quote + interpolate_operation($(esc(location)), $(esc(abstract_operation))) + end + + return wrapped_operation end diff --git a/src/AbstractOperations/binary_operations.jl b/src/AbstractOperations/binary_operations.jl index f13ada4d6d..adb74720e4 100644 --- a/src/AbstractOperations/binary_operations.jl +++ b/src/AbstractOperations/binary_operations.jl @@ -1,37 +1,26 @@ const binary_operators = Set() -""" - BinaryOperation{X, Y, Z, O, A, B, IA, IB, IΩ, G} <: AbstractOperation{X, Y, Z, G} - -An abstract representation of a binary operation on `AbstractField`s. -""" -struct BinaryOperation{X, Y, Z, O, A, B, IA, IB, IΩ, G} <: AbstractOperation{X, Y, Z, G} +struct BinaryOperation{X, Y, Z, O, A, B, IA, IB, G} <: AbstractOperation{X, Y, Z, G} op :: O a :: A b :: B ▶a :: IA ▶b :: IB - ▶op :: IΩ grid :: G """ - BinaryOperation{X, Y, Z}(op, a, b, ▶a, ▶b, ▶op, grid) + BinaryOperation{X, Y, Z}(op, a, b, ▶a, ▶b, grid) - Returns an abstract representation of the binary operation `op(▶a(a), ▶b(b))`, - followed by interpolation by `▶op` to `(X, Y, Z)`, where `▶a` and `▶b` interpolate - `a` and `b` to a common location. + Returns an abstract representation of the binary operation `op(▶a(a), ▶b(b))`. + where `▶a` and `▶b` interpolate `a` and `b` to (X, Y, Z). """ - function BinaryOperation{X, Y, Z}(op, a, b, ▶a, ▶b, ▶op, grid) where {X, Y, Z} - - any((X, Y, Z) .=== Nothing) && throw(ArgumentError("Nothing locations are invalid! " * - "Cannot construct BinaryOperation at ($X, $Y, $Z).")) - + function BinaryOperation{X, Y, Z}(op, a, b, ▶a, ▶b, grid) where {X, Y, Z} return new{X, Y, Z, typeof(op), typeof(a), typeof(b), typeof(▶a), typeof(▶b), - typeof(▶op), typeof(grid)}(op, a, b, ▶a, ▶b, ▶op, grid) + typeof(grid)}(op, a, b, ▶a, ▶b, grid) end end -@inline Base.getindex(β::BinaryOperation, i, j, k) = β.▶op(i, j, k, β.grid, β.op, β.▶a, β.▶b, β.a, β.b) +@inline Base.getindex(β::BinaryOperation, i, j, k) = β.op(i, j, k, β.grid, β.▶a, β.▶b, β.a, β.b) ##### ##### BinaryOperation construction @@ -39,13 +28,21 @@ end """Create a binary operation for `op` acting on `a` and `b` with locations `La` and `Lb`. The operator acts at `Lab` and the result is interpolated to `Lc`.""" -function _binary_operation(Lc, op, a, b, La, Lb, Lab, grid) - ▶a = interpolation_operator(La, Lab) - ▶b = interpolation_operator(Lb, Lab) - ▶op = interpolation_operator(Lab, Lc) - return BinaryOperation{Lc[1], Lc[2], Lc[3]}(op, a, b, ▶a, ▶b, ▶op, grid) +function _binary_operation(Lc, op, a, b, La, Lb, grid) + ▶a = interpolation_operator(La, Lc) + ▶b = interpolation_operator(Lb, Lc) + return BinaryOperation{Lc[1], Lc[2], Lc[3]}(op, a, b, ▶a, ▶b, grid) end +const ConcreteLocationType = Union{Type{Face}, Type{Center}} + +# Precedence rules for choosing operation location: +choose_location(La, Lb, Lc) = Lc # Fallback to the specification Lc, but also... +choose_location(::Type{Face}, ::Type{Face}, Lc) = Face # keep common locations; and +choose_location(::Type{Center}, ::Type{Center}, Lc) = Center # +choose_location(La::ConcreteLocationType, ::Type{Nothing}, Lc) = La # don't interpolate unspecified locations. +choose_location(::Type{Nothing}, Lb::ConcreteLocationType, Lc) = Lb # + """Return an expression that defines an abstract `BinaryOperator` named `op` for `AbstractField`.""" function define_binary_operator(op) return quote @@ -60,26 +57,29 @@ function define_binary_operator(op) @inbounds $op(▶a(i, j, k, grid, a), ▶b(i, j, k, grid, b)) """ - $($op)(Lc, Lab, a, b) + $($op)(Lc, a, b) - Returns an abstract representation of the operator `$($op)` acting on `a` and `b` at - location `Lab`, and subsequently interpolated to location `Lc`. + Returns an abstract representation of the operator `$($op)` acting on `a` and `b`. + The operation occurs at location(a) except for Nothing dimensions. In that case, + the location of the dimension in question is supplied either by location(b) or + if that is also Nothing, Lc. """ - function $op(Lc::Tuple, Lop::Tuple, a, b) + function $op(Lc::Tuple, a, b) La = location(a) Lb = location(b) + Lab = choose_location.(La, Lb, Lc) + grid = Oceananigans.AbstractOperations.validate_grid(a, b) - return Oceananigans.AbstractOperations._binary_operation(Lc, $op, a, b, La, Lb, Lop, grid) + + return Oceananigans.AbstractOperations._binary_operation(Lab, $op, a, b, La, Lb, grid) end - $op(Lc::Tuple, a, b) = $op(Lc, Lc, a, b) - $op(Lc::Tuple, a::Number, b) = $op(Lc, location(b), a, b) - $op(Lc::Tuple, a, b::Number) = $op(Lc, location(a), a, b) - $op(Lc::Tuple, a::AF{X, Y, Z}, b::AF{X, Y, Z}) where {X, Y, Z} = $op(Lc, location(a), a, b) + # Numbers are not fields... + $op(Lc::Tuple, a::Number, b::Number) = $op(a, b) # Sugar for mixing in functions of (x, y, z) - $op(Lc::Tuple, a::Function, b::AbstractField) = $op(Lc, FunctionField(Lc, a, b.grid), b) - $op(Lc::Tuple, a::AbstractField, b::Function) = $op(Lc, a, FunctionField(Lc, b, a.grid)) + $op(Lc::Tuple, f::Function, b::AbstractField) = $op(Lc, FunctionField(location(b), f, b.grid), b) + $op(Lc::Tuple, a::AbstractField, f::Function) = $op(Lc, a, FunctionField(location(a), f, a.grid)) # Sugary versions with default locations $op(a::AF, b::AF) = $op(location(a), a, b) @@ -184,5 +184,5 @@ end "Adapt `BinaryOperation` to work on the GPU via CUDAnative and CUDAdrv." Adapt.adapt_structure(to, binary::BinaryOperation{X, Y, Z}) where {X, Y, Z} = BinaryOperation{X, Y, Z}(Adapt.adapt(to, binary.op), Adapt.adapt(to, binary.a), Adapt.adapt(to, binary.b), - Adapt.adapt(to, binary.▶a), Adapt.adapt(to, binary.▶b), Adapt.adapt(to, binary.▶op), - binary.grid) + Adapt.adapt(to, binary.▶a), Adapt.adapt(to, binary.▶b), Adapt.adapt(to, binary.grid)) + diff --git a/src/AbstractOperations/derivatives.jl b/src/AbstractOperations/derivatives.jl index f5f3657f3e..fd895fb73e 100644 --- a/src/AbstractOperations/derivatives.jl +++ b/src/AbstractOperations/derivatives.jl @@ -1,10 +1,5 @@ using Oceananigans.Operators: interpolation_code -""" - Derivative{X, Y, Z, D, A, I, G} <: AbstractOperation{X, Y, Z, G} - -An abstract representation of a derivative of an `AbstractField`. -""" struct Derivative{X, Y, Z, D, A, I, G} <: AbstractOperation{X, Y, Z, G} ∂ :: D arg :: A @@ -122,4 +117,4 @@ compute_at!(∂::Derivative, time) = compute_at!(∂.arg, time) "Adapt `Derivative` to work on the GPU via CUDAnative and CUDAdrv." Adapt.adapt_structure(to, deriv::Derivative{X, Y, Z}) where {X, Y, Z} = Derivative{X, Y, Z}(Adapt.adapt(to, deriv.∂), Adapt.adapt(to, deriv.arg), - Adapt.adapt(to, deriv.▶), deriv.grid) + Adapt.adapt(to, deriv.▶), Adapt.adapt(to, deriv.grid)) diff --git a/src/AbstractOperations/multiary_operations.jl b/src/AbstractOperations/multiary_operations.jl index e16aa51a3d..b3b732d68c 100644 --- a/src/AbstractOperations/multiary_operations.jl +++ b/src/AbstractOperations/multiary_operations.jl @@ -162,4 +162,4 @@ end "Adapt `MultiaryOperation` to work on the GPU via CUDAnative and CUDAdrv." Adapt.adapt_structure(to, multiary::MultiaryOperation{X, Y, Z}) where {X, Y, Z} = MultiaryOperation{X, Y, Z}(Adapt.adapt(to, multiary.op), Adapt.adapt(to, multiary.args), - Adapt.adapt(to, multiary.▶), multiary.grid) + Adapt.adapt(to, multiary.▶), Adapt.adapt(to, multiary.grid)) diff --git a/src/AbstractOperations/show_abstract_operations.jl b/src/AbstractOperations/show_abstract_operations.jl index 10f8d9cf94..1c5f888a2e 100644 --- a/src/AbstractOperations/show_abstract_operations.jl +++ b/src/AbstractOperations/show_abstract_operations.jl @@ -48,7 +48,7 @@ end function tree_show(binary::BinaryOperation{X, Y, Z}, depth, nesting) where {X, Y, Z} padding = get_tree_padding(depth, nesting) - return string(binary.op, " at ", show_location(X, Y, Z), " via ", show_interp(binary.▶op), '\n', + return string(binary.op, " at ", show_location(X, Y, Z), '\n', padding, "├── ", tree_show(binary.a, depth+1, nesting+1), '\n', padding, "└── ", tree_show(binary.b, depth+1, nesting)) end diff --git a/src/AbstractOperations/unary_operations.jl b/src/AbstractOperations/unary_operations.jl index 5a0315b74c..dd763a849d 100644 --- a/src/AbstractOperations/unary_operations.jl +++ b/src/AbstractOperations/unary_operations.jl @@ -1,11 +1,5 @@ const unary_operators = Set() -""" - UnaryOperation{X, Y, Z, O, A, I, G} <: AbstractOperation{X, Y, Z, G} - -An abstract representation of a unary operation on an `AbstractField`; or a function -`f(x)` with on argument acting on `x::AbstractField`. -""" struct UnaryOperation{X, Y, Z, O, A, I, G} <: AbstractOperation{X, Y, Z, G} op :: O arg :: A @@ -131,4 +125,4 @@ compute_at!(υ::UnaryOperation, time) = compute_at!(υ.arg, time) "Adapt `UnaryOperation` to work on the GPU via CUDAnative and CUDAdrv." Adapt.adapt_structure(to, unary::UnaryOperation{X, Y, Z}) where {X, Y, Z} = UnaryOperation{X, Y, Z}(Adapt.adapt(to, unary.op), Adapt.adapt(to, unary.arg), - Adapt.adapt(to, unary.▶), unary.grid) + Adapt.adapt(to, unary.▶), Adapt.adapt(to, unary.grid)) diff --git a/src/Fields/abstract_field.jl b/src/Fields/abstract_field.jl index e48b970f08..86fccec648 100644 --- a/src/Fields/abstract_field.jl +++ b/src/Fields/abstract_field.jl @@ -102,7 +102,7 @@ end ##### AbstractField functionality ##### -@inline location(a) = nothing +@inline location(a) = (Nothing, Nothing, Nothing) "Returns the location `(X, Y, Z)` of an `AbstractField{X, Y, Z}`." @inline location(::AbstractField{X, Y, Z}) where {X, Y, Z} = (X, Y, Z) # note no instantiation diff --git a/test/test_abstract_operations_computed_field.jl b/test/test_abstract_operations_computed_field.jl index 026e83e5ac..af629569dc 100644 --- a/test/test_abstract_operations_computed_field.jl +++ b/test/test_abstract_operations_computed_field.jl @@ -544,8 +544,6 @@ end u, v, w, T, S = fields(model) - @test_throws ArgumentError @at (Nothing, Nothing, Center) T * S - for ϕ in (u, v, w, T, S) for op in (sin, cos, sqrt, exp, tanh) @test op(ϕ) isa UnaryOperation @@ -597,6 +595,15 @@ end @test compute_plus(model) @test compute_minus(model) @test compute_times(model) + + # Basic compilation test for nested BinaryOperations... + u, v, w = model.velocities + @test try + compute!(ComputedField(u + v - w)) + true + catch + false + end end @testset "Multiary computations [$FT, $(typeof(arch))]" begin @@ -661,12 +668,7 @@ end @info " Testing operations with AveragedField..." T, S = model.tracers - TS = AveragedField(T * S, dims=(1, 2)) - - @test_throws ArgumentError @at (Nothing, Nothing, Center) T * S - @test_throws ArgumentError TS * S - @test operations_with_averaged_field(model) end @@ -695,26 +697,44 @@ end @test computations_with_averaged_field_derivative(model) - # These don't work on the GPU right now - if arch isa CPU - @test computations_with_averaged_fields(model) - else - @test_skip computations_with_averaged_fields(model) - end + u, v, w = model.velocities + + set!(model, enforce_incompressibility = false, u = (x, y, z) -> z, v = 2, w = 3) + + # Two ways to compute turbulent kinetic energy + U = AveragedField(u, dims=(1, 2)) + V = AveragedField(v, dims=(1, 2)) + + # Build up compilation tests incrementally... + u_prime = u - U + u_prime_ccc = @at (Center, Center, Center) u - U + u_prime_squared = (u - U)^2 + u_prime_squared_ccc = @at (Center, Center, Center) (u - U)^2 + horizontal_twice_tke = (u - U)^2 + (v - V)^2 + horizontal_tke = ((u - U)^2 + (v - V)^2) / 2 + horizontal_tke_ccc = @at (Center, Center, Center) ((u - U)^2 + (v - V)^2) / 2 + twice_tke = (u - U)^2 + (v - V)^2 + w^2 + tke = ((u - U)^2 + (v - V)^2 + w^2) / 2 + tke_ccc = @at (Center, Center, Center) ((u - U)^2 + (v - V)^2 + w^2) / 2 + + @test try compute!(ComputedField(u_prime )); true; catch; false; end + @test try compute!(ComputedField(u_prime_ccc )); true; catch; false; end + @test try compute!(ComputedField(u_prime_squared )); true; catch; false; end + @test try compute!(ComputedField(u_prime_squared_ccc )); true; catch; false; end + @test try compute!(ComputedField(horizontal_twice_tke)); true; catch; false; end + @test try compute!(ComputedField(horizontal_tke )); true; catch; false; end + @test try compute!(ComputedField(twice_tke )); true; catch; false; end + + @test try compute!(ComputedField(horizontal_tke_ccc )); true; catch; false; end + @test try compute!(ComputedField(tke )); true; catch; false; end + @test try compute!(ComputedField(tke_ccc )); true; catch; false; end + + computed_tke = ComputedField(tke_ccc) + @test (compute!(computed_tke); all(interior(computed_tke)[2:3, 2:3, 2:3] .== 9/2)) end @testset "Computations with ComputedFields [$FT, $(typeof(arch))]" begin @info " Testing computations with ComputedField [$FT, $(typeof(arch))]..." - - # Basic compilation test... - u, v, w = model.velocities - @test try - compute!(ComputedField(u + v - w)) - true - catch - false - end - @test computations_with_computed_fields(model) end