-
Notifications
You must be signed in to change notification settings - Fork 63
Description
When profiling some code involving many contractions with small tensors I noticed that there is a lot of overhead due to runtime dispatch and resulting allocations and garbage collection when using the StridedBLAS
backend. I've figured out some of it but I thought it would be good to report here to see if something more can be done.
I ran a profile for a dummy example contraction of small complex tensors which can sort of reproduce the issue:
using TensorOperations
T = ComplexF64
L = randn(T, 8, 8, 7)
R = randn(T, 8, 8, 7)
O = randn(T, 4, 4, 7, 7)
function local_update!(psi::Array{T,3})::Array{T,3} where {T}
@tensor begin
psi[-1, -2, -3] =
psi[1, 3, 5] *
L[-1, 1, 2] *
O[-2, 3, 2, 4] *
R[-3, 5, 4]
end
return psi
end
psi0 = randn(T, 8, 4, 8)
@profview begin
for _ in 1:100000
local_update!(psi0)
end
end
So a lot of runtime dispatch overhead in TensorOperations.tensoradd!
and Strided._mapreduce_order!
. The effect becomes negligible for large tensor dimensions, but it turns out to be a real pain if there are a lot of these small contractions being performed.
For TensorOperations.tensoradd!
, I managed to track the problem to an ambiguity caused by patterns like flag2op(conjA)(A)
, where the return type at compile time can be a Union
of two StridedView
concrete types with typeof(identity)
and typeof(conj)
as their op
field types respectively. This leads to an issue here:
https://github.com/Jutho/TensorOperations.jl/blob/c1e37eca9c2d1ab468ddd1beb2ffb9eb0993bb4b/src/implementation/strided.jl#L99-L103
where the last argument in the call to Strided._mapreducedim!
is a Tuple
with mixed concrete and abstract (the union described above) types in its type parameters, which seems to mess things up.
At that level I managed to fix things by just splitting into two branches to de-confuse the compiler
opA = flag2op(conjA)
if opA isa typeof(identity)
A′ = permutedims(identity(A), linearize(pA))
op1 = Base.Fix2(scale, α)
op2 = Base.Fix2(scale, β)
Strided._mapreducedim!(op1, +, op2, size(C), (C, A′))
return C
elseif opA isa typeof(conj)
A′ = permutedims(conj(A), linearize(pA))
op1 = Base.Fix2(scale, α)
op2 = Base.Fix2(scale, β)
Strided._mapreducedim!(op1, +, op2, size(C), (C, A′))
return C
end
which gets rid of the runtime dispatch in tensoradd!
and already makes a big difference for small contractions.
I haven't descended all the way down into Strided._mapreduce_dim!
, so I don't know what the issue is there.
So in the end my questions are:
- Is there a cleaner way to deal with the
conj
flags and their effect onStridedView
s that could avoid these kinds of type ambiguities? - Do you have any idea if the runtime dispatch in
Strided._mapreduce_dim!
could be prevented by changing theStridedBLAS
backend implementation here, and if not, if it could be avoided altogether by a reasonable change to Strided.jl?