Skip to content

Commit 2a56132

Browse files
committed
force strided specialisation
1 parent 42e9c9c commit 2a56132

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

src/implementation/strided.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,14 @@ end
9191
#-------------------------------------------------------------------------------------------
9292
# StridedView implementation
9393
#-------------------------------------------------------------------------------------------
94+
struct Adder end
95+
(::Adder)(x, y) = VectorInterface.add(x, y)
96+
struct Scaler{T}
97+
α::T
98+
end
99+
(s::Scaler)(x) = scale(x, s.α)
100+
(s::Scaler)(x, y) = scale(x * y, s.α)
101+
94102
function stridedtensoradd!(C::StridedView,
95103
A::StridedView, pA::Index2Tuple,
96104
α::Number, β::Number,
@@ -102,9 +110,7 @@ function stridedtensoradd!(C::StridedView,
102110
end
103111

104112
A′ = permutedims(A, linearize(pA))
105-
op1 = Base.Fix2(scale, α)
106-
op2 = Base.Fix2(scale, β)
107-
Strided._mapreducedim!(op1, +, op2, size(C), (C, A′))
113+
Strided._mapreducedim!(Scaler(α), Adder(), Scaler(β), size(C), (C, A′))
108114
return C
109115
end
110116

@@ -125,9 +131,7 @@ function stridedtensortrace!(C::StridedView,
125131
newsize = (size(C)..., tracesize...)
126132

127133
A′ = SV(A.parent, newsize, newstrides, A.offset, A.op)
128-
op1 = Base.Fix2(scale, α)
129-
op2 = Base.Fix2(scale, β)
130-
Strided._mapreducedim!(op1, +, op2, newsize, (C, A′))
134+
Strided._mapreducedim!(Scaler(α), Adder(), Scaler(β), newsize, (C, A′))
131135
return C
132136
end
133137

@@ -170,8 +174,6 @@ function stridedtensorcontract!(C::StridedView,
170174
(osizeA..., osizeB..., one.(csizeA)...))
171175
tsize = (osizeA..., osizeB..., csizeA...)
172176

173-
op1 = Base.Fix2(scale, α) *
174-
op2 = Base.Fix2(scale, β)
175-
Strided._mapreducedim!(op1, +, op2, tsize, (CS, AS, BS))
177+
Strided._mapreducedim!(Scaler(α), Adder(), Scaler(β), tsize, (CS, AS, BS))
176178
return C
177179
end

0 commit comments

Comments
 (0)