91
91
# -------------------------------------------------------------------------------------------
92
92
# StridedView implementation
93
93
# -------------------------------------------------------------------------------------------
94
+ struct Adder end
95
+ (:: Adder )(x, y) = VectorInterface. add (x, y)
96
+ struct Scaler{T}
97
+ α:: T
98
+ end
99
+ (s:: Scaler )(x) = scale (x, s. α)
100
+ (s:: Scaler )(x, y) = scale (x * y, s. α)
101
+
94
102
function stridedtensoradd! (C:: StridedView ,
95
103
A:: StridedView , pA:: Index2Tuple ,
96
104
α:: Number , β:: Number ,
@@ -102,9 +110,7 @@ function stridedtensoradd!(C::StridedView,
102
110
end
103
111
104
112
A′ = permutedims (A, linearize (pA))
105
- op1 = Base. Fix2 (scale, α)
106
- op2 = Base. Fix2 (scale, β)
107
- Strided. _mapreducedim! (op1, + , op2, size (C), (C, A′))
113
+ Strided. _mapreducedim! (Scaler (α), Adder (), Scaler (β), size (C), (C, A′))
108
114
return C
109
115
end
110
116
@@ -125,9 +131,7 @@ function stridedtensortrace!(C::StridedView,
125
131
newsize = (size (C)... , tracesize... )
126
132
127
133
A′ = SV (A. parent, newsize, newstrides, A. offset, A. op)
128
- op1 = Base. Fix2 (scale, α)
129
- op2 = Base. Fix2 (scale, β)
130
- Strided. _mapreducedim! (op1, + , op2, newsize, (C, A′))
134
+ Strided. _mapreducedim! (Scaler (α), Adder (), Scaler (β), newsize, (C, A′))
131
135
return C
132
136
end
133
137
@@ -170,8 +174,6 @@ function stridedtensorcontract!(C::StridedView,
170
174
(osizeA... , osizeB... , one .(csizeA)... ))
171
175
tsize = (osizeA... , osizeB... , csizeA... )
172
176
173
- op1 = Base. Fix2 (scale, α) ∘ *
174
- op2 = Base. Fix2 (scale, β)
175
- Strided. _mapreducedim! (op1, + , op2, tsize, (CS, AS, BS))
177
+ Strided. _mapreducedim! (Scaler (α), Adder (), Scaler (β), tsize, (CS, AS, BS))
176
178
return C
177
179
end
0 commit comments