Skip to content

Try inlining matrix field multiply_matrix_at_index #2311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,8 @@ steps:
- label: "Unit: operator matrices (CPU)"
key: unit_operator_matrices_cpu
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/MatrixFields/operator_matrices.jl"
agents:
slurm_mem: 64GB

- label: "Unit: operator matrices (GPU)"
key: unit_operator_matrices_gpu
Expand All @@ -807,7 +809,7 @@ steps:
CLIMACOMMS_DEVICE: "CUDA"
agents:
slurm_gpus: 1
slurm_mem: 40GB
slurm_mem: 64GB

- label: "Unit: field names"
key: unit_field_names
Expand Down Expand Up @@ -943,6 +945,8 @@ steps:
- label: "Unit: matrix field broadcasting (CPU)"
key: unit_matrix_field_broadcasting_cpu_non_scalar_3
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/MatrixFields/matrix_fields_broadcasting/test_non_scalar_3.jl"
agents:
slurm_mem: 20GB

- label: "Unit: matrix field broadcasting (CPU)"
key: unit_matrix_field_broadcasting_cpu_non_scalar_4
Expand Down Expand Up @@ -1141,7 +1145,7 @@ steps:
CLIMACOMMS_DEVICE: "CUDA"
agents:
slurm_gpus: 1
slurm_mem: 10GB
slurm_mem: 20GB

- label: "Unit: matrix field broadcasting (GPU)"
key: unit_matrix_field_broadcasting_gpu_non_scalar_5
Expand Down
91 changes: 91 additions & 0 deletions src/ClimaCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,95 @@ include("CommonSpaces/CommonSpaces.jl")
include("deprecated.jl")
include("to_device.jl")

# For complex nested types (ex. wrapped SMatrix / broadcast expressions) we hit
# a recursion limit and de-optimize We know the recursion will terminate due to
# the fact that bitstype fields cannot be self referential so there are no
# cycles in these methods (bounded tree) TODO: enforce inference termination
# some other way

if hasfield(Method, :recursion_relation)
dont_limit = (args...) -> true
for m in methods(Operators.column)
m.recursion_relation = dont_limit
end
for m in methods(Operators.column_args)
m.recursion_relation = dont_limit
end
for m in methods(MatrixFields.multiply_matrix_at_index)
m.recursion_relation = dont_limit
end
for m in methods(MatrixFields.unique_and_non_overlapping_values)
m.recursion_relation = dont_limit
end
for m in methods(MatrixFields.union_values)
m.recursion_relation = dont_limit
end
for m in methods(Operators.reconstruct_placeholder_broadcasted)
m.recursion_relation = dont_limit
end
for m in methods(Operators._reconstruct_placeholder_broadcasted)
m.recursion_relation = dont_limit
end
for m in methods(Operators.get_node)
m.recursion_relation = dont_limit
end
for m in methods(MatrixFields.has_field)
m.recursion_relation = dont_limit
end
for m in methods(MatrixFields.get_field)
m.recursion_relation = dont_limit
end
for m in methods(MatrixFields.broadcasted_has_field)
m.recursion_relation = dont_limit
end
for m in methods(MatrixFields.broadcasted_get_field)
m.recursion_relation = dont_limit
end
for m in methods(MatrixFields.wrapped_prop_names)
m.recursion_relation = dont_limit
end
for m in methods(MatrixFields.filtered_child_names)
m.recursion_relation = dont_limit
end
for m in methods(MatrixFields.subtree_at_name)
m.recursion_relation = dont_limit
end
for m in methods(MatrixFields.is_valid_name)
m.recursion_relation = dont_limit
end
for m in methods(MatrixFields.get_subtree_at_name)
m.recursion_relation = dont_limit
end
for m in methods(MatrixFields.concrete_field_vector_within_subtree)
m.recursion_relation = dont_limit
end
for m in methods(DataLayouts.get_struct_linear)
m.recursion_relation = dont_limit
end
for m in methods(DataLayouts.set_struct_linear!)
m.recursion_relation = dont_limit
end
for m in methods(DataLayouts.get_struct)
m.recursion_relation = dont_limit
end
for m in methods(DataLayouts.set_struct!)
m.recursion_relation = dont_limit
end
for m in methods(Operators.call_bc_f)
m.recursion_relation = dont_limit
end
for m in methods(Operators.getidx)
m.recursion_relation = dont_limit
end
for m in methods(Operators.stencil_interior)
m.recursion_relation = dont_limit
end
for m in methods(Operators.stencil_left_boundary)
m.recursion_relation = dont_limit
end
for m in methods(Operators.stencil_right_boundary)
m.recursion_relation = dont_limit
end
end

end # module
14 changes: 0 additions & 14 deletions src/DataLayouts/struct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,17 +247,3 @@ Base.@propagate_inbounds function set_struct!(
@inbounds array[index] = val
val
end

# For complex nested types (ex. wrapped SMatrix) we hit a recursion limit and de-optimize
# We know the recursion will terminate due to the fact that bitstype fields
# cannot be self referential so there are no cycles in get/set_struct (bounded tree)
# TODO: enforce inference termination some other way
if hasfield(Method, :recursion_relation)
dont_limit = (args...) -> true
for m in methods(get_struct)
m.recursion_relation = dont_limit
end
for m in methods(set_struct!)
m.recursion_relation = dont_limit
end
end
14 changes: 0 additions & 14 deletions src/DataLayouts/struct_linear_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,3 @@ Base.@propagate_inbounds function set_struct_linear!(
@inbounds array[start_index] = val
val
end

# For complex nested types (ex. wrapped SMatrix) we hit a recursion limit and de-optimize
# We know the recursion will terminate due to the fact that bitstype fields
# cannot be self referential so there are no cycles in get/set_struct (bounded tree)
# TODO: enforce inference termination some other way
if hasfield(Method, :recursion_relation)
dont_limit = (args...) -> true
for m in methods(get_struct_linear)
m.recursion_relation = dont_limit
end
for m in methods(set_struct_linear!)
m.recursion_relation = dont_limit
end
end
33 changes: 0 additions & 33 deletions src/MatrixFields/field_name.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,36 +167,3 @@ get_subtree_at_name(name, tree) =
get_subtree_at_name(name, subtrees_at_name[1])
end

################################################################################

# This is required for type-stability as of Julia 1.9.
if hasfield(Method, :recursion_relation)
dont_limit = (args...) -> true
for m in methods(has_field)
m.recursion_relation = dont_limit
end
for m in methods(get_field)
m.recursion_relation = dont_limit
end
for m in methods(broadcasted_has_field)
m.recursion_relation = dont_limit
end
for m in methods(broadcasted_get_field)
m.recursion_relation = dont_limit
end
for m in methods(wrapped_prop_names)
m.recursion_relation = dont_limit
end
for m in methods(filtered_child_names)
m.recursion_relation = dont_limit
end
for m in methods(subtree_at_name)
m.recursion_relation = dont_limit
end
for m in methods(is_valid_name)
m.recursion_relation = dont_limit
end
for m in methods(get_subtree_at_name)
m.recursion_relation = dont_limit
end
end
8 changes: 0 additions & 8 deletions src/MatrixFields/field_name_dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,14 +359,6 @@ concrete_field_vector_within_subtree(tree, vector) =
Fields.FieldVector{T}(NamedTuple{internal_names}(internal_entries))
end

# This is required for type-stability as of Julia 1.9.
if hasfield(Method, :recursion_relation)
dont_limit = (args...) -> true
for m in methods(concrete_field_vector_within_subtree)
m.recursion_relation = dont_limit
end
end

################################################################################

struct FieldNameDictStyle <: Base.Broadcast.BroadcastStyle end
Expand Down
10 changes: 0 additions & 10 deletions src/MatrixFields/field_name_set.jl
Original file line number Diff line number Diff line change
Expand Up @@ -455,13 +455,3 @@ function expand_child_values(
end
end

# This is required for type-stability as of Julia 1.9.
if hasfield(Method, :recursion_relation)
dont_limit = (args...) -> true
for m in methods(unique_and_non_overlapping_values)
m.recursion_relation = dont_limit
end
for m in methods(union_values)
m.recursion_relation = dont_limit
end
end
34 changes: 10 additions & 24 deletions src/MatrixFields/matrix_multiplication.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,15 +341,7 @@ boundary_modified_ud(_, ud, column_space, i) = ud
boundary_modified_ud(::BottomRightMatrixCorner, ud, column_space, i) =
min(Operators.right_idx(column_space) - i, ud)

# TODO: Use @propagate_inbounds here, and remove @inbounds from this function.
# As of Julia 1.8, doing this increases compilation time by more than an order
# of magnitude, and it also makes type inference fail for some complicated
# matrix field broadcast expressions (in particular, those that involve products
# of linear combinations of matrix fields). Not using @propagate_inbounds causes
# matrix field broadcast expressions to take roughly 3 or 4 times longer to
# evaluate, but this is less significant than the decrease in compilation time.
# matrix-matrix multiplication
function multiply_matrix_at_index(
Base.@propagate_inbounds function multiply_matrix_at_index(
space,
idx,
hidx,
Expand All @@ -374,9 +366,11 @@ function multiply_matrix_at_index(

# Precompute the row that is needed from matrix1 so that it does not get
# recomputed multiple times.
matrix1_row = @inbounds Operators.getidx(space, matrix1, idx, hidx)
TM1R = Operators.getidx_return_type(matrix1)
matrix1_row = @inbounds Operators.getidx(space, matrix1, idx, hidx)::TM1R

matrix2 = arg
TM2R = Operators.getidx_return_type(matrix2)
column_space2 = column_axes(matrix2, column_space1)
ld2, ud2 = outer_diagonals(eltype(matrix2))
prod_ld, prod_ud = outer_diagonals(prod_type)
Expand All @@ -395,7 +389,7 @@ function multiply_matrix_at_index(
# TODO: Use @propagate_inbounds_meta instead of @inline_meta.
Base.@_inline_meta
if isnothing(bc) || boundary_modified_ld1 <= d <= boundary_modified_ud1
@inbounds Operators.getidx(space, matrix2, idx + d, hidx)
@inbounds Operators.getidx(space, matrix2, idx + d, hidx)::TM2R
else
zero(eltype(matrix2)) # This row is outside the matrix.
end
Expand Down Expand Up @@ -437,7 +431,7 @@ function multiply_matrix_at_index(
return BandMatrixRow{prod_ld}(prod_entries...)
end
# matrix-vector multiplication
function multiply_matrix_at_index(
Base.@propagate_inbounds function multiply_matrix_at_index(
space,
idx,
hidx,
Expand All @@ -454,6 +448,7 @@ function multiply_matrix_at_index(
arg,
typeof(lg),
)
TM1R = Operators.getidx_return_type(matrix1)

column_space1 = column_axes(matrix1, space)
ld1, ud1 = outer_diagonals(eltype(matrix1))
Expand All @@ -462,13 +457,14 @@ function multiply_matrix_at_index(

# Precompute the row that is needed from matrix1 so that it does not get
# recomputed multiple times.
matrix1_row = @inbounds Operators.getidx(space, matrix1, idx, hidx)
matrix1_row = @inbounds Operators.getidx(space, matrix1, idx, hidx)::TM1R

vector = arg
TVR = Operators.getidx_return_type(vector)
prod_value = rzero(prod_type)
@inbounds for d in boundary_modified_ld1:boundary_modified_ud1
value1 = matrix1_row[d]
value2 = Operators.getidx(space, vector, idx + d, hidx)
value2 = Operators.getidx(space, vector, idx + d, hidx)::TVR
value2_lg = Geometry.LocalGeometry(space, idx + d, hidx)
prod_value =
radd(prod_value, rmul_with_projection(value1, value2, value2_lg))
Expand Down Expand Up @@ -513,13 +509,3 @@ Base.@propagate_inbounds Operators.stencil_right_boundary(
arg,
) = multiply_matrix_at_index(space, idx, hidx, matrix1, arg, bc, eltype(arg))

# For matrix field broadcast expressions involving 4 or more matrices, we
# sometimes hit a recursion limit and de-optimize.
# We know that the recursion will terminate due to the fact that broadcast
# expressions are not self-referential.
if hasfield(Method, :recursion_relation)
dont_limit = (args...) -> true
for m in methods(multiply_matrix_at_index)
m.recursion_relation = dont_limit
end
end
32 changes: 0 additions & 32 deletions src/Operators/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3840,16 +3840,6 @@ end
end
end

if hasfield(Method, :recursion_relation)
dont_limit = (args...) -> true
for m in methods(call_bc_f)
m.recursion_relation = dont_limit
end
for m in methods(getidx)
m.recursion_relation = dont_limit
end
end

# setidx! methods for copyto!
Base.@propagate_inbounds function setidx!(
parent_space,
Expand Down Expand Up @@ -3928,17 +3918,6 @@ Base.@propagate_inbounds column(sbc::StencilBroadcasted{S}, inds...) where {S} =
column(sbc.axes, inds...),
)

#TODO: the optimizer dies with column broadcast expressions over a certain complexity
if hasfield(Method, :recursion_relation)
dont_limit = (args...) -> true
for m in methods(column)
m.recursion_relation = dont_limit
end
for m in methods(column_args)
m.recursion_relation = dont_limit
end
end

function Base.similar(
bc::Base.Broadcast.Broadcasted{S},
::Type{Eltype},
Expand Down Expand Up @@ -4163,17 +4142,6 @@ promote_bc(bc::SetDivergence{<:Geometry.AxisTensor}, ::Type{FT}) where {FT} =
promote_bc(bc::SetCurl{<:Geometry.AxisTensor}, ::Type{FT}) where {FT} =
SetCurl(promote_axis_tensor(bc.val, FT))


if hasfield(Method, :recursion_relation)
dont_limit = (args...) -> true
for m in methods(reconstruct_placeholder_broadcasted)
m.recursion_relation = dont_limit
end
for m in methods(_reconstruct_placeholder_broadcasted)
m.recursion_relation = dont_limit
end
end

"""
use_fd_shmem()

Expand Down
5 changes: 0 additions & 5 deletions src/Operators/spectralelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -404,11 +404,6 @@ Base.@propagate_inbounds function get_node(
data[ij]
end

dont_limit = (args...) -> true
for m in methods(get_node)
m.recursion_relation = dont_limit
end

Base.@propagate_inbounds function get_local_geometry(
space::Union{
Spaces.AbstractSpectralElementSpace,
Expand Down
Loading