Skip to content

Commit 79e851e

Browse files
committed
dont force unroll loop in reductions
1 parent 81dd6ca commit 79e851e

File tree

2 files changed

+23
-24
lines changed

2 files changed

+23
-24
lines changed

src/mapreduce.jl

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,13 @@ end
4545

4646

4747
@generated function _map!(f, dest, ::Size{S}, a::StaticArray...) where {S}
48-
exprs = Vector{Expr}(undef, prod(S))
49-
for i 1:prod(S)
50-
tmp = [:(a[$j][$i]) for j 1:length(a)]
51-
exprs[i] = :(dest[$i] = f($(tmp...)))
52-
end
48+
tmp = [:(a[$j][i]) for j 1:length(a)]
5349
return quote
5450
@_inline_meta
55-
@inbounds $(Expr(:block, exprs...))
51+
@inbounds @simd for i 1:$(prod(S))
52+
dest[i] = f($(tmp...))
53+
end
54+
return dest
5655
end
5756
end
5857

@@ -66,28 +65,28 @@ end
6665

6766
@generated function _mapreduce(f, op, dims::Colon, nt::NamedTuple{()},
6867
::Size{S}, a::StaticArray...) where {S}
69-
tmp = [:(a[$j][1]) for j 1:length(a)]
70-
expr = :(f($(tmp...)))
71-
for i 2:prod(S)
72-
tmp = [:(a[$j][$i]) for j 1:length(a)]
73-
expr = :(op($expr, f($(tmp...))))
74-
end
68+
tmp = [:(a[$j][i]) for j 1:length(a)]
7569
return quote
7670
@_inline_meta
77-
@inbounds return $expr
71+
i = 1
72+
@inbounds s = f($(tmp...))
73+
@inbounds @simd for i = 2:$(prod(S))
74+
s = op(s, f($(tmp...)))
75+
end
76+
return s
7877
end
7978
end
80-
79+
8180
@generated function _mapreduce(f, op, dims::Colon, nt::NamedTuple{(:init,)},
82-
::Size{S}, a::StaticArray...) where {S}
83-
expr = :(nt.init)
84-
for i 1:prod(S)
85-
tmp = [:(a[$j][$i]) for j 1:length(a)]
86-
expr = :(op($expr, f($(tmp...))))
87-
end
81+
::Size{S}, a::StaticArray...) where {S}
82+
tmp = [:(a[$j][i]) for j 1:length(a)]
8883
return quote
8984
@_inline_meta
90-
@inbounds return $expr
85+
@inbounds s = nt.init
86+
@inbounds @simd for i = 1:$(prod(S))
87+
s = op(s, f($(tmp...)))
88+
end
89+
return s
9190
end
9291
end
9392

@@ -98,7 +97,7 @@ end
9897
@inline _mapreduce(f, op, D::Int, nt::NamedTuple, sz::Size{S}, a::StaticArray) where {S} =
9998
_mapreduce(f, op, Val(D), nt, sz, a)
10099

101-
100+
102101
@generated function _mapreduce(f, op, dims::Val{D}, nt::NamedTuple{()},
103102
::Size{S}, a::StaticArray) where {S,D}
104103
N = length(S)

test/mapreduce.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ end
5858

5959
@test iszero(sz) == iszero(z)
6060

61-
@test sum(sa) === sum(a)
61+
@test sum(sa) sum(a)
6262
@test sum(abs2, sa) === sum(abs2, a)
6363
@test sum(sa, dims=2) === RSArray2(sum(a, dims=2))
6464
@test sum(sa, dims=Val(2)) === RSArray2(sum(a, dims=2))
@@ -85,7 +85,7 @@ end
8585
@test any(sb, dims=Val(2)) === RSArray2(any(b, dims=2))
8686
@test any(x->x>0, sa, dims=Val(2)) === RSArray2(any(x->x>0, a, dims=2))
8787

88-
@test mean(sa) === mean(a)
88+
@test mean(sa) mean(a)
8989
@test mean(abs2, sa) === mean(abs2, a)
9090
@test mean(sa, dims=Val(2)) === RSArray2(mean(a, dims=2))
9191
@test mean(abs2, sa, dims=Val(2)) === RSArray2(mean(abs2.(a), dims=2))

0 commit comments

Comments
 (0)