@@ -165,7 +165,18 @@ Base.:(==)(A::CompositeMap, B::CompositeMap) =
165
165
(eltype (A) == eltype (B) && all (A. maps .== B. maps))
166
166
167
167
# multiplication with vectors/matrices
168
- _unsafe_mul! (y, A:: CompositeMap , x:: AbstractVector ) = _compositemul! (y, A, x)
168
+ function Base.:(* )(A:: CompositeMap , x:: AbstractVector )
169
+ MulStyle (A) === TwoArg () ?
170
+ foldr (* , reverse (A. maps), init= x) :
171
+ invoke (* , Tuple{LinearMap, AbstractVector}, A, x)
172
+ end
173
+
174
+ function _unsafe_mul! (y, A:: CompositeMap , x:: AbstractVector )
175
+ MulStyle (A) === TwoArg () ?
176
+ copyto! (y, foldr (* , reverse (A. maps), init= x)) :
177
+ _compositemul! (y, A, x)
178
+ return y
179
+ end
169
180
_unsafe_mul! (y, A:: CompositeMap , x:: AbstractMatrix ) = _compositemul! (y, A, x)
170
181
171
182
function _compositemul! (y, A:: CompositeMap{<:Any,<:Tuple{LinearMap}} , x,
@@ -174,10 +185,50 @@ function _compositemul!(y, A::CompositeMap{<:Any,<:Tuple{LinearMap}}, x,
174
185
return _unsafe_mul! (y, A. maps[1 ], x)
175
186
end
176
187
function _compositemul! (y, A:: CompositeMap{<:Any,<:Tuple{LinearMap,LinearMap}} , x,
177
- source = similar (y, (size (A. maps[1 ],1 ), size (x)[2 : end ]. .. )),
188
+ source = nothing ,
189
+ dest = nothing )
190
+ if isnothing (source)
191
+ z = convert (AbstractArray, A. maps[1 ] * x)
192
+ _unsafe_mul! (y, A. maps[2 ], z)
193
+ return y
194
+ else
195
+ _unsafe_mul! (source, A. maps[1 ], x)
196
+ _unsafe_mul! (y, A. maps[2 ], source)
197
+ return y
198
+ end
199
+ end
200
+ _compositemul! (y, A:: CompositeMap{<:Any,<:LinearMapTuple} , x, s = nothing , d = nothing ) =
201
+ _compositemulN! (y, A, x, s, d)
202
+ function _compositemul! (y, A:: CompositeMap{<:Any,<:LinearMapVector} , x,
203
+ source = nothing ,
178
204
dest = nothing )
179
- _unsafe_mul! (source, A. maps[1 ], x)
180
- _unsafe_mul! (y, A. maps[2 ], source)
205
+ N = length (A. maps)
206
+ if N == 1
207
+ return _unsafe_mul! (y, A. maps[1 ], x)
208
+ elseif N == 2
209
+ return _unsafe_mul! (y, A. maps[2 ] * A. maps[1 ], x)
210
+ else
211
+ return _compositemulN! (y, A, x, source, dest)
212
+ end
213
+ end
214
+
215
+ function _compositemulN! (y, A:: CompositeMap , x,
216
+ src = nothing ,
217
+ dst = nothing )
218
+ N = length (A. maps) # ≥ 3
219
+ source = isnothing (src) ?
220
+ convert (AbstractArray, A. maps[1 ] * x) :
221
+ _unsafe_mul! (src, A. maps[1 ], x)
222
+ dest = isnothing (dst) ?
223
+ convert (AbstractArray, A. maps[2 ] * source) :
224
+ _unsafe_mul! (dst, A. maps[2 ], source)
225
+ dest, source = source, dest # alternate dest and source
226
+ for n in 3 : N- 1
227
+ dest = _resize (dest, (size (A. maps[n], 1 ), size (x)[2 : end ]. .. ))
228
+ _unsafe_mul! (dest, A. maps[n], source)
229
+ dest, source = source, dest # alternate dest and source
230
+ end
231
+ _unsafe_mul! (y, A. maps[N], source)
181
232
return y
182
233
end
183
234
@@ -197,48 +248,3 @@ function _resize(dest::AbstractMatrix, sz::Tuple{<:Integer,<:Integer})
197
248
size (dest) == sz && return dest
198
249
similar (dest, sz)
199
250
end
200
-
201
- function _compositemul! (y, A:: CompositeMap{<:Any,<:LinearMapTuple} , x,
202
- source = similar (y, (size (A. maps[1 ],1 ), size (x)[2 : end ]. .. )),
203
- dest = similar (y, (size (A. maps[2 ],1 ), size (x)[2 : end ]. .. )))
204
- N = length (A. maps)
205
- _unsafe_mul! (source, A. maps[1 ], x)
206
- for n in 2 : N- 1
207
- dest = _resize (dest, (size (A. maps[n],1 ), size (x)[2 : end ]. .. ))
208
- _unsafe_mul! (dest, A. maps[n], source)
209
- dest, source = source, dest # alternate dest and source
210
- end
211
- _unsafe_mul! (y, A. maps[N], source)
212
- return y
213
- end
214
-
215
- function _compositemul! (y, A:: CompositeMap{<:Any,<:LinearMapVector} , x)
216
- N = length (A. maps)
217
- if N == 1
218
- return _unsafe_mul! (y, A. maps[1 ], x)
219
- elseif N == 2
220
- return _compositemul2! (y, A, x)
221
- else
222
- return _compositemulN! (y, A, x)
223
- end
224
- end
225
-
226
- function _compositemul2! (y, A:: CompositeMap{<:Any,<:LinearMapVector} , x,
227
- source = similar (y, (size (A. maps[1 ],1 ), size (x)[2 : end ]. .. )))
228
- _unsafe_mul! (source, A. maps[1 ], x)
229
- _unsafe_mul! (y, A. maps[2 ], source)
230
- return y
231
- end
232
- function _compositemulN! (y, A:: CompositeMap{<:Any,<:LinearMapVector} , x,
233
- source = similar (y, (size (A. maps[1 ],1 ), size (x)[2 : end ]. .. )),
234
- dest = similar (y, (size (A. maps[2 ],1 ), size (x)[2 : end ]. .. )))
235
- N = length (A. maps)
236
- _unsafe_mul! (source, A. maps[1 ], x)
237
- for n in 2 : N- 1
238
- dest = _resize (dest, (size (A. maps[n],1 ), size (x)[2 : end ]. .. ))
239
- _unsafe_mul! (dest, A. maps[n], source)
240
- dest, source = source, dest # alternate dest and source
241
- end
242
- _unsafe_mul! (y, A. maps[N], source)
243
- return y
244
- end
0 commit comments