Skip to content

Commit 4df9b05

Browse files
committed
also format extensions
1 parent 3dda298 commit 4df9b05

File tree

1 file changed

+62
-60
lines changed

1 file changed

+62
-60
lines changed

ext/TensorOperationscuTENSORExt.jl

Lines changed: 62 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ using cuTENSOR.CUDA: with_workspace, default_stream
3131
# this might be dependency-piracy, but removes a dependency from the main package
3232
using cuTENSOR.CUDA.Adapt: adapt
3333

34-
function TensorOperations.tensorscalar(C::CuArray)
34+
const TO = TensorOperations
35+
function TO.tensorscalar(C::CuArray)
3536
return ndims(C) == 0 ? tensorscalar(collect(C)) : throw(DimensionMismatch())
3637
end
3738

@@ -50,9 +51,9 @@ end
5051
# Operations
5152
#-------------------------------------------------------------------------------------------
5253

53-
function TensorOperations.tensoradd!(C::CuArray, pC::Index2Tuple,
54-
A::CuArray, conjA::Symbol, α::Number, β::Number)
55-
TensorOperations.argcheck_tensoradd(C, pC, A)
54+
function TO.tensoradd!(C::CuArray, pC::Index2Tuple, A::CuArray, conjA::Symbol,
55+
α::Number, β::Number)
56+
TO.argcheck_tensoradd(C, pC, A)
5657

5758
T = eltype(C)
5859
conjA == :N || conjA == :C ||
@@ -78,12 +79,12 @@ function TensorOperations.tensoradd!(C::CuArray, pC::Index2Tuple,
7879
return C
7980
end
8081

81-
function TensorOperations.tensorcontract!(C::CuArray, pC::Index2Tuple,
82-
A::CuArray, pA::Index2Tuple, conjA::Symbol,
83-
B::CuArray, pB::Index2Tuple, conjB::Symbol,
84-
α, β)
85-
TensorOperations.argcheck_tensorcontract(C, pC, A, pA, B, pB)
86-
TensorOperations.dimcheck_tensorcontract(C, pC, A, pA, B, pB)
82+
function TO.tensorcontract!(C::CuArray, pC::Index2Tuple,
83+
A::CuArray, pA::Index2Tuple, conjA::Symbol,
84+
B::CuArray, pB::Index2Tuple, conjB::Symbol,
85+
α, β)
86+
TO.argcheck_tensorcontract(C, pC, A, pA, B, pB)
87+
TO.dimcheck_tensorcontract(C, pC, A, pA, B, pB)
8788

8889
conjA == :N || conjA == :C ||
8990
throw(ArgumentError("Value of conjA should be :N or :C instead of $conjA"))
@@ -98,9 +99,9 @@ function TensorOperations.tensorcontract!(C::CuArray, pC::Index2Tuple,
9899

99100
typeCompute = cutensorComputeType(T)
100101

101-
NoA = TensorOperations.numout(pA)
102-
NoB = TensorOperations.numin(pB)
103-
Nc = TensorOperations.numin(pA)
102+
NoA = TO.numout(pA)
103+
NoB = TO.numin(pB)
104+
Nc = TO.numin(pA)
104105

105106
modeoA = ntuple(n -> n, NoA)
106107
modeoB = ntuple(n -> NoA + n, NoB)
@@ -148,9 +149,9 @@ function TensorOperations.tensorcontract!(C::CuArray, pC::Index2Tuple,
148149
return C
149150
end
150151

151-
function TensorOperations.tensortrace!(C::CuArray, pC::Index2Tuple,
152-
A::CuArray, pA::Index2Tuple, conjA::Symbol, α, β)
153-
TensorOperations.argcheck_tensortrace(C, pC, A, pA)
152+
function TO.tensortrace!(C::CuArray, pC::Index2Tuple,
153+
A::CuArray, pA::Index2Tuple, conjA::Symbol, α, β)
154+
TO.argcheck_tensortrace(C, pC, A, pA)
154155
T = eltype(C)
155156
NA, NC = ndims(A), ndims(C)
156157

@@ -197,92 +198,93 @@ end
197198
# Allocations
198199
#-------------------------------------------------------------------------------------------
199200

200-
function TensorOperations.tensoradd_type(TC, pC::Index2Tuple, ::CuArray, conjA::Symbol)
201-
return CuArray{TC,TensorOperations.numind(pC)}
201+
function TO.tensoradd_type(TC, pC::Index2Tuple, ::CuArray, conjA::Symbol)
202+
return CuArray{TC,TO.numind(pC)}
202203
end
203204

204-
function TensorOperations.tensorcontract_type(TC, pC::Index2Tuple, ::CuArray,
205-
pA::Index2Tuple, conjA::Symbol, ::CuArray,
206-
pB::Index2Tuple, conjB::Symbol)
207-
return CuArray{TC,TensorOperations.numind(pC)}
205+
function TO.tensorcontract_type(TC, pC::Index2Tuple,
206+
::CuArray, pA::Index2Tuple, conjA::Symbol,
207+
::CuArray, pB::Index2Tuple, conjB::Symbol)
208+
return CuArray{TC,TO.numind(pC)}
208209
end
209210

210211
#-------------------------------------------------------------------------------------------
211212
# Backend
212213
#-------------------------------------------------------------------------------------------
213214

214-
const cuTENSORBackend = TensorOperations.Backend{:cuTENSOR}
215+
const cuTENSORBackend = TO.Backend{:cuTENSOR}
215216

216-
function TensorOperations.tensoradd!(C::AbstractArray, pC::Index2Tuple,
217-
A::AbstractArray, conjA::Symbol, α::Number, β::Number,
218-
backend::cuTENSORBackend)
217+
function TO.tensoradd!(C::AbstractArray, pC::Index2Tuple,
218+
A::AbstractArray, conjA::Symbol, α::Number, β::Number,
219+
backend::cuTENSORBackend)
219220
C_cuda = adapt(CuArray, C)
220221
tensoradd!(C_cuda, pC, A, conjA, α, β, backend)
221222
copyto!(C, collect(C_cuda))
222223
return C
223224
end
224225

225-
function TensorOperations.tensoradd!(C::CuArray, pC::Index2Tuple,
226-
A::AbstractArray, conjA::Symbol, α::Number, β::Number,
227-
::cuTENSORBackend)
226+
function TO.tensoradd!(C::CuArray, pC::Index2Tuple,
227+
A::AbstractArray, conjA::Symbol, α::Number, β::Number,
228+
::cuTENSORBackend)
228229
return tensoradd!(C, pC, adapt(CuArray, A), conjA, α, β)
229230
end
230231

231-
function TensorOperations.tensorcontract!(C::AbstractArray, pC::Index2Tuple,
232-
A::AbstractArray, pA::Index2Tuple, conjA::Symbol,
233-
B::AbstractArray, pB::Index2Tuple, conjB::Symbol,
234-
α, β, backend::cuTENSORBackend)
232+
function TO.tensorcontract!(C::AbstractArray, pC::Index2Tuple,
233+
A::AbstractArray, pA::Index2Tuple, conjA::Symbol,
234+
B::AbstractArray, pB::Index2Tuple, conjB::Symbol,
235+
α, β, backend::cuTENSORBackend)
235236
C_cuda = adapt(CuArray, C)
236237
tensorcontract!(C_cuda, pC, A, pA, conjA, B, pB, conjB, α, β, backend)
237238
copyto!(C, collect(C_cuda))
238239
return C
239240
end
240-
function TensorOperations.tensorcontract!(C::CuArray, pC::Index2Tuple,
241-
A::AbstractArray, pA::Index2Tuple, conjA::Symbol,
242-
B::AbstractArray, pB::Index2Tuple, conjB::Symbol,
243-
α, β, ::cuTENSORBackend)
241+
function TO.tensorcontract!(C::CuArray, pC::Index2Tuple,
242+
A::AbstractArray, pA::Index2Tuple, conjA::Symbol,
243+
B::AbstractArray, pB::Index2Tuple, conjB::Symbol,
244+
α, β, ::cuTENSORBackend)
244245
return tensorcontract!(C, pC, adapt(CuArray, A), pA, conjA, adapt(CuArray, B), pB,
245246
conjB, α, β)
246247
end
247248

248-
function TensorOperations.tensortrace!(C::AbstractArray, pC::Index2Tuple,
249-
A::AbstractArray, pA::Index2Tuple, conjA::Symbol,
250-
α, β, backend::cuTENSORBackend)
249+
function TO.tensortrace!(C::AbstractArray, pC::Index2Tuple,
250+
A::AbstractArray, pA::Index2Tuple, conjA::Symbol,
251+
α, β, backend::cuTENSORBackend)
251252
C_cuda = adapt(CuArray, C)
252253
tensortrace!(C_cuda, pC, A, pA, conjA, α, β, backend)
253254
copyto!(C, collect(C_cuda))
254255
return C
255256
end
256-
function TensorOperations.tensortrace!(C::CuArray, pC::Index2Tuple,
257-
A::AbstractArray, pA::Index2Tuple, conjA::Symbol,
258-
α, β, ::cuTENSORBackend)
257+
function TO.tensortrace!(C::CuArray, pC::Index2Tuple,
258+
A::AbstractArray, pA::Index2Tuple, conjA::Symbol,
259+
α, β, ::cuTENSORBackend)
259260
return tensortrace!(C, pC, adapt(CuArray, A), pA, conjA, α, β)
260261
end
261262

262-
function TensorOperations.tensoradd_type(TC, pC::Index2Tuple, ::AbstractArray,
263-
conjA::Symbol, ::cuTENSORBackend)
264-
return CuArray{TC,TensorOperations.numind(pC)}
263+
function TO.tensoradd_type(TC, pC::Index2Tuple, ::AbstractArray,
264+
conjA::Symbol, ::cuTENSORBackend)
265+
return CuArray{TC,TO.numind(pC)}
265266
end
266267

267-
function TensorOperations.tensorcontract_type(TC, pC::Index2Tuple, ::AbstractArray,
268-
pA::Index2Tuple, conjA::Symbol,
269-
::AbstractArray,
270-
pB::Index2Tuple, conjB::Symbol, ::cuTENSORBackend)
271-
return CuArray{TC,TensorOperations.numind(pC)}
268+
function TO.tensorcontract_type(TC, pC::Index2Tuple,
269+
::AbstractArray, pA::Index2Tuple, conjA::Symbol,
270+
::AbstractArray, pB::Index2Tuple, conjB::Symbol,
271+
::cuTENSORBackend)
272+
return CuArray{TC,TO.numind(pC)}
272273
end
273274

274-
function TensorOperations.tensoralloc_add(TC, pC, A::AbstractArray, conjA, istemp,
275-
::cuTENSORBackend)
276-
ttype = CuArray{TC,TensorOperations.numind(pC)}
277-
structure = TensorOperations.tensoradd_structure(pC, A, conjA)
275+
function TO.tensoralloc_add(TC, pC, A::AbstractArray, conjA, istemp,
276+
::cuTENSORBackend)
277+
ttype = CuArray{TC,TO.numind(pC)}
278+
structure = TO.tensoradd_structure(pC, A, conjA)
278279
return tensoralloc(ttype, structure, istemp)::ttype
279280
end
280281

281-
function TensorOperations.tensoralloc_contract(TC, pC, A::AbstractArray, pA, conjA,
282-
B::AbstractArray, pB, conjB, istemp,
283-
::cuTENSORBackend)
284-
ttype = CuArray{TC,TensorOperations.numind(pC)}
285-
structure = TensorOperations.tensorcontract_structure(pC, A, pA, conjA, B, pB, conjB)
282+
function TO.tensoralloc_contract(TC, pC,
283+
A::AbstractArray, pA, conjA,
284+
B::AbstractArray, pB, conjB,
285+
istemp, ::cuTENSORBackend)
286+
ttype = CuArray{TC,TO.numind(pC)}
287+
structure = TO.tensorcontract_structure(pC, A, pA, conjA, B, pB, conjB)
286288
return tensoralloc(ttype, structure, istemp)::ttype
287289
end
288290

0 commit comments

Comments
 (0)