Skip to content

Commit 3da1d38

Browse files
committed
refactor: avoid SubArray in CUDA kernel
1 parent 74f56d4 commit 3da1d38

File tree

2 files changed

+62
-61
lines changed

2 files changed

+62
-61
lines changed

ext/DynamicExpressionsCUDAExt.jl

Lines changed: 45 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,16 @@ module DynamicExpressionsCUDAExt
44
using CUDA: @cuda, CuArray, blockDim, blockIdx, threadIdx
55
using DynamicExpressions: OperatorEnum, AbstractExpressionNode
66
using DynamicExpressions.EvaluateModule: get_nbin, get_nuna
7-
using DynamicExpressions.AsArrayModule: as_array
7+
using DynamicExpressions.AsArrayModule:
8+
as_array,
9+
IDX_DEGREE,
10+
IDX_FEATURE,
11+
IDX_OP,
12+
IDX_EXECUTION_ORDER,
13+
IDX_SELF,
14+
IDX_L,
15+
IDX_R,
16+
IDX_CONSTANT
817
using DispatchDoctor: @stable
918

1019
import DynamicExpressions.EvaluateModule: eval_tree_array
@@ -59,12 +68,11 @@ end
5968
## in the input data by the number of nodes in the tree.
6069
## It has one extra row to store the constant values.
6170
gworkspace = @something(gpu_workspace, similar(gcX, num_elem + 1, num_nodes))
62-
gval = @view gworkspace[end, :]
6371
if _update_buffers
64-
copyto!(gval, val)
72+
copyto!(@view(gworkspace[end, :]), val)
6573
end
74+
val_idx = size(gworkspace, 1)
6675

67-
## Index arrays (much faster to have `@view` here)
6876
gbuffer = if !_update_buffers
6977
gpu_buffer
7078
elseif gpu_buffer === nothing
@@ -73,17 +81,8 @@ end
7381
copyto!(gpu_buffer, buffer)
7482
end
7583

76-
#! format: off
77-
gdegree = @view gbuffer[1, :]
78-
gfeature = @view gbuffer[2, :]
79-
gop = @view gbuffer[3, :]
80-
gexecution_order = @view gbuffer[4, :]
81-
gidx_self = @view gbuffer[5, :]
82-
gidx_l = @view gbuffer[6, :]
83-
gidx_r = @view gbuffer[7, :]
84-
gconstant = @view gbuffer[8, :]
85-
#! format: on
86-
# TODO: This is a bit dangerous as we're assuming exact indices
84+
# Removed @view definitions of gdegree, gfeature, etc.
85+
# We'll index directly into gbuffer using the constants above.
8786

8887
num_threads = 256
8988
num_blocks = nextpow(2, ceil(Int, num_elem * num_nodes / num_threads))
@@ -92,10 +91,9 @@ end
9291
_launch_gpu_kernel!(
9392
num_threads, num_blocks, num_launches, gworkspace,
9493
# Thread info:
95-
num_elem, num_nodes, gexecution_order,
96-
# Input data and tree
97-
operators, gcX, gidx_self, gidx_l, gidx_r,
98-
gdegree, gconstant, gval, gfeature, gop,
94+
num_elem, num_nodes,
95+
# We'll pass gbuffer directly to the kernel now:
96+
operators, gcX, gbuffer, val_idx,
9997
)
10098
#! format: on
10199

@@ -109,34 +107,30 @@ end
109107
@stable default_mode = "disable" function _launch_gpu_kernel!(
110108
num_threads, num_blocks, num_launches::Integer, buffer::AbstractArray{T,2},
111109
# Thread info:
112-
num_elem::Integer, num_nodes::Integer, execution_order::AbstractArray{I},
113-
# Input data and tree
114-
operators::OperatorEnum, cX::AbstractArray{T,2}, idx_self::AbstractArray, idx_l::AbstractArray, idx_r::AbstractArray,
115-
degree::AbstractArray, constant::AbstractArray, val::AbstractArray{T,1}, feature::AbstractArray, op::AbstractArray,
116-
) where {I,T}
110+
num_elem::Integer, num_nodes::Integer,
111+
operators::OperatorEnum, cX::AbstractArray{T,2}, gbuffer::AbstractArray{Int32,2},
112+
val_idx::Integer
113+
) where {T}
117114
#! format: on
118115
nuna = get_nuna(typeof(operators))
119116
nbin = get_nbin(typeof(operators))
120117
(nuna > 10 || nbin > 10) &&
121118
error("Too many operators. Kernels are only compiled up to 10.")
122119
gpu_kernel! = create_gpu_kernel(operators, Val(nuna), Val(nbin))
123-
for launch in one(I):I(num_launches)
120+
for launch in one(Int32):Int32(num_launches)
124121
#! format: off
125122
if buffer isa CuArray
126123
@cuda threads=num_threads blocks=num_blocks gpu_kernel!(
127124
buffer,
128-
launch, num_elem, num_nodes, execution_order,
129-
cX, idx_self, idx_l, idx_r,
130-
degree, constant, val, feature, op
125+
launch, num_elem, num_nodes,
126+
cX, gbuffer, val_idx
131127
)
132128
else
133129
Threads.@threads for i in 1:(num_threads * num_blocks)
134130
gpu_kernel!(
135131
buffer,
136-
launch, num_elem, num_nodes, execution_order,
137-
cX, idx_self, idx_l, idx_r,
138-
degree, constant, val, feature, op,
139-
i
132+
launch, num_elem, num_nodes,
133+
cX, gbuffer, val_idx, i
140134
)
141135
end
142136
end
@@ -155,55 +149,53 @@ for nuna in 0:10, nbin in 0:10
155149
@eval function create_gpu_kernel(operators::OperatorEnum, ::Val{$nuna}, ::Val{$nbin})
156150
#! format: off
157151
function (
158-
# Storage:
159152
buffer,
160-
# Thread info:
161-
launch::Integer, num_elem::Integer, num_nodes::Integer, execution_order::AbstractArray,
162-
# Input data and tree
163-
cX::AbstractArray, idx_self::AbstractArray, idx_l::AbstractArray, idx_r::AbstractArray,
164-
degree::AbstractArray, constant::AbstractArray, val::AbstractArray, feature::AbstractArray, op::AbstractArray,
165-
# Override for unittesting:
153+
launch::Integer, num_elem::Integer, num_nodes::Integer,
154+
cX::AbstractArray, gbuffer::AbstractArray{Int32,2},
155+
val_idx::Integer,
166156
i=nothing,
167157
)
168-
i = i === nothing ? (blockIdx().x - 1) * blockDim().x + threadIdx().x : i
158+
i = @something(i, (blockIdx().x - 1) * blockDim().x + threadIdx().x)
169159
if i > num_elem * num_nodes
170160
return nothing
171161
end
172162

173163
node = (i - 1) % num_nodes + 1
174164
elem = (i - node) ÷ num_nodes + 1
175165

176-
#! format: off
166+
177167
@inbounds begin
178-
if execution_order[node] != launch
168+
if gbuffer[IDX_EXECUTION_ORDER, node] != launch
179169
return nothing
180170
end
181171

182-
cur_degree = degree[node]
183-
cur_idx = idx_self[node]
172+
# Use constants to index gbuffer:
173+
cur_degree = gbuffer[IDX_DEGREE, node]
174+
cur_idx = gbuffer[IDX_SELF, node]
175+
184176
if cur_degree == 0
185-
if constant[node] == 1
186-
cur_val = val[node]
177+
if gbuffer[IDX_CONSTANT, node] == 1
178+
cur_val = buffer[val_idx, node]
187179
buffer[elem, cur_idx] = cur_val
188180
else
189-
cur_feature = feature[node]
181+
cur_feature = gbuffer[IDX_FEATURE, node]
190182
buffer[elem, cur_idx] = cX[cur_feature, elem]
191183
end
192184
else
193185
if cur_degree == 1 && $nuna > 0
194-
cur_op = op[node]
195-
l_idx = idx_l[node]
186+
cur_op = gbuffer[IDX_OP, node]
187+
l_idx = gbuffer[IDX_L, node]
196188
Base.Cartesian.@nif(
197189
$nuna,
198190
i -> i == cur_op,
199191
i -> let op = operators.unaops[i]
200192
buffer[elem, cur_idx] = op(buffer[elem, l_idx])
201193
end
202194
)
203-
elseif $nbin > 0 # Note this check is to avoid type inference issues when binops is empty
204-
cur_op = op[node]
205-
l_idx = idx_l[node]
206-
r_idx = idx_r[node]
195+
elseif $nbin > 0
196+
cur_op = gbuffer[IDX_OP, node]
197+
l_idx = gbuffer[IDX_L, node]
198+
r_idx = gbuffer[IDX_R, node]
207199
Base.Cartesian.@nif(
208200
$nbin,
209201
i -> i == cur_op,

src/AsArray.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,15 @@ Base.@kwdef struct TreeBuffer{
3737
buffer::D
3838
end
3939

40+
const IDX_DEGREE = 1
41+
const IDX_FEATURE = 2
42+
const IDX_OP = 3
43+
const IDX_EXECUTION_ORDER = 4
44+
const IDX_SELF = 5
45+
const IDX_L = 6
46+
const IDX_R = 7
47+
const IDX_CONSTANT = 8
48+
4049
function as_array(
4150
::Type{I},
4251
trees::Union{Tuple{N,Vararg{N}},AbstractVector{N}};
@@ -61,14 +70,14 @@ function as_array(
6170

6271
# Obtain arrays from the buffer. Each call to get_array consumes one "slot".
6372
#! format: off
64-
degree = @view buffer[1, :]
65-
feature = @view buffer[2, :]
66-
op = @view buffer[3, :]
67-
execution_order = @view buffer[4, :]
68-
idx_self = @view buffer[5, :]
69-
idx_l = @view buffer[6, :]
70-
idx_r = @view buffer[7, :]
71-
constant = @view buffer[8, :]
73+
degree = @view buffer[IDX_DEGREE, :]
74+
feature = @view buffer[IDX_FEATURE, :]
75+
op = @view buffer[IDX_OP, :]
76+
execution_order = @view buffer[IDX_EXECUTION_ORDER, :]
77+
idx_self = @view buffer[IDX_SELF, :]
78+
idx_l = @view buffer[IDX_L, :]
79+
idx_r = @view buffer[IDX_R, :]
80+
constant = @view buffer[IDX_CONSTANT, :]
7281
#! format: on
7382

7483
tree_buffers = TreeBuffer(;

0 commit comments

Comments
 (0)