Skip to content

Runtime dispatch overhead for small tensor contractions using the StridedBLAS backend #189

@leburgel

Description

@leburgel

When profiling some code involving many contractions with small tensors I noticed that there is a lot of overhead due to runtime dispatch and resulting allocations and garbage collection when using the StridedBLAS backend. I've figured out some of it but I thought it would be good to report here to see if something more can be done.

I ran a profile for a dummy example contraction of small complex tensors which can sort of reproduce the issue:

using TensorOperations

T = ComplexF64
L = randn(T, 8, 8, 7)
R = randn(T, 8, 8, 7)
O = randn(T, 4, 4, 7, 7)

function local_update!(psi::Array{T,3})::Array{T,3} where {T}
    @tensor begin
        psi[-1, -2, -3] =
            psi[1, 3, 5] *
            L[-1, 1, 2] *
            O[-2, 3, 2, 4] *
            R[-3, 5, 4]
    end
    return psi
end

psi0 = randn(T, 8, 4, 8)
@profview begin
    for _ in 1:100000
        local_update!(psi0)
    end
end

which gives:
small_contract_profile

So a lot of runtime dispatch overhead in TensorOperations.tensoradd! and Strided._mapreduce_order!. The effect becomes negligible for large tensor dimensions, but it turns out to be a real pain if there are a lot of these small contractions being performed.

For TensorOperations.tensoradd!, I managed to track the problem to an ambiguity caused by patterns like flag2op(conjA)(A), where the return type at compile time can be a Union of two StridedView concrete types with typeof(identity) and typeof(conj) as their op field types respectively. This leads to an issue here:
https://github.com/Jutho/TensorOperations.jl/blob/c1e37eca9c2d1ab468ddd1beb2ffb9eb0993bb4b/src/implementation/strided.jl#L99-L103
where the last argument in the call to Strided._mapreducedim! is a Tuple with mixed concrete and abstract (the union described above) types in its type parameters, which seems to mess things up.

At that level I managed to fix things by just splitting into two branches to de-confuse the compiler

opA = flag2op(conjA)
if opA isa typeof(identity)
    A′ = permutedims(identity(A), linearize(pA))
    op1 = Base.Fix2(scale, α)
    op2 = Base.Fix2(scale, β)
    Strided._mapreducedim!(op1, +, op2, size(C), (C, A′))
    return C
elseif opA isa typeof(conj)
    A′ = permutedims(conj(A), linearize(pA))
    op1 = Base.Fix2(scale, α)
    op2 = Base.Fix2(scale, β)
    Strided._mapreducedim!(op1, +, op2, size(C), (C, A′))
    return C
end

which gets rid of the runtime dispatch in tensoradd! and already makes a big difference for small contractions.

I haven't descended all the way down into Strided._mapreduce_dim!, so I don't know what the issue is there.

So in the end my questions are:

  • Is there a cleaner way to deal with the conj flags and their effect on StridedViews that could avoid these kinds of type ambiguities?
  • Do you have any idea if the runtime dispatch in Strided._mapreduce_dim! could be prevented by changing the StridedBLAS backend implementation here, and if not, if it could be avoided altogether by a reasonable change to Strided.jl?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions