Skip to content

Commit 15844ad

Browse files
committed
Allow both Bumper and LoopVectorization
1 parent 0d396e7 commit 15844ad

File tree

6 files changed

+60
-25
lines changed

6 files changed

+60
-25
lines changed

benchmark/benchmarks.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ function benchmark_evaluation()
3636
for turbo in (false, true), bumper in (false, true)
3737

3838
(turbo || bumper) && !(T in (Float32, Float64)) && continue
39-
turbo && bumper && continue
4039
if bumper
4140
try
4241
eval_tree_array(Node{T}(val=1.0), ones(T, 5, n), operators; turbo, bumper)
@@ -47,7 +46,15 @@ function benchmark_evaluation()
4746
end
4847
end
4948

50-
extra_key = turbo ? "_turbo" : (bumper ? "_bumper" : "")
49+
extra_key = if turbo && bumper
50+
"_turbo_bumper"
51+
elseif turbo
52+
"_turbo"
53+
elseif bumper
54+
"_bumper"
55+
else
56+
""
57+
end
5158
extra_kws = bumper ? (; bumper=Val(true)) : ()
5259
eval_tree_array(
5360
gen_random_tree_fixed_size(20, operators, 5, T),

ext/DynamicExpressionsBumperExt.jl

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@ using Bumper: @no_escape, @alloc
44
using DynamicExpressions: OperatorEnum, AbstractExpressionNode, tree_mapreduce
55
using DynamicExpressions.UtilsModule: ResultOk, counttuple, is_bad_array
66

7-
import DynamicExpressions.ExtensionInterfaceModule: bumper_eval_tree_array
7+
import DynamicExpressions.ExtensionInterfaceModule:
8+
bumper_eval_tree_array, bumper_kern1!, bumper_kern2!
89

910
function bumper_eval_tree_array(
10-
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, operators::OperatorEnum
11-
) where {T}
11+
tree::AbstractExpressionNode{T},
12+
cX::AbstractMatrix{T},
13+
operators::OperatorEnum,
14+
::Val{turbo},
15+
) where {T,turbo}
1216
result = similar(cX, axes(cX, 2))
1317
n = size(cX, 2)
1418
all_ok = Ref(false)
@@ -32,7 +36,8 @@ function bumper_eval_tree_array(
3236
branch_node -> branch_node,
3337
# In the evaluation kernel, we combine the branch nodes
3438
# with the arrays created by the leaf nodes:
35-
((args::Vararg{Any,M}) where {M}) -> dispatch_kerns!(operators, args...),
39+
((args::Vararg{Any,M}) where {M}) ->
40+
dispatch_kerns!(operators, args..., Val(turbo)),
3641
tree;
3742
break_sharing=Val(true),
3843
)
@@ -43,43 +48,55 @@ function bumper_eval_tree_array(
4348
return (result, all_ok[])
4449
end
4550

46-
function dispatch_kerns!(operators, branch_node, cumulator)
51+
function dispatch_kerns!(operators, branch_node, cumulator, ::Val{turbo}) where {turbo}
4752
cumulator.ok || return cumulator
4853

49-
out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x)
54+
out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, Val(turbo))
5055
return ResultOk(out, !is_bad_array(out))
5156
end
52-
function dispatch_kerns!(operators, branch_node, cumulator1, cumulator2)
57+
function dispatch_kerns!(
58+
operators, branch_node, cumulator1, cumulator2, ::Val{turbo}
59+
) where {turbo}
5360
cumulator1.ok || return cumulator1
5461
cumulator2.ok || return cumulator2
5562

56-
out = dispatch_kern2!(operators.binops, branch_node.op, cumulator1.x, cumulator2.x)
63+
out = dispatch_kern2!(
64+
operators.binops, branch_node.op, cumulator1.x, cumulator2.x, Val(turbo)
65+
)
5766
return ResultOk(out, !is_bad_array(out))
5867
end
5968

60-
@generated function dispatch_kern1!(unaops, op_idx, cumulator)
69+
@generated function dispatch_kern1!(unaops, op_idx, cumulator, ::Val{turbo}) where {turbo}
6170
nuna = counttuple(unaops)
6271
quote
63-
Base.@nif($nuna, i -> i == op_idx, i -> let op = unaops[i]
64-
return kern1!(op, cumulator)
65-
end,)
72+
Base.@nif(
73+
$nuna,
74+
i -> i == op_idx,
75+
i -> let op = unaops[i]
76+
return bumper_kern1!(op, cumulator, Val(turbo))
77+
end,
78+
)
6679
end
6780
end
68-
@generated function dispatch_kern2!(binops, op_idx, cumulator1, cumulator2)
81+
@generated function dispatch_kern2!(
82+
binops, op_idx, cumulator1, cumulator2, ::Val{turbo}
83+
) where {turbo}
6984
nbin = counttuple(binops)
7085
quote
7186
Base.@nif(
72-
$nbin, i -> i == op_idx, i -> let op = binops[i]
73-
return kern2!(op, cumulator1, cumulator2)
87+
$nbin,
88+
i -> i == op_idx,
89+
i -> let op = binops[i]
90+
return bumper_kern2!(op, cumulator1, cumulator2, Val(turbo))
7491
end,
7592
)
7693
end
7794
end
78-
function kern1!(op::F, cumulator) where {F}
95+
function bumper_kern1!(op::F, cumulator, ::Val{false}) where {F}
7996
@. cumulator = op(cumulator)
8097
return cumulator
8198
end
82-
function kern2!(op::F, cumulator1, cumulator2) where {F}
99+
function bumper_kern2!(op::F, cumulator1, cumulator2, ::Val{false}) where {F}
83100
@. cumulator1 = op(cumulator1, cumulator2)
84101
return cumulator1
85102
end

ext/DynamicExpressionsLoopVectorizationExt.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ import DynamicExpressions.EvaluateEquationModule:
1212
deg2_l0_r0_eval,
1313
deg2_l0_eval,
1414
deg2_r0_eval
15-
import DynamicExpressions.ExtensionInterfaceModule: _is_loopvectorization_loaded
15+
import DynamicExpressions.ExtensionInterfaceModule:
16+
_is_loopvectorization_loaded, bumper_kern1!, bumper_kern2!
1617

1718
_is_loopvectorization_loaded(::Int) = true
1819

@@ -201,4 +202,14 @@ function deg2_r0_eval(
201202
end
202203
end
203204

205+
## Interface with Bumper.jl
206+
function bumper_kern1!(op::F, cumulator, ::Val{true}) where {F}
207+
@turbo @. cumulator = op(cumulator)
208+
return cumulator
209+
end
210+
function bumper_kern2!(op::F, cumulator1, cumulator2, ::Val{true}) where {F}
211+
@turbo @. cumulator1 = op(cumulator1, cumulator2)
212+
return cumulator1
213+
end
214+
204215
end

src/EvaluateEquation.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,13 @@ function eval_tree_array(
7272
if v_turbo isa Val{true} || v_bumper isa Val{true}
7373
@assert T in (Float32, Float64)
7474
end
75-
@assert !(v_turbo isa Val{true} && v_bumper isa Val{true})
76-
if v_bumper isa Val{true}
77-
return bumper_eval_tree_array(tree, cX, operators)
78-
end
7975
if v_turbo isa Val{true}
8076
_is_loopvectorization_loaded(0) ||
8177
error("Please load the LoopVectorization.jl package to use this feature.")
8278
end
79+
if v_bumper isa Val{true}
80+
return bumper_eval_tree_array(tree, cX, operators, v_turbo)
81+
end
8382

8483
result = _eval_tree_array(tree, cX, operators, v_turbo)
8584
return (result.x, result.ok && !is_bad_array(result.x))

src/ExtensionInterface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ end
1414
function bumper_eval_tree_array(args...)
1515
return error("Please load the Bumper.jl package to use this feature.")
1616
end
17+
function bumper_kern1! end
18+
function bumper_kern2! end
1719

1820
_is_loopvectorization_loaded(_) = false
1921

test/test_evaluation.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ for turbo in [Val(false), Val(true)],
4343

4444
# Float16 not implemented:
4545
(turbo isa Val{true} || bumper isa Val{true}) && !(T in (Float32, Float64)) && continue
46-
turbo isa Val{true} && bumper isa Val{true} && continue
4746
@testset "Test evaluation of trees with turbo=$turbo, bumper=$bumper, T=$T" begin
4847
for (i_func, fnc) in enumerate(functions)
4948

0 commit comments

Comments
 (0)