Skip to content

Commit

Permalink
Merge #371
Browse files Browse the repository at this point in the history
371: Add benchmarks r=charleskawczynski a=charleskawczynski

This PR adds some benchmarks. Closes #369.

Co-authored-by: Charles Kawczynski <kawczynski.charles@gmail.com>
  • Loading branch information
bors[bot] and charleskawczynski authored Jun 25, 2023
2 parents a7c6d2b + 345b2aa commit b66d13a
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 8 deletions.
8 changes: 8 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,11 @@ steps:
queue: central
slurm_ntasks: 1
slurm_mem_per_cpu: 6G

- label: "Benchmarks"
command: "julia --color=yes --project=examples perf/benchmark.jl"
agents:
config: cpu
queue: central
slurm_ntasks: 1
slurm_mem_per_cpu: 6G
20 changes: 19 additions & 1 deletion examples/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.8.5"
manifest_format = "2.0"
project_hash = "7501ebcb82a24608b802115bb6f84feb98949827"
project_hash = "243c6060b2bc9f1d650bfee030845a40aea5b195"

[[deps.AbstractFFTs]]
deps = ["ChainRulesCore", "LinearAlgebra"]
Expand Down Expand Up @@ -44,6 +44,12 @@ version = "0.4.2"
[[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[deps.BenchmarkTools]]
deps = ["JSON", "Logging", "Printf", "Profile", "Statistics", "UUIDs"]
git-tree-sha1 = "d9a9701b899b30332bbcb3e1679c41cce81fb0e8"
uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
version = "1.3.2"

[[deps.BitFlags]]
git-tree-sha1 = "43b1a4a8f797c1cddadf60499a8a077d4af2cd2d"
uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35"
Expand Down Expand Up @@ -235,6 +241,12 @@ git-tree-sha1 = "2613d054b0e18a3dea99ca1594e9a3960e025da4"
uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3"
version = "1.9.7"

[[deps.Infiltrator]]
deps = ["InteractiveUtils", "Markdown", "REPL", "UUIDs"]
git-tree-sha1 = "6e48065ac352c8c9616013faa419b0ea65bb6455"
uuid = "5903a43b-9cc3-4c30-8d17-598619ec4e9b"
version = "1.6.3"

[[deps.InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand Down Expand Up @@ -600,6 +612,12 @@ version = "1.4.0"
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[deps.Suppressor]]
deps = ["Logging"]
git-tree-sha1 = "9a428c8eb6cca9a9566ab619b176e83d441064ba"
uuid = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
version = "0.2.3"

[[deps.TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Expand Down
3 changes: 3 additions & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CLIMAParameters = "6eacf6c3-8458-43b9-ae03-caf5306d3d53"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Coverage = "a2441757-f6aa-5fb2-8edb-039e3f45d037"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
GaussQuadrature = "d54b0c1a-921d-58e0-8e36-89d8069c0969"
Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
ProfileCanvas = "efd6af41-a80b-495e-886c-e51b0c7d77a3"
RRTMGP = "a01a1ee8-cea4-48fc-987c-fc7878d79da1"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
91 changes: 91 additions & 0 deletions perf/benchmark.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#=
julia --project=examples
julia --project=examples perf/benchmark.jl
```
include(joinpath("perf", "benchmark.jl"))
```
=#
# ARGS[1]

using BenchmarkTools
using Suppressor
root_dir = joinpath(dirname(@__DIR__))
import ClimaComms
import Logging

@info "------------------------------------------------- Benchmark: gray_atm"
@suppress_out begin
include(joinpath(root_dir, "test", "gray_atm_utils.jl"))
gray_atmos_lw_equil(ClimaComms.context(), OneScalar, Float64; exfiltrate = true)
end
(; slv, max_threads) = Infiltrator.exfiltrated
@info "gray_atm lw"
solve_lw!(slv, max_threads) # compile first
trial = @benchmark solve_lw!(slv, max_threads)
show(stdout, MIME("text/plain"), trial)
println()

gray_atmos_sw_test(ClimaComms.context(), OneScalar, Float64, 1; exfiltrate = true)
(; slv, max_threads) = Infiltrator.exfiltrated
solve_sw!(slv, max_threads) # compile first
@info "gray_atm sw"
trial = @benchmark solve_sw!(slv, max_threads)
show(stdout, MIME("text/plain"), trial)
println()
@info "------------------------------------------------- Benchmark: clear_sky"
@suppress_out begin
include(joinpath(root_dir, "test", "clear_sky_utils.jl"))
context = ClimaComms.context()
clear_sky(ClimaComms.context(), TwoStream, SourceLW2Str, VmrGM, Float64; exfiltrate = true)
end
(; slv, max_threads, lookup_sw, lookup_lw) = Infiltrator.exfiltrated

@info "clear_sky lw"
solve_lw!(slv, max_threads, lookup_lw) # compile first
trial = @benchmark solve_lw!(slv, max_threads, lookup_lw)
show(stdout, MIME("text/plain"), trial)
println()

@info "clear_sky sw"
solve_sw!(slv, max_threads, lookup_sw) # compile first
trial = @benchmark solve_sw!(slv, max_threads, lookup_sw)
show(stdout, MIME("text/plain"), trial)
println()

@info "------------------------------------------------- Benchmark: all_sky"
@suppress_out begin
include(joinpath(root_dir, "test", "all_sky_utils.jl"))
all_sky(ClimaComms.context(), TwoStream, Float64; use_lut = true, cldfrac = Float64(1), exfiltrate = true)
end

(; slv, max_threads, lookup_sw, lookup_sw_cld, lookup_lw, lookup_lw_cld) = Infiltrator.exfiltrated

solve_sw!(slv, max_threads, lookup_sw, lookup_sw_cld) # compile first
solve_lw!(slv, max_threads, lookup_lw, lookup_lw_cld) # compile first

@info "all_sky, lw, use_lut=true"
trial = @benchmark solve_lw!(slv, max_threads, lookup_lw, lookup_lw_cld)
show(stdout, MIME("text/plain"), trial)
println()
@info "all_sky, sw, use_lut=true"
trial = @benchmark solve_sw!(slv, max_threads, lookup_sw, lookup_sw_cld)
show(stdout, MIME("text/plain"), trial)
println()

@suppress_out begin
all_sky(ClimaComms.context(), TwoStream, Float64; use_lut = false, cldfrac = Float64(1), exfiltrate = true)
end

(; slv, max_threads, lookup_sw, lookup_sw_cld, lookup_lw, lookup_lw_cld) = Infiltrator.exfiltrated

solve_sw!(slv, max_threads, lookup_sw, lookup_sw_cld) # compile first
solve_lw!(slv, max_threads, lookup_lw, lookup_lw_cld) # compile first

@info "all_sky, lw, use_lut=false"
trial = @benchmark solve_lw!(slv, max_threads, lookup_lw, lookup_lw_cld)
show(stdout, MIME("text/plain"), trial)
println()
@info "all_sky, sw, use_lut=false"
trial = @benchmark solve_sw!(slv, max_threads, lookup_sw, lookup_sw_cld)
show(stdout, MIME("text/plain"), trial)
println()
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Coverage = "a2441757-f6aa-5fb2-8edb-039e3f45d037"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
GaussQuadrature = "d54b0c1a-921d-58e0-8e36-89d8069c0969"
Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand Down
4 changes: 2 additions & 2 deletions test/all_sky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ include("all_sky_utils.jl")

context = ClimaComms.context()
@testset "Cloudy (all-sky, Two-stream calculations using lookup table method" begin
@time all_sky(context, TwoStream, Float64, use_lut = true, cldfrac = Float64(1))
@time all_sky(context, TwoStream, Float64; use_lut = true, cldfrac = Float64(1))
end
@testset "Cloudy (all-sky), Two-stream calculations using Pade method" begin
@time all_sky(context, TwoStream, Float64, use_lut = false, cldfrac = Float64(1))
@time all_sky(context, TwoStream, Float64; use_lut = false, cldfrac = Float64(1))
end
4 changes: 3 additions & 1 deletion test/all_sky_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Pkg.Artifacts
using NCDatasets

import JET
import Infiltrator
import ClimaComms
using RRTMGP
using RRTMGP.Vmrs
Expand Down Expand Up @@ -31,6 +32,7 @@ function all_sky(
::Type{FT};
use_lut::Bool = true,
cldfrac = FT(1),
exfiltrate = false,
) where {FT <: AbstractFloat, OPC}
opc = Symbol(OPC)
device = ClimaComms.device(context)
Expand Down Expand Up @@ -127,6 +129,7 @@ function all_sky(
end

println("calling shortwave solver; ncol = $ncol")
exfiltrate && Infiltrator.@exfiltrate
@time solve_sw!(slv, max_threads, lookup_sw, lookup_sw_cld)
if device isa ClimaComms.CPUSingleThreaded
JET.@test_opt solve_sw!(slv, max_threads, lookup_sw, lookup_sw_cld)
Expand Down Expand Up @@ -174,4 +177,3 @@ function all_sky(
@test max_err_flux_up_sw < toler
@test max_err_flux_dn_sw < toler
end

6 changes: 4 additions & 2 deletions test/clear_sky_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Pkg.Artifacts
using NCDatasets
import JET
import ClimaComms
import Infiltrator

using RRTMGP
using RRTMGP.Vmrs
Expand Down Expand Up @@ -31,7 +32,8 @@ function clear_sky(
::Type{OPC},
::Type{SRC},
::Type{VMR},
::Type{FT},
::Type{FT};
exfiltrate = false,
) where {FT <: AbstractFloat, OPC, SRC, VMR}
device = ClimaComms.device(context)
DA = ClimaComms.array_type(device)
Expand Down Expand Up @@ -91,7 +93,7 @@ function clear_sky(
#--------------------------------------------------
# initializing RTE solver
slv = Solver(context, as, op, src_lw, src_sw, bcs_lw, bcs_sw, fluxb_lw, fluxb_sw, flux_lw, flux_sw)

exfiltrate && Infiltrator.@exfiltrate
println("calling longwave solver; ncol = $ncol")
@time solve_lw!(slv, max_threads, lookup_lw)
if device isa ClimaComms.CPUSingleThreaded
Expand Down
13 changes: 11 additions & 2 deletions test/gray_atm_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Test
using CUDA
import ClimaComms
import JET
import Infiltrator
using RRTMGP
using RRTMGP.AngularDiscretizations
using RRTMGP.Fluxes
Expand All @@ -25,7 +26,7 @@ param_set = create_insolation_parameters(Float64)
"""
Example program to demonstrate the calculation of longwave radiative fluxes in a model gray atmosphere.
"""
function gray_atmos_lw_equil(context, ::Type{OPC}, ::Type{FT}) where {FT <: AbstractFloat, OPC}
function gray_atmos_lw_equil(context, ::Type{OPC}, ::Type{FT}; exfiltrate = false) where {FT <: AbstractFloat, OPC}
device = ClimaComms.device(context)
ncol = if device isa ClimaComms.CUDADevice
4096
Expand Down Expand Up @@ -84,6 +85,7 @@ function gray_atmos_lw_equil(context, ::Type{OPC}, ::Type{FT}) where {FT <: Abst
T_ex_lev = DA{FT}(undef, nlev, ncol)
flux_grad = DA{FT}(undef, nlay, ncol)
flux_grad_err = FT(0)
exfiltrate && Infiltrator.@exfiltrate
for i in 1:nsteps
# calling the long wave gray radiation solver
solve_lw!(slv, max_threads)
Expand Down Expand Up @@ -130,7 +132,13 @@ function gray_atmos_lw_equil(context, ::Type{OPC}, ::Type{FT}) where {FT <: Abst
end
#------------------------------------------------------------------------------

function gray_atmos_sw_test(context, ::Type{OPC}, ::Type{FT}, ncol::Int) where {FT <: AbstractFloat, OPC}
function gray_atmos_sw_test(
context,
::Type{OPC},
::Type{FT},
ncol::Int;
exfiltrate = false,
) where {FT <: AbstractFloat, OPC}
device = ClimaComms.device(context)
DA = ClimaComms.array_type(device)
opc = Symbol(OPC)
Expand Down Expand Up @@ -178,6 +186,7 @@ function gray_atmos_sw_test(context, ::Type{OPC}, ::Type{FT}, ncol::Int) where {
flux_sw = FluxSW(ncol, nlay, FT, DA)

slv = Solver(context, as, op, nothing, src_sw, nothing, bcs_sw, nothing, fluxb_sw, nothing, flux_sw)
exfiltrate && Infiltrator.@exfiltrate
solve_sw!(slv, max_threads)

τ = Array(slv.op.τ)
Expand Down

0 comments on commit b66d13a

Please sign in to comment.