Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 26 additions & 27 deletions ext/TensorOperationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ using VectorInterface
using TupleTools: invperm
using LinearAlgebra

_conj(conjA::Symbol) = conjA == :C ? :N : :C
trivtuple(N) = ntuple(identity, N)

@non_differentiable TensorOperations.tensorstructure(args...)
Expand Down Expand Up @@ -44,7 +43,7 @@ end

function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
C,
A, pA::Index2Tuple, conjA::Symbol,
A, pA::Index2Tuple, conjA::Bool,
α::Number, β::Number, backend::Backend...)
C′ = tensoradd!(copy(C), A, pA, conjA, α, β, backend...)

Expand All @@ -59,20 +58,20 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
dA = @thunk begin
ipA = invperm(linearize(pA))
_dA = zerovector(A, VectorInterface.promote_add(ΔC, α))
_dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA == :N ? conj(α) : α, Zero(),
_dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(),
backend...)
return projectA(_dA)
end
dα = @thunk begin
_dα = tensorscalar(tensorcontract(A, ((), linearize(pA)), _conj(conjA),
ΔC, (trivtuple(numind(pA)), ()), :N,
_dα = tensorscalar(tensorcontract(A, ((), linearize(pA)), !conjA,
ΔC, (trivtuple(numind(pA)), ()), false,
((), ()), One(), backend...))
return projectα(_dα)
end
dβ = @thunk begin
# TODO: consider using `inner`
_dβ = tensorscalar(tensorcontract(C, ((), trivtuple(numind(pA))), :C,
ΔC, (trivtuple(numind(pA)), ()), :N,
_dβ = tensorscalar(tensorcontract(C, ((), trivtuple(numind(pA))), true,
ΔC, (trivtuple(numind(pA)), ()), false,
((), ()), One(), backend...))
return projectβ(_dβ)
end
Expand All @@ -85,8 +84,8 @@ end

function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
C,
A, pA::Index2Tuple, conjA::Symbol,
B, pB::Index2Tuple, conjB::Symbol,
A, pA::Index2Tuple, conjA::Bool,
B, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple,
α::Number, β::Number, backend::Backend...)
C′ = tensorcontract!(copy(C), A, pA, conjA, B, pB, conjB, pAB, α, β, backend...)
Expand All @@ -105,42 +104,42 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk begin
ipA = (invperm(linearize(pA)), ())
conjΔC = conjA == :C ? :C : :N
conjB′ = conjA == :C ? conjB : _conj(conjB)
conjΔC = conjA
conjB′ = conjA ? conjB : !conjB
_dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B), typeof(α)))
_dA = tensorcontract!(_dA,
ΔC, pΔC, conjΔC,
B, reverse(pB), conjB′,
ipA,
conjA == :C ? α : conj(α), Zero(), backend...)
conjA ? α : conj(α), Zero(), backend...)
return projectA(_dA)
end
dB = @thunk begin
ipB = (invperm(linearize(pB)), ())
conjΔC = conjB == :C ? :C : :N
conjA′ = conjB == :C ? conjA : _conj(conjA)
conjΔC = conjB
conjA′ = conjB ? conjA : !conjA
_dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A), typeof(α)))
_dB = tensorcontract!(_dB,
A, reverse(pA), conjA′,
ΔC, pΔC, conjΔC,
ipB,
conjB == :C ? α : conj(α), Zero(), backend...)
conjB ? α : conj(α), Zero(), backend...)
return projectB(_dB)
end
dα = @thunk begin
C_αβ = tensorcontract(A, pA, conjA,
B, pB, conjB,
pAB, One(), backend...)
# TODO: consider using `inner`
_dα = tensorscalar(tensorcontract(C_αβ, ((), trivtuple(numind(pAB))), :C,
ΔC, (trivtuple(numind(pAB)), ()), :N,
_dα = tensorscalar(tensorcontract(C_αβ, ((), trivtuple(numind(pAB))), true,
ΔC, (trivtuple(numind(pAB)), ()), false,
((), ()), One(), backend...))
return projectα(_dα)
end
dβ = @thunk begin
# TODO: consider using `inner`
_dβ = tensorscalar(tensorcontract(C, ((), trivtuple(numind(pAB))), :C,
ΔC, (trivtuple(numind(pAB)), ()), :N,
_dβ = tensorscalar(tensorcontract(C, ((), trivtuple(numind(pAB))), true,
ΔC, (trivtuple(numind(pAB)), ()), false,
((), ()), One(), backend...))
return projectβ(_dβ)
end
Expand All @@ -156,7 +155,7 @@ end
# note that this requires `one` to be defined, which is already not the case for regular
# arrays when tracing multiple indices at the same time.
function ChainRulesCore.rrule(::typeof(tensortrace!), C,
A, p::Index2Tuple, q::Index2Tuple, conjA::Symbol,
A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
α::Number, β::Number, backend::Backend...)
C′ = tensortrace!(copy(C), A, p, q, conjA, α, β, backend...)

Expand All @@ -179,20 +178,20 @@ function ChainRulesCore.rrule(::typeof(tensortrace!), C,
_dA = tensorproduct!(_dA, ΔC, (trivtuple(numind(p)), ()), conjA,
E, ((), trivtuple(numind(q))), conjA,
(ip, ()),
conjA == :N ? conj(α) : α, Zero(), backend...)
conjA ? α : conj(α), Zero(), backend...)
return projectA(_dA)
end
dα = @thunk begin
C_αβ = tensortrace(A, p, q, :N, One(), backend...)
C_αβ = tensortrace(A, p, q, false, One(), backend...)
_dα = tensorscalar(tensorcontract(C_αβ, ((), trivtuple(numind(p))),
_conj(conjA),
ΔC, (trivtuple(numind(p)), ()), :N,
!conjA,
ΔC, (trivtuple(numind(p)), ()), false,
((), ()), One(), backend...))
return projectα(_dα)
end
dβ = @thunk begin
_dβ = tensorscalar(tensorcontract(C, ((), trivtuple(numind(p))), :C,
ΔC, (trivtuple(numind(p)), ()), :N,
_dβ = tensorscalar(tensorcontract(C, ((), trivtuple(numind(p))), true,
ΔC, (trivtuple(numind(p)), ()), false,
((), ()), One(), backend...))
return projectβ(_dβ)
end
Expand All @@ -210,7 +209,7 @@ function _kron(Es::NTuple{N,Any}, backend::Backend...) where {N}
E2 = _kron(Base.tail(Es), backend...)
p2 = ((), trivtuple(2 * N - 2))
p = ((1, (2 .+ trivtuple(N - 1))...), (2, ((N + 1) .+ trivtuple(N - 1))...))
return tensorproduct(p, E1, ((1, 2), ()), :N, E2, p2, :N, One(), backend...)
return tensorproduct(p, E1, ((1, 2), ()), false, E2, p2, false, One(), backend...)
end

# NCON functions
Expand Down
52 changes: 26 additions & 26 deletions ext/TensorOperationscuTENSORExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ function TO.tensorscalar(C::CuStridedView)
return ndims(C) == 0 ? CUDA.@allowscalar(C[]) : throw(DimensionMismatch())
end

function tensorop(A::StridedCuArray, conjA::Symbol=:N)
return (eltype(A) <: Real || conjA === :N) ? OP_IDENTITY : OP_CONJ
function tensorop(A::StridedCuArray, conjA::Bool=false)
return (eltype(A) <: Real || !conjA) ? OP_IDENTITY : OP_CONJ
end
function tensorop(A::CuStridedView, conjA::Symbol=:N)
return if (eltype(A) <: Real || !xor(conjA === :C, A.op === conj))
function tensorop(A::CuStridedView, conjA::Bool=false)
return if (eltype(A) <: Real || !xor(conjA, A.op === conj))
OP_IDENTITY
else
OP_CONJ
Expand All @@ -64,28 +64,28 @@ end

for ArrayType in SUPPORTED_CUARRAYS
@eval function TO.tensoradd!(C::$ArrayType, A::$ArrayType, pA::Index2Tuple,
conjA::Symbol,
conjA::Bool,
α::Number, β::Number)
return tensoradd!(C, A, pA, conjA, α, β, cuTENSORBackend())
end
@eval function TO.tensorcontract!(C::$ArrayType,
A::$ArrayType, pA::Index2Tuple, conjA::Symbol,
B::$ArrayType, pB::Index2Tuple, conjB::Symbol,
A::$ArrayType, pA::Index2Tuple, conjA::Bool,
B::$ArrayType, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple, α::Number, β::Number)
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, cuTENSORBackend())
end
@eval function TO.tensortrace!(C::$ArrayType,
A::$ArrayType, p::Index2Tuple, q::Index2Tuple,
conjA::Symbol,
conjA::Bool,
α::Number, β::Number)
return tensortrace!(C, A, p, q, conjA, α, β, cuTENSORBackend())
end
@eval function TO.tensoradd_type(TC, ::$ArrayType, pA::Index2Tuple, conjA::Symbol)
@eval function TO.tensoradd_type(TC, ::$ArrayType, pA::Index2Tuple, conjA::Bool)
return CUDA.CuArray{TC,TO.numind(pA)}
end
@eval function TO.tensorcontract_type(TC,
::$ArrayType, pA::Index2Tuple, conjA::Symbol,
::$ArrayType, pB::Index2Tuple, conjB::Symbol,
::$ArrayType, pA::Index2Tuple, conjA::Bool,
::$ArrayType, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple)
return CUDA.CuArray{TC,TO.numind(pAB)}
end
Expand All @@ -95,7 +95,7 @@ end
# making sure that if the backend is specified, arrays are converted to CuArrays

function TO.tensoradd!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Symbol,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
α::Number, β::Number,
backend::cuTENSORBackend)
C_cuda = adapt(CuArray, C)
Expand All @@ -105,8 +105,8 @@ function TO.tensoradd!(C::AbstractArray,
return C
end
function TO.tensorcontract!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Symbol,
B::AbstractArray, pB::Index2Tuple, conjB::Symbol,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
B::AbstractArray, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple,
α::Number, β::Number, backend::cuTENSORBackend)
C_cuda = adapt(CuArray, C)
Expand All @@ -117,7 +117,7 @@ function TO.tensorcontract!(C::AbstractArray,
return C
end
function TO.tensortrace!(C::AbstractArray,
A::AbstractArray, p::Index2Tuple, q::Index2Tuple, conjA::Symbol,
A::AbstractArray, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
α::Number, β::Number, backend::cuTENSORBackend)
C_cuda = adapt(CuArray, C)
A_cuda = adapt(CuArray, A)
Expand All @@ -126,7 +126,7 @@ function TO.tensortrace!(C::AbstractArray,
return C
end

function TO.tensoralloc_add(TC, A::AbstractArray, pA::Index2Tuple, conjA::Symbol,
function TO.tensoralloc_add(TC, A::AbstractArray, pA::Index2Tuple, conjA::Bool,
istemp::Bool,
::cuTENSORBackend)
ttype = CuArray{TC,TO.numind(pA)}
Expand All @@ -135,8 +135,8 @@ function TO.tensoralloc_add(TC, A::AbstractArray, pA::Index2Tuple, conjA::Symbol
end

function TO.tensoralloc_contract(TC,
A::AbstractArray, pA::Index2Tuple, conjA::Symbol,
B::AbstractArray, pB::Index2Tuple, conjB::Symbol,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
B::AbstractArray, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple,
istemp::Bool, ::cuTENSORBackend)
ttype = CuArray{TC,TO.numind(pAB)}
Expand All @@ -152,22 +152,22 @@ end
# Convert all implementations to StridedViews
# This should work for wrapper types that are supported by StridedViews
function TO.tensoradd!(C::AnyCuArray,
A::AnyCuArray, pA::Index2Tuple, conjA::Symbol,
A::AnyCuArray, pA::Index2Tuple, conjA::Bool,
α::Number, β::Number, backend::cuTENSORBackend)
tensoradd!(StridedView(C), StridedView(A), pA, conjA, α, β, backend)
return C
end
function TO.tensorcontract!(C::AnyCuArray, A::AnyCuArray,
pA::Index2Tuple, conjA::Symbol, B::AnyCuArray,
pB::Index2Tuple, conjB::Symbol, pAB::Index2Tuple, α::Number,
pA::Index2Tuple, conjA::Bool, B::AnyCuArray,
pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, α::Number,
β::Number,
backend::cuTENSORBackend)
tensorcontract!(StridedView(C), StridedView(A), pA, conjA,
StridedView(B), pB, conjB, pAB, α, β, backend)
return C
end
function TO.tensortrace!(C::AnyCuArray,
A::AnyCuArray, p::Index2Tuple, q::Index2Tuple, conjA::Symbol,
A::AnyCuArray, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
α::Number, β::Number, backend::cuTENSORBackend)
tensortrace!(StridedView(C), StridedView(A), p, q, conjA, α, β, backend)
return C
Expand All @@ -178,7 +178,7 @@ end
#-------------------------------------------------------------------------------------------

function TO.tensoradd!(C::CuStridedView,
A::CuStridedView, pA::Index2Tuple, conjA::Symbol,
A::CuStridedView, pA::Index2Tuple, conjA::Bool,
α::Number, β::Number, ::cuTENSORBackend)
# convert arguments
Ainds, Cinds = collect.(TO.add_labels(pA))
Expand All @@ -198,8 +198,8 @@ function TO.tensoradd!(C::CuStridedView,
end

function TO.tensorcontract!(C::CuStridedView,
A::CuStridedView, pA::Index2Tuple, conjA::Symbol,
B::CuStridedView, pB::Index2Tuple, conjB::Symbol,
A::CuStridedView, pA::Index2Tuple, conjA::Bool,
B::CuStridedView, pB::Index2Tuple, conjB::Bool,
pAB::Index2Tuple,
α::Number, β::Number, ::cuTENSORBackend)
# convert arguments
Expand All @@ -217,7 +217,7 @@ function TO.tensorcontract!(C::CuStridedView,
end

function TO.tensortrace!(C::CuStridedView,
A::CuStridedView, p::Index2Tuple, q::Index2Tuple, conjA::Symbol,
A::CuStridedView, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
α::Number, β::Number, ::cuTENSORBackend)
# convert arguments
Ainds, Cinds = collect.(TO.trace_labels(p, q))
Expand Down
Loading