Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b95e83a
CompatHelper: bump compat for cuTENSOR to 2, (keep existing compat)
Jan 19, 2024
2eb0dc2
Remove cuTENSOR v1
lkdvos Feb 10, 2024
729f76c
Rewrite cuTENSOR module for v2
lkdvos Feb 10, 2024
8eae6b2
Rearrange cuTENSOR tests
lkdvos Feb 10, 2024
c97dc70
Re-enable `@cutensor` dependency check test
lkdvos Feb 11, 2024
1a925d0
Bump Julia compat to 1.8
lkdvos Feb 11, 2024
f066505
use namespace enums
lkdvos Feb 11, 2024
16e865e
Formatter
lkdvos Feb 11, 2024
e99450a
Decouple default backend selection to reduce ambiguity errors
lkdvos Apr 15, 2024
fcf32d1
Rewrite cuTENSOR backend to support CuStridedView (ugly)
lkdvos Apr 15, 2024
43a218e
Formatter
lkdvos Apr 15, 2024
ef8c837
Clean-up cutensor implementation
lkdvos Apr 17, 2024
da4fd98
Update cutensor tests [no ci]
lkdvos Apr 17, 2024
7b18b2e
Remove code that got incorporated in cuTENSOR
lkdvos May 5, 2024
612f918
Only copy if Adapt actually changed something
lkdvos May 5, 2024
ce54f9f
Move tensorfree!
lkdvos May 5, 2024
d98f7f2
Improve alignment computation
lkdvos May 5, 2024
6b7d885
Remove some unused imports
lkdvos May 5, 2024
ee0692a
Bump CUDA version requirement
lkdvos May 16, 2024
56c0ae0
Merge branch 'master' into cutensor
lkdvos May 26, 2024
103c403
Update cutensor extension arg order
lkdvos May 26, 2024
36fb70f
fixup! Merge branch 'master' into cutensor
lkdvos May 26, 2024
0bfc23b
Bump cuTENSOR compat + Bump version v5.0.0-DEV
lkdvos May 28, 2024
60c6597
Fix method ambiguities
lkdvos May 28, 2024
de459f7
fixup! Update cutensor extension arg order
lkdvos May 28, 2024
0cb1ac4
Also update cuTENSOR compat
lkdvos May 28, 2024
267cb29
Change copyto! to copy!
lkdvos May 29, 2024
c82eac7
Remove unused trace!
lkdvos May 29, 2024
714835a
Fix alignment computation
lkdvos May 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6' # LTS version
- '1.8' # lowest supported version
- '1' # automatically expands to the latest stable 1.x release of Julia
os:
- ubuntu-latest
Expand Down
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorOperations"
uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
authors = ["Lukas Devos <lukas.devos@ugent.be>", "Maarten Van Damme <maartenvd1994@gmail.com>", "Jutho Haegeman <jutho.haegeman@ugent.be>"]
version = "4.1.1"
version = "5.0.0-DEV"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand All @@ -26,7 +26,7 @@ TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"]

[compat]
Aqua = "0.6, 0.7, 0.8"
CUDA = "4,5"
CUDA = "5.4.0"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
DynamicPolynomials = "0.5"
Expand All @@ -40,8 +40,8 @@ StridedViews = "0.2"
Test = "1"
TupleTools = "1.1"
VectorInterface = "0.4.1"
cuTENSOR = "1"
julia = "1.6"
cuTENSOR = "2.1.1"
julia = "1.8"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Expand Down
519 changes: 234 additions & 285 deletions ext/TensorOperationscuTENSORExt.jl

Large diffs are not rendered by default.

35 changes: 28 additions & 7 deletions src/implementation/strided.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,32 @@
#-------------------------------------------------------------------------------------------
# StridedView implementation
#-------------------------------------------------------------------------------------------

# default backends
function tensoradd!(C::StridedView,
A::StridedView, pA::Index2Tuple, conjA::Symbol,
α::Number, β::Number)
backend = eltype(C) isa BlasFloat ? StridedBLAS() : StridedNative()
return tensoradd!(C, A, pA, conjA, α, β, backend)
end
function tensortrace!(C::StridedView,
A::StridedView, p::Index2Tuple, q::Index2Tuple, conjA::Symbol,
α::Number, β::Number)
backend = eltype(C) isa BlasFloat ? StridedBLAS() : StridedNative()
return tensortrace!(C, A, p, q, conjA, α, β, backend)
end
function tensorcontract!(C::StridedView,
A::StridedView, pA::Index2Tuple, conjA::Symbol,
B::StridedView, pB::Index2Tuple, conjB::Symbol,
pAB::Index2Tuple, α::Number, β::Number)
backend = eltype(C) isa BlasFloat ? StridedBLAS() : StridedNative()
return tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, backend)
end

function tensoradd!(C::StridedView,
A::StridedView, pA::Index2Tuple, conjA::Symbol,
α::Number, β::Number,
backend::Union{StridedNative,StridedBLAS}=StridedNative())
::Union{StridedNative,StridedBLAS})
argcheck_tensoradd(C, A, pA)
dimcheck_tensoradd(C, A, pA)
if !istrivialpermutation(pA) && Base.mightalias(C, A)
Expand All @@ -21,7 +43,7 @@ end
function tensortrace!(C::StridedView,
A::StridedView, p::Index2Tuple, q::Index2Tuple, conjA::Symbol,
α::Number, β::Number,
backend::Union{StridedNative,StridedBLAS}=StridedNative())
::Union{StridedNative,StridedBLAS})
argcheck_tensortrace(C, A, p, q)
dimcheck_tensortrace(C, A, p, q)

Expand All @@ -41,12 +63,11 @@ function tensortrace!(C::StridedView,
return C
end

function tensorcontract!(C::StridedView{T},
function tensorcontract!(C::StridedView,
A::StridedView, pA::Index2Tuple, conjA::Symbol,
B::StridedView, pB::Index2Tuple, conjB::Symbol,
pAB::Index2Tuple,
α::Number, β::Number,
backend::StridedBLAS=StridedBLAS()) where {T<:LinearAlgebra.BlasFloat}
α::Number, β::Number, ::StridedBLAS)
argcheck_tensorcontract(C, A, pA, B, pB, pAB)
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)

Expand Down Expand Up @@ -74,7 +95,7 @@ function tensorcontract!(C::StridedView{T,2},
A::StridedView{T,2}, pA::Index2Tuple{1,1}, conjA::Symbol,
B::StridedView{T,2}, pB::Index2Tuple{1,1}, conjB::Symbol,
pAB::Index2Tuple{1,1}, α::Number, β::Number,
backend::StridedBLAS=StridedBLAS()) where {T<:LinearAlgebra.BlasFloat}
::StridedBLAS) where {T}
argcheck_tensorcontract(C, A, pA, B, pB, pAB)
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)

Expand All @@ -97,7 +118,7 @@ function tensorcontract!(C::StridedView,
A::StridedView, pA::Index2Tuple, conjA::Symbol,
B::StridedView, pB::Index2Tuple, conjB::Symbol,
pAB::Index2Tuple, α::Number, β::Number,
backend::StridedNative)
::StridedNative)
argcheck_tensorcontract(C, A, pA, B, pB, pAB)
dimcheck_tensorcontract(C, A, pA, B, pB, pAB)

Expand Down
Loading