@@ -4,11 +4,15 @@ using Bumper: @no_escape, @alloc
44using DynamicExpressions: OperatorEnum, AbstractExpressionNode, tree_mapreduce
55using DynamicExpressions. UtilsModule: ResultOk, counttuple, is_bad_array
66
7- import DynamicExpressions. ExtensionInterfaceModule: bumper_eval_tree_array
7+ import DynamicExpressions. ExtensionInterfaceModule:
8+ bumper_eval_tree_array, bumper_kern1!, bumper_kern2!
89
910function bumper_eval_tree_array (
10- tree:: AbstractExpressionNode{T} , cX:: AbstractMatrix{T} , operators:: OperatorEnum
11- ) where {T}
11+ tree:: AbstractExpressionNode{T} ,
12+ cX:: AbstractMatrix{T} ,
13+ operators:: OperatorEnum ,
14+ :: Val{turbo} ,
15+ ) where {T,turbo}
1216 result = similar (cX, axes (cX, 2 ))
1317 n = size (cX, 2 )
1418 all_ok = Ref (false )
@@ -32,7 +36,8 @@ function bumper_eval_tree_array(
3236 branch_node -> branch_node,
3337 # In the evaluation kernel, we combine the branch nodes
3438 # with the arrays created by the leaf nodes:
35- ((args:: Vararg{Any,M} ) where {M}) -> dispatch_kerns! (operators, args... ),
39+ ((args:: Vararg{Any,M} ) where {M}) ->
40+ dispatch_kerns! (operators, args... , Val (turbo)),
3641 tree;
3742 break_sharing= Val (true ),
3843 )
@@ -43,43 +48,55 @@ function bumper_eval_tree_array(
4348 return (result, all_ok[])
4449end
4550
46- function dispatch_kerns! (operators, branch_node, cumulator)
51+ function dispatch_kerns! (operators, branch_node, cumulator, :: Val{turbo} ) where {turbo}
4752 cumulator. ok || return cumulator
4853
49- out = dispatch_kern1! (operators. unaops, branch_node. op, cumulator. x)
54+ out = dispatch_kern1! (operators. unaops, branch_node. op, cumulator. x, Val (turbo) )
5055 return ResultOk (out, ! is_bad_array (out))
5156end
52- function dispatch_kerns! (operators, branch_node, cumulator1, cumulator2)
57+ function dispatch_kerns! (
58+ operators, branch_node, cumulator1, cumulator2, :: Val{turbo}
59+ ) where {turbo}
5360 cumulator1. ok || return cumulator1
5461 cumulator2. ok || return cumulator2
5562
56- out = dispatch_kern2! (operators. binops, branch_node. op, cumulator1. x, cumulator2. x)
63+ out = dispatch_kern2! (
64+ operators. binops, branch_node. op, cumulator1. x, cumulator2. x, Val (turbo)
65+ )
5766 return ResultOk (out, ! is_bad_array (out))
5867end
5968
60- @generated function dispatch_kern1! (unaops, op_idx, cumulator)
69+ @generated function dispatch_kern1! (unaops, op_idx, cumulator, :: Val{turbo} ) where {turbo}
6170 nuna = counttuple (unaops)
6271 quote
63- Base. @nif ($ nuna, i -> i == op_idx, i -> let op = unaops[i]
64- return kern1! (op, cumulator)
65- end ,)
72+ Base. @nif (
73+ $ nuna,
74+ i -> i == op_idx,
75+ i -> let op = unaops[i]
76+ return bumper_kern1! (op, cumulator, Val (turbo))
77+ end ,
78+ )
6679 end
6780end
68- @generated function dispatch_kern2! (binops, op_idx, cumulator1, cumulator2)
81+ @generated function dispatch_kern2! (
82+ binops, op_idx, cumulator1, cumulator2, :: Val{turbo}
83+ ) where {turbo}
6984 nbin = counttuple (binops)
7085 quote
7186 Base. @nif (
72- $ nbin, i -> i == op_idx, i -> let op = binops[i]
73- return kern2! (op, cumulator1, cumulator2)
87+ $ nbin,
88+ i -> i == op_idx,
89+ i -> let op = binops[i]
90+ return bumper_kern2! (op, cumulator1, cumulator2, Val (turbo))
7491 end ,
7592 )
7693 end
7794end
78- function kern1 ! (op:: F , cumulator) where {F}
95+ function bumper_kern1 ! (op:: F , cumulator, :: Val{false} ) where {F}
7996 @. cumulator = op (cumulator)
8097 return cumulator
8198end
82- function kern2 ! (op:: F , cumulator1, cumulator2) where {F}
99+ function bumper_kern2 ! (op:: F , cumulator1, cumulator2, :: Val{false} ) where {F}
83100 @. cumulator1 = op (cumulator1, cumulator2)
84101 return cumulator1
85102end
0 commit comments