Skip to content

Commit

Permalink
Move AngularDiscretization to NoScatLWRTE
Browse files Browse the repository at this point in the history
  • Loading branch information
sriharshakandala committed Jun 12, 2024
1 parent ad01895 commit 95e9529
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 28 deletions.
1 change: 1 addition & 0 deletions ext/RRTMGPCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module RRTMGPCUDAExt

import ClimaComms
import RRTMGP.Parameters as RP
import RRTMGP.AngularDiscretizations.AngularDiscretization
import RRTMGP.Fluxes: FluxLW, FluxSW
import RRTMGP.Fluxes: add_to_flux!
import RRTMGP.Fluxes: set_flux_to_zero!
Expand Down
16 changes: 10 additions & 6 deletions ext/cuda/rte_longwave_1scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ function rte_lw_noscat_solve!(
src_lw::SourceLWNoScat,
bcs_lw::LwBCs,
op::OneScalar,
angle_disc::AngularDiscretization,
as::GrayAtmosphericState,
)
nlay, ncol = AtmosphericStates.get_dims(as)
nlev = nlay + 1
tx, bx = _configure_threadblock(ncol)
args = (flux_lw, src_lw, bcs_lw, op, nlay, ncol, as)
args = (flux_lw, src_lw, bcs_lw, op, angle_disc, nlay, ncol, as)
@cuda always_inline = true threads = (tx) blocks = (bx) rte_lw_noscat_solve_CUDA!(args...)
return nothing
end
Expand All @@ -19,6 +20,7 @@ function rte_lw_noscat_solve_CUDA!(
src_lw::SourceLWNoScat,
bcs_lw::LwBCs,
op::OneScalar,
angle_disc::AngularDiscretization,
nlay,
ncol,
as::GrayAtmosphericState,
Expand All @@ -27,11 +29,11 @@ function rte_lw_noscat_solve_CUDA!(
nlev = nlay + 1
igpt, ibnd = 1, 1
τ = op.τ
Ds = op.angle_disc.gauss_Ds
Ds = angle_disc.gauss_Ds
(; flux_up, flux_dn, flux_net) = flux_lw
if gcol ncol
compute_optical_props!(op, as, src_lw, gcol)
rte_lw_noscat!(src_lw, bcs_lw, op, gcol, flux_lw, igpt, ibnd, nlay, nlev)
rte_lw_noscat!(src_lw, bcs_lw, op, angle_disc, gcol, flux_lw, igpt, ibnd, nlay, nlev)
@inbounds for ilev in 1:nlev
flux_net[ilev, gcol] = flux_up[ilev, gcol] - flux_dn[ilev, gcol]
end
Expand All @@ -46,14 +48,15 @@ function rte_lw_noscat_solve!(
src_lw::SourceLWNoScat,
bcs_lw::LwBCs,
op::OneScalar,
angle_disc::AngularDiscretization,
as::AtmosphericState,
lookup_lw::LookUpLW,
lookup_lw_cld::Union{LookUpCld, PadeCld, Nothing} = nothing,
)
nlay, ncol = AtmosphericStates.get_dims(as)
nlev = nlay + 1
tx, bx = _configure_threadblock(ncol)
args = (flux, flux_lw, src_lw, bcs_lw, op, nlay, ncol, as, lookup_lw, lookup_lw_cld)
args = (flux, flux_lw, src_lw, bcs_lw, op, angle_disc, nlay, ncol, as, lookup_lw, lookup_lw_cld)
@cuda always_inline = true threads = (tx) blocks = (bx) rte_lw_noscat_solve_CUDA!(args...)
return nothing
end
Expand All @@ -64,6 +67,7 @@ function rte_lw_noscat_solve_CUDA!(
src_lw::SourceLWNoScat,
bcs_lw::LwBCs,
op::OneScalar,
angle_disc::AngularDiscretization,
nlay,
ncol,
as::AtmosphericState,
Expand All @@ -75,7 +79,7 @@ function rte_lw_noscat_solve_CUDA!(
(; major_gpt2bnd) = lookup_lw.band_data
n_gpt = length(major_gpt2bnd)
τ = op.τ
Ds = op.angle_disc.gauss_Ds
Ds = angle_disc.gauss_Ds
if gcol ncol
flux_up_lw = flux_lw.flux_up
flux_dn_lw = flux_lw.flux_dn
Expand All @@ -84,7 +88,7 @@ function rte_lw_noscat_solve_CUDA!(
ibnd = major_gpt2bnd[igpt]
igpt == 1 && set_flux_to_zero!(flux_lw, gcol)
compute_optical_props!(op, as, src_lw, gcol, igpt, lookup_lw, lookup_lw_cld)
rte_lw_noscat!(src_lw, bcs_lw, op, gcol, flux, igpt, ibnd, nlay, nlev)
rte_lw_noscat!(src_lw, bcs_lw, op, angle_disc, gcol, flux, igpt, ibnd, nlay, nlev)
add_to_flux!(flux_lw, flux, gcol)
end
@inbounds begin
Expand Down
2 changes: 1 addition & 1 deletion ext/cuda/rte_shortwave_1scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function rte_sw_noscat_solve_CUDA!(flux_sw::FluxSW, op::OneScalar, bcs_sw::SwBCs
gcol = threadIdx().x + (blockIdx().x - 1) * blockDim().x # global id
nlev = nlay + 1
n_gpt, igpt = 1, 1
FT = eltype(op.angle_disc.gauss_Ds)
FT = eltype(bcs_sw.cos_zenith)
solar_frac = FT(1)
if gcol ncol
flux_up_sw = flux_sw.flux_up
Expand Down
7 changes: 2 additions & 5 deletions src/optics/Optics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,19 @@ calculations accounting for extinction and emission
# Fields
$(DocStringExtensions.FIELDS)
"""
struct OneScalar{D, V, AD <: AngularDiscretization} <: AbstractOpticalProps
struct OneScalar{D, V} <: AbstractOpticalProps
"storage for optical thickness"
layerdata::D
"view into optical depth"
τ::V
"Angular discretization"
angle_disc::AD
end
Adapt.@adapt_structure OneScalar

function OneScalar(::Type{FT}, ncol::Int, nlay::Int, ::Type{DA}) where {FT <: AbstractFloat, DA}
layerdata = DA{FT, 3}(undef, 1, nlay, ncol)
τ = view(layerdata, 1, :, :)
ad = AngularDiscretization(FT, DA, 1)
V = typeof(τ)
return OneScalar{typeof(layerdata), V, typeof(ad)}(layerdata, τ, ad)
return OneScalar{typeof(layerdata), V}(layerdata, τ)
end

"""
Expand Down
8 changes: 6 additions & 2 deletions src/optics/RTE.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module RTE
using Adapt
import ClimaComms
using ..AngularDiscretizations
using ..AtmosphericStates
using DocStringExtensions
using ..Sources
Expand Down Expand Up @@ -32,7 +33,7 @@ configurations for a non-scattering longwave simulation.
# Fields
$(DocStringExtensions.FIELDS)
"""
struct NoScatLWRTE{C, OP, SL <: SourceLWNoScat, BC <: LwBCs, FXBL, FXL <: FluxLW}
struct NoScatLWRTE{C, OP, SL <: SourceLWNoScat, BC <: LwBCs, FXBL, FXL <: FluxLW, AD}
"ClimaComms context"
context::C
"optical properties"
Expand All @@ -45,6 +46,8 @@ struct NoScatLWRTE{C, OP, SL <: SourceLWNoScat, BC <: LwBCs, FXBL, FXL <: FluxLW
fluxb::FXBL
"longwave fluxes"
flux::FXL
"Angular discretization"
angle_disc::AD
end
Adapt.@adapt_structure NoScatLWRTE

Expand All @@ -64,7 +67,8 @@ function NoScatLWRTE(
bcs = LwBCs(sfc_emis, inc_flux)
fluxb = FluxLW(ncol, nlay, FT, DA)
flux = FluxLW(ncol, nlay, FT, DA)
return NoScatLWRTE(context, op, src, bcs, fluxb, flux)
ad = AngularDiscretization(FT, DA, 1)
return NoScatLWRTE(context, op, src, bcs, fluxb, flux, ad)
end

"""
Expand Down
12 changes: 6 additions & 6 deletions src/rte/RTESolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ include("shortwave2stream.jl")
Non-scattering RTE solver for the longwave problem, using gray optics.
"""
solve_lw!((; context, flux, src, bcs, op)::NoScatLWRTE, as::GrayAtmosphericState) =
rte_lw_noscat_solve!(context.device, flux, src, bcs, op, as)
solve_lw!((; context, flux, src, bcs, op, angle_disc)::NoScatLWRTE, as::GrayAtmosphericState) =
rte_lw_noscat_solve!(context.device, flux, src, bcs, op, angle_disc, as)

"""
solve_lw!((; context, flux, src, bcs, op)::TwoStreamLWRTE, as::GrayAtmosphericState)
Expand All @@ -47,14 +47,14 @@ solve_lw!((; context, flux, src, bcs, op)::TwoStreamLWRTE, as::GrayAtmosphericSt
Non-scattering RTE solver for the longwave problem, using RRTMGP optics.
"""
solve_lw!(
(; context, fluxb, flux, src, bcs, op)::NoScatLWRTE,
(; context, fluxb, flux, src, bcs, op, angle_disc)::NoScatLWRTE,
as::AtmosphericState,
lookup_lw::LookUpLW,
lookup_lw_cld::Union{LookUpCld, PadeCld},
) = rte_lw_noscat_solve!(context.device, fluxb, flux, src, bcs, op, as, lookup_lw, lookup_lw_cld)
) = rte_lw_noscat_solve!(context.device, fluxb, flux, src, bcs, op, angle_disc, as, lookup_lw, lookup_lw_cld)

solve_lw!((; context, fluxb, flux, src, bcs, op)::NoScatLWRTE, as::AtmosphericState, lookup_lw::LookUpLW) =
rte_lw_noscat_solve!(context.device, fluxb, flux, src, bcs, op, as, lookup_lw, nothing)
solve_lw!((; context, fluxb, flux, src, bcs, op, angle_disc)::NoScatLWRTE, as::AtmosphericState, lookup_lw::LookUpLW) =
rte_lw_noscat_solve!(context.device, fluxb, flux, src, bcs, op, angle_disc, as, lookup_lw, nothing)

"""
solve_lw!(
Expand Down
17 changes: 10 additions & 7 deletions src/rte/longwave1scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@ function rte_lw_noscat_solve!(
src_lw::SourceLWNoScat,
bcs_lw::LwBCs,
op::OneScalar,
angle_disc::AngularDiscretization,
as::GrayAtmosphericState,
)
nlay, ncol = AtmosphericStates.get_dims(as)
nlev = nlay + 1
igpt, ibnd = 1, 1
τ = op.τ
Ds = op.angle_disc.gauss_Ds
Ds = angle_disc.gauss_Ds
(; flux_up, flux_dn, flux_net) = flux_lw
@inbounds begin
ClimaComms.@threaded device for gcol in 1:ncol
compute_optical_props!(op, as, src_lw, gcol)
rte_lw_noscat!(src_lw, bcs_lw, op, gcol, flux_lw, igpt, ibnd, nlay, nlev)
rte_lw_noscat!(src_lw, bcs_lw, op, angle_disc, gcol, flux_lw, igpt, ibnd, nlay, nlev)
for ilev in 1:nlev
flux_net[ilev, gcol] = flux_up[ilev, gcol] - flux_dn[ilev, gcol]
end
Expand All @@ -31,6 +32,7 @@ function rte_lw_noscat_solve!(
src_lw::SourceLWNoScat,
bcs_lw::LwBCs,
op::OneScalar,
angle_disc::AngularDiscretization,
as::AtmosphericState,
lookup_lw::LookUpLW,
lookup_lw_cld::Union{LookUpCld, PadeCld, Nothing} = nothing,
Expand All @@ -40,7 +42,7 @@ function rte_lw_noscat_solve!(
(; major_gpt2bnd) = lookup_lw.band_data
n_gpt = length(major_gpt2bnd)
τ = op.τ
Ds = op.angle_disc.gauss_Ds
Ds = angle_disc.gauss_Ds
flux_up_lw = flux_lw.flux_up
flux_dn_lw = flux_lw.flux_dn
flux_net_lw = flux_lw.flux_net
Expand All @@ -50,7 +52,7 @@ function rte_lw_noscat_solve!(
ibnd = major_gpt2bnd[igpt]
igpt == 1 && set_flux_to_zero!(flux_lw, gcol)
compute_optical_props!(op, as, src_lw, gcol, igpt, lookup_lw, lookup_lw_cld)
rte_lw_noscat!(src_lw, bcs_lw, op, gcol, flux, igpt, ibnd, nlay, nlev)
rte_lw_noscat!(src_lw, bcs_lw, op, angle_disc, gcol, flux, igpt, ibnd, nlay, nlev)
add_to_flux!(flux_lw, flux, gcol)
end
end
Expand Down Expand Up @@ -110,6 +112,7 @@ Transport for no-scattering longwave problem.
src_lw::SourceLWNoScat,
bcs_lw::LwBCs,
op::OneScalar,
angle_disc::AngularDiscretization,
gcol,
flux::FluxLW,
igpt,
Expand All @@ -123,9 +126,9 @@ Transport for no-scattering longwave problem.
(; sfc_emis, inc_flux) = bcs_lw
(; flux_up, flux_dn) = flux

Ds = op.angle_disc.gauss_Ds
w_μ = op.angle_disc.gauss_wts
n_μ = op.angle_disc.n_gauss_angles
Ds = angle_disc.gauss_Ds
w_μ = angle_disc.gauss_wts
n_μ = angle_disc.n_gauss_angles
τ = op.τ
FT = eltype(τ)
τ_thresh = 100 * eps(FT) # or abs(eps(FT))?
Expand Down
2 changes: 1 addition & 1 deletion src/rte/shortwave1scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function rte_sw_noscat_solve!(
nlev = nlay + 1
n_gpt, igpt = 1, 1
cos_zenith = bcs_sw.cos_zenith
FT = eltype(op.angle_disc.gauss_Ds)
FT = eltype(cos_zenith)
solar_frac = FT(1)
@inbounds begin
ClimaComms.@threaded device for gcol in 1:ncol
Expand Down

0 comments on commit 95e9529

Please sign in to comment.