Closed
Description
Motivation and description
use case is the multihead attention layer in FluxML/Flux.jl#2146
Possible Implementation
function NNlib.batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
sz = size(x)[3:end]
@assert sz == size(y)[3:end]
x2 = reshape(x, size(x, 1), size(x, 2), :)
y2 = reshape(y, size(y, 1), size(y, 2), :)
z = NNlib.batched_mul(x2, y2)
return reshape(z, size(z, 1), size(z, 2), sz...)
end