Skip to content

Commit

Permalink
Tune tolerances for single-precision solver (#421)
Browse files Browse the repository at this point in the history
Improve result printout for tests
  • Loading branch information
sriharshakandala authored Jan 10, 2024
1 parent 1a52a76 commit 9ef0b6c
Show file tree
Hide file tree
Showing 11 changed files with 197 additions and 79 deletions.
8 changes: 4 additions & 4 deletions src/optics/GasOptics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,15 @@ Compute interpolation fraction for binary species parameter.
itemp = 1
@inbounds η_half = vmr_ref[tropo, ig[1] + 1, jtemp + itemp - 1] / vmr_ref[tropo, ig[2] + 1, jtemp + itemp - 1]
col_mix1 = vmr1 + η_half * vmr2
η = col_mix1 eps(FT) * 2 ? vmr1 / col_mix1 : FT(0.5)
η = col_mix1 > 0 ? vmr1 / col_mix1 : FT(0.5) # rte-rrtmgp uses col_mix1 > tiny(col_mix1)
loc_η = FT* (n_η - 1))
jη1 = min(unsafe_trunc(Int, loc_η) + 1, n_η - 1)
fη1 = loc_η - unsafe_trunc(Int, loc_η)

itemp = 2
@inbounds η_half = vmr_ref[tropo, ig[1] + 1, jtemp + itemp - 1] / vmr_ref[tropo, ig[2] + 1, jtemp + itemp - 1]
col_mix2 = vmr1 + η_half * vmr2
η = col_mix2 eps(FT) * 2 ? vmr1 / col_mix2 : FT(0.5)
η = col_mix2 > 0 ? vmr1 / col_mix2 : FT(0.5) # rte-rrtmgp uses col_mix2 > tiny(col_mix2)
loc_η = FT* (n_η - 1))
jη2 = min(unsafe_trunc(Int, loc_η) + 1, n_η - 1)
fη2 = loc_η - unsafe_trunc(Int, loc_η)
Expand Down Expand Up @@ -159,7 +159,7 @@ Compute optical thickness, single scattering albedo, and asymmetry parameter.
end
τ = τ_major + τ_minor + τ_ray
ssa = FT(0)
if τ > 2 * eps(FT) # single scattering albedo
if τ > 0 # single scattering albedo
ssa = τ_ray / τ
end
return (τ, ssa, zero(τ)) # initializing asymmetry parameter
Expand Down Expand Up @@ -234,7 +234,7 @@ Compute optical thickness contributions from minor gases.

@inbounds for i in minor_bnd_st[ibnd]:(minor_bnd_st[ibnd + 1] - 1)
vmr_imnr = get_vmr(vmr, idx_gases_minor[i], glay, gcol)
if vmr_imnr > eps(FT) * 2
if vmr_imnr > 0
scaling = vmr_imnr * col_dry

if minor_scales_with_density[i] == 1
Expand Down
4 changes: 2 additions & 2 deletions src/rte/longwave1scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ function rte_lw_noscat_source!(src_lw::SourceLWNoScat{FT}, op::OneScalar{FT}, gc
Ds = op.angle_disc.gauss_Ds
τ = op.τ

τ_thresh = sqrt(sqrt(eps(FT))) # or abs(eps(FT))?
τ_thresh = 100 * eps(FT)

@inbounds for glay in 1:nlay
τ_loc = τ[glay, gcol] * Ds[1] # Optical path and transmission,
τ_loc = τ[glay, gcol] * Ds[1] # Optical path and transmission,

trans = exp(-τ_loc) # used in source function and transport calculations
# Weighting factor. Use 2nd order series expansion when rounding error (~tau^2)
Expand Down
7 changes: 4 additions & 3 deletions src/rte/longwave2stream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,11 @@ function rte_lw_2stream_source!(
# setting references
(; τ, ssa, g) = op
(; Rdif, Tdif, lev_source, src_up, src_dn) = src_lw
#k_min = FT === Float64 ? FT(1e-12) : FT(1e-4)
k_min = FT(1e4 * eps(FT))
#k_min = FT === Float64 ? FT(1e-12) : FT(1e-4) used in RRTMGP-RTE FORTRAN code
k_min = sqrt(eps(FT)) #FT(1e4 * eps(FT))
lw_diff_sec = FT(1.66)
τ_thresh = sqrt(eps(FT))
τ_thresh = 100 * eps(FT)# tau(icol,ilay) > 1.0e-8_wp used in rte-rrtmgp
# this is chosen to prevent catastrophic cancellation in src_up and src_dn calculation

@inbounds for glay in 1:nlay
γ1 = lw_diff_sec * (1 - FT(0.5) * ssa[glay, gcol] * (1 + g[glay, gcol]))
Expand Down
2 changes: 1 addition & 1 deletion src/rte/shortwave2stream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ Equations are developed in Meador and Weaver, 1980,
doi:10.1175/1520-0469(1980)037<0630:TSATRT>2.0.CO;2
"""
function sw_2stream_coeffs::FT, ssa::FT, g::FT, μ₀::FT) where {FT}
k_min = FT(1e4 * eps(FT)) # Suggestion from Chiel van Heerwaarden
k_min = sqrt(eps(FT)) #FT(1e4 * eps(FT)) # Suggestion from Chiel van Heerwaarden
# Zdunkowski Practical Improved Flux Method "PIFM"
# (Zdunkowski et al., 1980; Contributions to Atmospheric Physics 53, 147-66)
γ1 = (FT(8) - ssa * (FT(5) + FT(3) * g)) * FT(0.25)
Expand Down
69 changes: 50 additions & 19 deletions test/all_sky_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,15 @@ function all_sky(
# initializing RTE solver
slv = Solver(context, as, op, src_lw, src_sw, bcs_lw, bcs_sw, fluxb_lw, fluxb_sw, flux_lw, flux_sw)
#------calling solvers
println("calling longwave solver; ncol = $ncol")
@time solve_lw!(slv, max_threads, lookup_lw, lookup_lw_cld)
solve_lw!(slv, max_threads, lookup_lw, lookup_lw_cld)
if device isa ClimaComms.CPUSingleThreaded
JET.@test_opt solve_lw!(slv, max_threads, lookup_lw, lookup_lw_cld)
@test_broken (@allocated solve_lw!(slv, max_threads, lookup_lw, lookup_lw_cld)) == 0
@test (@allocated solve_lw!(slv, max_threads, lookup_lw, lookup_lw_cld)) 736
end

println("calling shortwave solver; ncol = $ncol")
exfiltrate && Infiltrator.@exfiltrate
@time solve_sw!(slv, max_threads, lookup_sw, lookup_sw_cld)
solve_sw!(slv, max_threads, lookup_sw, lookup_sw_cld)
if device isa ClimaComms.CPUSingleThreaded
JET.@test_opt solve_sw!(slv, max_threads, lookup_sw, lookup_sw_cld)
@test_broken (@allocated solve_sw!(slv, max_threads, lookup_sw, lookup_sw_cld)) == 0
Expand All @@ -138,34 +136,67 @@ function all_sky(
method = use_lut ? "Lookup Table Interpolation method" : "PADE method"
comp_flux_up_lw, comp_flux_dn_lw, comp_flux_up_sw, comp_flux_dn_sw = load_comparison_data(use_lut, bot_at_1, ncol)

comp_flux_net_lw = comp_flux_up_lw .- comp_flux_dn_lw
comp_flux_net_sw = comp_flux_up_sw .- comp_flux_dn_sw

flux_up_lw = Array(slv.flux_lw.flux_up)
flux_dn_lw = Array(slv.flux_lw.flux_dn)
flux_net_lw = flux_up_lw .- flux_dn_lw

max_err_flux_up_lw = maximum(abs.(flux_up_lw .- comp_flux_up_lw))
max_err_flux_dn_lw = maximum(abs.(flux_dn_lw .- comp_flux_dn_lw))
println("=======================================")
println("Cloudy-sky longwave test - $opc")
println(method)
println("max_err_flux_up_lw = $max_err_flux_up_lw")
println("max_err_flux_dn_lw = $max_err_flux_dn_lw")
max_err_flux_net_lw = maximum(abs.(flux_net_lw .- comp_flux_net_lw))

rel_err_flux_net_lw = abs.(flux_net_lw .- comp_flux_net_lw)

for gcol in 1:ncol, glev in 1:nlev
den = abs(comp_flux_net_lw[glev, gcol])
if den > 10 * eps(FT)
rel_err_flux_net_lw[glev, gcol] /= den
end
end
max_rel_err_flux_net_lw = maximum(rel_err_flux_net_lw)
color2 = :cyan
printstyled("Cloudy-sky longwave test with ncol = $ncol, nlev = $nlev, OPC = $opc, FT = $FT\n", color = color2)
printstyled("device = $device\n", color = color2)
printstyled("$method\n\n", color = color2)
println("L∞ error in flux_up = $max_err_flux_up_lw")
println("L∞ error in flux_dn = $max_err_flux_dn_lw")
println("L∞ error in flux_net = $max_err_flux_net_lw")
println("L∞ relative error in flux_net = $(max_rel_err_flux_net_lw * 100) %\n")

flux_up_sw = Array(slv.flux_sw.flux_up)
flux_dn_sw = Array(slv.flux_sw.flux_dn)
flux_dn_dir_sw = Array(slv.flux_sw.flux_dn_dir)
flux_net_sw = flux_up_sw .- flux_dn_sw

max_err_flux_up_sw = maximum(abs.(flux_up_sw .- comp_flux_up_sw))
max_err_flux_dn_sw = maximum(abs.(flux_dn_sw .- comp_flux_dn_sw))
max_err_flux_net_sw = maximum(abs.(flux_net_sw .- comp_flux_net_sw))

rel_err_flux_net_sw = abs.(flux_net_sw .- comp_flux_net_sw)

for gcol in 1:ncol, glev in 1:nlev
den = abs(comp_flux_net_sw[glev, gcol])
if den > 10 * eps(FT)
rel_err_flux_net_sw[glev, gcol] /= den
end
end
max_rel_err_flux_net_sw = maximum(rel_err_flux_net_sw)

printstyled("Cloudy-sky shortwave test with ncol = $ncol, nlev = $nlev, OPC = $opc, FT = $FT\n", color = color2)
printstyled("device = $device\n", color = color2)
printstyled("$method\n\n", color = color2)
println("L∞ error in flux_up = $max_err_flux_up_sw")
println("L∞ error in flux_dn = $max_err_flux_dn_sw")
println("L∞ error in flux_net = $max_err_flux_net_sw")
println("L∞ relative error in flux_net = $(max_rel_err_flux_net_sw * 100) %\n")

println("Cloudy-sky shortwave test - $opc")
println(method)
println("max_err_flux_up_sw = $max_err_flux_up_sw")
println("max_err_flux_dn_sw = $max_err_flux_dn_sw")
println("=======================================")
toler = FT(1e-5)
toler = Dict(Float64 => Float64(1e-5), Float32 => Float32(0.05))

@test max_err_flux_up_lw toler broken = (FT == Float32)
@test max_err_flux_dn_lw toler broken = (FT == Float32)
@test max_err_flux_up_lw toler[FT]
@test max_err_flux_dn_lw toler[FT]

@test max_err_flux_up_sw toler broken = (FT == Float32)
@test max_err_flux_dn_sw toler broken = (FT == Float32)
@test max_err_flux_up_sw toler[FT]
@test max_err_flux_dn_sw toler[FT]
end
70 changes: 51 additions & 19 deletions test/clear_sky_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,14 @@ function clear_sky(
# initializing RTE solver
slv = Solver(context, as, op, src_lw, src_sw, bcs_lw, bcs_sw, fluxb_lw, fluxb_sw, flux_lw, flux_sw)
exfiltrate && Infiltrator.@exfiltrate
println("calling longwave solver; ncol = $ncol")
@time solve_lw!(slv, max_threads, lookup_lw)
solve_lw!(slv, max_threads, lookup_lw)
if device isa ClimaComms.CPUSingleThreaded
JET.@test_opt solve_lw!(slv, max_threads, lookup_lw)
@test_broken (@allocated solve_lw!(slv, max_threads, lookup_lw)) == 0
@test (@allocated solve_lw!(slv, max_threads, lookup_lw)) 448
end

println("calling shortwave solver; ncol = $ncol")
@time solve_sw!(slv, max_threads, lookup_sw)
solve_sw!(slv, max_threads, lookup_sw)
if device isa ClimaComms.CPUSingleThreaded
JET.@test_opt solve_sw!(slv, max_threads, lookup_sw)
@test_broken (@allocated solve_sw!(slv, max_threads, lookup_sw)) == 0
Expand All @@ -121,19 +119,35 @@ function clear_sky(
comp_flux_dn_lw = Array(ds_flux_dn_lw["rld"])[flip_ind, :, exp_no]
close(ds_flux_dn_lw)

comp_flux_net_lw = comp_flux_up_lw .- comp_flux_dn_lw

flux_up_lw = Array(slv.flux_lw.flux_up)
flux_dn_lw = Array(slv.flux_lw.flux_dn)

flux_net_lw = flux_up_lw .- flux_dn_lw

max_err_flux_up_lw = maximum(abs.(flux_up_lw .- comp_flux_up_lw))
max_err_flux_dn_lw = maximum(abs.(flux_dn_lw .- comp_flux_dn_lw))
max_err_flux_net_lw = maximum(abs.(flux_net_lw .- comp_flux_net_lw))

rel_err_flux_net_lw = abs.(flux_net_lw .- comp_flux_net_lw)

for gcol in 1:ncol, glev in 1:nlev
den = abs(comp_flux_net_lw[glev, gcol])
if den > 10 * eps(FT)
rel_err_flux_net_lw[glev, gcol] /= den
end
end
max_rel_err_flux_net_lw = maximum(rel_err_flux_net_lw)

println("=======================================")
println("Clear-sky longwave test - $opc")
println("max_err_flux_up_lw = $max_err_flux_up_lw")
println("max_err_flux_dn_lw = $max_err_flux_dn_lw")
color2 = :cyan
printstyled("Clear-sky longwave test with ncol = $ncol, nlev = $nlev, OPC = $opc, FT = $FT\n", color = color2)
printstyled("device = $device\n\n", color = color2)
println("L∞ error in flux_up = $max_err_flux_up_lw")
println("L∞ error in flux_dn = $max_err_flux_dn_lw")
println("L∞ error in flux_net = $max_err_flux_net_lw")
println("L∞ relative error in flux_net = $(max_rel_err_flux_net_lw * 100) %\n")

toler_lw = 1e-4
#--------------------------------------------------------------
# comparing shortwave fluxes with data from RRTMGP FORTRAN code
ds_flux_up_sw = Dataset(flux_up_file_sw, "r")
comp_flux_up_sw = Array(ds_flux_up_sw["rsu"])[flip_ind, :, exp_no]
Expand All @@ -143,29 +157,47 @@ function clear_sky(
comp_flux_dn_sw = Array(ds_flux_dn_sw["rsd"])[flip_ind, :, exp_no]
close(ds_flux_dn_sw)

comp_flux_net_sw = comp_flux_up_sw .- comp_flux_dn_sw

flux_up_sw = Array(slv.flux_sw.flux_up)
flux_dn_sw = Array(slv.flux_sw.flux_dn)
flux_net_sw = flux_up_sw .- flux_dn_sw

for i in 1:ncol
if usecol[i] == 0
flux_up_sw[:, i] .= FT(0)
flux_dn_sw[:, i] .= FT(0)
flux_net_sw[:, i] .= FT(0)
end
end

max_err_flux_up_sw = maximum(abs.(flux_up_sw .- comp_flux_up_sw))
max_err_flux_dn_sw = maximum(abs.(flux_dn_sw .- comp_flux_dn_sw))
max_err_flux_net_sw = maximum(abs.(flux_net_sw .- comp_flux_net_sw))

rel_err_flux_net_sw = abs.(flux_net_sw .- comp_flux_net_sw)

for gcol in 1:ncol, glev in 1:nlev
den = abs(comp_flux_net_sw[glev, gcol])
if den > 10 * eps(FT)
rel_err_flux_net_sw[glev, gcol] /= den
end
end

println("Clear-sky shortwave test, opc = $opc")
println("max_err_flux_up_sw = $max_err_flux_up_sw")
println("max_err_flux_dn_sw = $max_err_flux_dn_sw")
println("=======================================")
max_rel_err_flux_net_sw = maximum(rel_err_flux_net_sw)

printstyled("Clear-sky shortwave test with ncol = $ncol, nlev = $nlev, OPC = $opc, FT = $FT\n", color = color2)
printstyled("device = $device\n\n", color = color2)
println("L∞ error in flux_up = $max_err_flux_up_sw")
println("L∞ error in flux_dn = $max_err_flux_dn_sw")
println("L∞ error in flux_net = $max_err_flux_net_sw")
println("L∞ relative error in flux_net = $(max_rel_err_flux_net_sw * 100) %\n")

toler_sw = FT(0.001)
toler_lw = Dict(Float64 => Float64(1e-4), Float32 => Float32(0.04))
toler_sw = Dict(Float64 => Float64(1e-3), Float32 => Float32(0.04))

@test max_err_flux_up_lw toler_lw broken = (FT == Float32)
@test max_err_flux_dn_lw toler_lw broken = (FT == Float32)
@test max_err_flux_up_sw toler_sw broken = (FT == Float32)
@test max_err_flux_dn_sw toler_sw broken = (FT == Float32)
@test max_err_flux_up_lw toler_lw[FT]
@test max_err_flux_dn_lw toler_lw[FT]
@test max_err_flux_up_sw toler_sw[FT]
@test max_err_flux_dn_sw toler_sw[FT]
end
20 changes: 14 additions & 6 deletions test/gray_atm_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,13 @@ function gray_atmos_lw_equil(context, ::Type{OPC}, ::Type{FT}; exfiltrate = fals
end

t_error = maximum(abs.(T_ex_lev .- gray_as.t_lev))
println("*************************************************")
println("Longwave test for gray atmosphere model - $opc; ncol = $ncol; context = $context")
println("Integration time = $(FT(nsteps)/FT(24.0/tstep) / FT(365.0)) years")
println("t_error = $(t_error); flux_grad_err = $(flux_grad_err)")
color2 = :cyan

printstyled("\nGray atmosphere longwave test with ncol = $ncol, nlev = $nlev, OPC = $opc\n", color = color2)
printstyled("device = $device\n", color = color2)
printstyled("Integration time = $(FT(nsteps)/FT(24.0/tstep) / FT(365.0)) years\n\n", color = color2)

println("t_error = $(t_error); flux_grad_err = $(flux_grad_err)\n")

@test maximum(t_error) < temp_toler
if device isa ClimaComms.CPUSingleThreaded
Expand Down Expand Up @@ -201,9 +204,14 @@ function gray_atmos_sw_test(

rel_toler = FT(0.001)
rel_error = abs(flux_dn_dir[1] - exact) / exact
println("*************************************************")
println("Running shortwave test for gray atmosphere model - $(opc); ncol = $ncol; context = $context")

color2 = :cyan

printstyled("\nGray atmosphere shortwave test with ncol = $ncol, nlev = $nlev, OPC = $opc\n", color = color2)
printstyled("device = $device\n\n", color = color2)

println("relative error = $rel_error")

@test rel_error < rel_toler

if device isa ClimaComms.CPUSingleThreaded
Expand Down
15 changes: 10 additions & 5 deletions test/rfmip_clear_sky_lw.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,16 @@ function lw_rfmip(context, ::Type{OPC}, ::Type{SRC}, ::Type{VMR}, ::Type{FT}) wh
max_err_flux_up = FT(maximum(abs.(Array(slv.flux_lw.flux_up) .- comp_flux_up)))
max_err_flux_dn = FT(maximum(abs.(Array(slv.flux_lw.flux_dn) .- comp_flux_dn)))

println("=======================================")
println("Clear-sky longwave test - $opc")
println("max_err_flux_up = $max_err_flux_up")
println("max_err_flux_dn = $max_err_flux_dn")
println("=======================================")
color2 = :cyan

printstyled(
"Stand-alone clear-sky longwave test with ncol = $ncol, nlev = $nlev, OPC = $opc, FT = $FT\n",
color = color2,
)
printstyled("device = $device\n\n", color = color2)
println("L∞ error in flux_up = $max_err_flux_up")
println("L∞ error in flux_dn = $max_err_flux_dn\n")

toler = FT(1e-4)

@test max_err_flux_up toler
Expand Down
14 changes: 10 additions & 4 deletions test/rfmip_clear_sky_sw.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,16 @@ function sw_rfmip(context, ::Type{OPC}, ::Type{SRC}, ::Type{VMR}, ::Type{FT}) wh
max_err_flux_up = maximum(abs.(flux_up .- comp_flux_up))
max_err_flux_dn = maximum(abs.(flux_dn .- comp_flux_dn))

println("=======================================")
println("Clear-sky shortwave test, opc = $opc")
println("max_err_flux_up = $max_err_flux_up")
println("max_err_flux_dn = $max_err_flux_dn")
color2 = :cyan

printstyled(
"Stand-alone clear-sky shortwave test with ncol = $ncol, nlev = $nlev, OPC = $opc, FT = $FT\n",
color = color2,
)
printstyled("device = $device\n\n", color = color2)
println("L∞ error in flux_up = $max_err_flux_up")
println("L∞ error in flux_dn = $max_err_flux_dn\n")


toler = FT(0.001)
@test maximum(abs.(flux_up .- comp_flux_up)) toler
Expand Down
Loading

0 comments on commit 9ef0b6c

Please sign in to comment.