@@ -31,7 +31,8 @@ using cuTENSOR.CUDA: with_workspace, default_stream
31
31
# this might be dependency-piracy, but removes a dependency from the main package
32
32
using cuTENSOR. CUDA. Adapt: adapt
33
33
34
- function TensorOperations. tensorscalar (C:: CuArray )
34
+ const TO = TensorOperations
35
+ function TO. tensorscalar (C:: CuArray )
35
36
return ndims (C) == 0 ? tensorscalar (collect (C)) : throw (DimensionMismatch ())
36
37
end
37
38
50
51
# Operations
51
52
# -------------------------------------------------------------------------------------------
52
53
53
- function TensorOperations . tensoradd! (C:: CuArray , pC:: Index2Tuple ,
54
- A :: CuArray , conjA :: Symbol , α:: Number , β:: Number )
55
- TensorOperations . argcheck_tensoradd (C, pC, A)
54
+ function TO . tensoradd! (C:: CuArray , pC:: Index2Tuple , A :: CuArray , conjA :: Symbol ,
55
+ α:: Number , β:: Number )
56
+ TO . argcheck_tensoradd (C, pC, A)
56
57
57
58
T = eltype (C)
58
59
conjA == :N || conjA == :C ||
@@ -78,12 +79,12 @@ function TensorOperations.tensoradd!(C::CuArray, pC::Index2Tuple,
78
79
return C
79
80
end
80
81
81
- function TensorOperations . tensorcontract! (C:: CuArray , pC:: Index2Tuple ,
82
- A:: CuArray , pA:: Index2Tuple , conjA:: Symbol ,
83
- B:: CuArray , pB:: Index2Tuple , conjB:: Symbol ,
84
- α, β)
85
- TensorOperations . argcheck_tensorcontract (C, pC, A, pA, B, pB)
86
- TensorOperations . dimcheck_tensorcontract (C, pC, A, pA, B, pB)
82
+ function TO . tensorcontract! (C:: CuArray , pC:: Index2Tuple ,
83
+ A:: CuArray , pA:: Index2Tuple , conjA:: Symbol ,
84
+ B:: CuArray , pB:: Index2Tuple , conjB:: Symbol ,
85
+ α, β)
86
+ TO . argcheck_tensorcontract (C, pC, A, pA, B, pB)
87
+ TO . dimcheck_tensorcontract (C, pC, A, pA, B, pB)
87
88
88
89
conjA == :N || conjA == :C ||
89
90
throw (ArgumentError (" Value of conjA should be :N or :C instead of $conjA " ))
@@ -98,9 +99,9 @@ function TensorOperations.tensorcontract!(C::CuArray, pC::Index2Tuple,
98
99
99
100
typeCompute = cutensorComputeType (T)
100
101
101
- NoA = TensorOperations . numout (pA)
102
- NoB = TensorOperations . numin (pB)
103
- Nc = TensorOperations . numin (pA)
102
+ NoA = TO . numout (pA)
103
+ NoB = TO . numin (pB)
104
+ Nc = TO . numin (pA)
104
105
105
106
modeoA = ntuple (n -> n, NoA)
106
107
modeoB = ntuple (n -> NoA + n, NoB)
@@ -148,9 +149,9 @@ function TensorOperations.tensorcontract!(C::CuArray, pC::Index2Tuple,
148
149
return C
149
150
end
150
151
151
- function TensorOperations . tensortrace! (C:: CuArray , pC:: Index2Tuple ,
152
- A:: CuArray , pA:: Index2Tuple , conjA:: Symbol , α, β)
153
- TensorOperations . argcheck_tensortrace (C, pC, A, pA)
152
+ function TO . tensortrace! (C:: CuArray , pC:: Index2Tuple ,
153
+ A:: CuArray , pA:: Index2Tuple , conjA:: Symbol , α, β)
154
+ TO . argcheck_tensortrace (C, pC, A, pA)
154
155
T = eltype (C)
155
156
NA, NC = ndims (A), ndims (C)
156
157
@@ -197,92 +198,93 @@ end
197
198
# Allocations
198
199
# -------------------------------------------------------------------------------------------
199
200
200
- function TensorOperations . tensoradd_type (TC, pC:: Index2Tuple , :: CuArray , conjA:: Symbol )
201
- return CuArray{TC,TensorOperations . numind (pC)}
201
+ function TO . tensoradd_type (TC, pC:: Index2Tuple , :: CuArray , conjA:: Symbol )
202
+ return CuArray{TC,TO . numind (pC)}
202
203
end
203
204
204
- function TensorOperations . tensorcontract_type (TC, pC:: Index2Tuple , :: CuArray ,
205
- pA:: Index2Tuple , conjA:: Symbol , :: CuArray ,
206
- pB:: Index2Tuple , conjB:: Symbol )
207
- return CuArray{TC,TensorOperations . numind (pC)}
205
+ function TO . tensorcontract_type (TC, pC:: Index2Tuple ,
206
+ :: CuArray , pA:: Index2Tuple , conjA:: Symbol ,
207
+ :: CuArray , pB:: Index2Tuple , conjB:: Symbol )
208
+ return CuArray{TC,TO . numind (pC)}
208
209
end
209
210
210
211
# -------------------------------------------------------------------------------------------
211
212
# Backend
212
213
# -------------------------------------------------------------------------------------------
213
214
214
- const cuTENSORBackend = TensorOperations . Backend{:cuTENSOR }
215
+ const cuTENSORBackend = TO . Backend{:cuTENSOR }
215
216
216
- function TensorOperations . tensoradd! (C:: AbstractArray , pC:: Index2Tuple ,
217
- A:: AbstractArray , conjA:: Symbol , α:: Number , β:: Number ,
218
- backend:: cuTENSORBackend )
217
+ function TO . tensoradd! (C:: AbstractArray , pC:: Index2Tuple ,
218
+ A:: AbstractArray , conjA:: Symbol , α:: Number , β:: Number ,
219
+ backend:: cuTENSORBackend )
219
220
C_cuda = adapt (CuArray, C)
220
221
tensoradd! (C_cuda, pC, A, conjA, α, β, backend)
221
222
copyto! (C, collect (C_cuda))
222
223
return C
223
224
end
224
225
225
- function TensorOperations . tensoradd! (C:: CuArray , pC:: Index2Tuple ,
226
- A:: AbstractArray , conjA:: Symbol , α:: Number , β:: Number ,
227
- :: cuTENSORBackend )
226
+ function TO . tensoradd! (C:: CuArray , pC:: Index2Tuple ,
227
+ A:: AbstractArray , conjA:: Symbol , α:: Number , β:: Number ,
228
+ :: cuTENSORBackend )
228
229
return tensoradd! (C, pC, adapt (CuArray, A), conjA, α, β)
229
230
end
230
231
231
- function TensorOperations . tensorcontract! (C:: AbstractArray , pC:: Index2Tuple ,
232
- A:: AbstractArray , pA:: Index2Tuple , conjA:: Symbol ,
233
- B:: AbstractArray , pB:: Index2Tuple , conjB:: Symbol ,
234
- α, β, backend:: cuTENSORBackend )
232
+ function TO . tensorcontract! (C:: AbstractArray , pC:: Index2Tuple ,
233
+ A:: AbstractArray , pA:: Index2Tuple , conjA:: Symbol ,
234
+ B:: AbstractArray , pB:: Index2Tuple , conjB:: Symbol ,
235
+ α, β, backend:: cuTENSORBackend )
235
236
C_cuda = adapt (CuArray, C)
236
237
tensorcontract! (C_cuda, pC, A, pA, conjA, B, pB, conjB, α, β, backend)
237
238
copyto! (C, collect (C_cuda))
238
239
return C
239
240
end
240
- function TensorOperations . tensorcontract! (C:: CuArray , pC:: Index2Tuple ,
241
- A:: AbstractArray , pA:: Index2Tuple , conjA:: Symbol ,
242
- B:: AbstractArray , pB:: Index2Tuple , conjB:: Symbol ,
243
- α, β, :: cuTENSORBackend )
241
+ function TO . tensorcontract! (C:: CuArray , pC:: Index2Tuple ,
242
+ A:: AbstractArray , pA:: Index2Tuple , conjA:: Symbol ,
243
+ B:: AbstractArray , pB:: Index2Tuple , conjB:: Symbol ,
244
+ α, β, :: cuTENSORBackend )
244
245
return tensorcontract! (C, pC, adapt (CuArray, A), pA, conjA, adapt (CuArray, B), pB,
245
246
conjB, α, β)
246
247
end
247
248
248
- function TensorOperations . tensortrace! (C:: AbstractArray , pC:: Index2Tuple ,
249
- A:: AbstractArray , pA:: Index2Tuple , conjA:: Symbol ,
250
- α, β, backend:: cuTENSORBackend )
249
+ function TO . tensortrace! (C:: AbstractArray , pC:: Index2Tuple ,
250
+ A:: AbstractArray , pA:: Index2Tuple , conjA:: Symbol ,
251
+ α, β, backend:: cuTENSORBackend )
251
252
C_cuda = adapt (CuArray, C)
252
253
tensortrace! (C_cuda, pC, A, pA, conjA, α, β, backend)
253
254
copyto! (C, collect (C_cuda))
254
255
return C
255
256
end
256
- function TensorOperations . tensortrace! (C:: CuArray , pC:: Index2Tuple ,
257
- A:: AbstractArray , pA:: Index2Tuple , conjA:: Symbol ,
258
- α, β, :: cuTENSORBackend )
257
+ function TO . tensortrace! (C:: CuArray , pC:: Index2Tuple ,
258
+ A:: AbstractArray , pA:: Index2Tuple , conjA:: Symbol ,
259
+ α, β, :: cuTENSORBackend )
259
260
return tensortrace! (C, pC, adapt (CuArray, A), pA, conjA, α, β)
260
261
end
261
262
262
- function TensorOperations . tensoradd_type (TC, pC:: Index2Tuple , :: AbstractArray ,
263
- conjA:: Symbol , :: cuTENSORBackend )
264
- return CuArray{TC,TensorOperations . numind (pC)}
263
+ function TO . tensoradd_type (TC, pC:: Index2Tuple , :: AbstractArray ,
264
+ conjA:: Symbol , :: cuTENSORBackend )
265
+ return CuArray{TC,TO . numind (pC)}
265
266
end
266
267
267
- function TensorOperations . tensorcontract_type (TC, pC:: Index2Tuple , :: AbstractArray ,
268
- pA:: Index2Tuple , conjA:: Symbol ,
269
- :: AbstractArray ,
270
- pB :: Index2Tuple , conjB :: Symbol , :: cuTENSORBackend )
271
- return CuArray{TC,TensorOperations . numind (pC)}
268
+ function TO . tensorcontract_type (TC, pC:: Index2Tuple ,
269
+ :: AbstractArray , pA:: Index2Tuple , conjA:: Symbol ,
270
+ :: AbstractArray , pB :: Index2Tuple , conjB :: Symbol ,
271
+ :: cuTENSORBackend )
272
+ return CuArray{TC,TO . numind (pC)}
272
273
end
273
274
274
- function TensorOperations . tensoralloc_add (TC, pC, A:: AbstractArray , conjA, istemp,
275
- :: cuTENSORBackend )
276
- ttype = CuArray{TC,TensorOperations . numind (pC)}
277
- structure = TensorOperations . tensoradd_structure (pC, A, conjA)
275
+ function TO . tensoralloc_add (TC, pC, A:: AbstractArray , conjA, istemp,
276
+ :: cuTENSORBackend )
277
+ ttype = CuArray{TC,TO . numind (pC)}
278
+ structure = TO . tensoradd_structure (pC, A, conjA)
278
279
return tensoralloc (ttype, structure, istemp):: ttype
279
280
end
280
281
281
- function TensorOperations. tensoralloc_contract (TC, pC, A:: AbstractArray , pA, conjA,
282
- B:: AbstractArray , pB, conjB, istemp,
283
- :: cuTENSORBackend )
284
- ttype = CuArray{TC,TensorOperations. numind (pC)}
285
- structure = TensorOperations. tensorcontract_structure (pC, A, pA, conjA, B, pB, conjB)
282
+ function TO. tensoralloc_contract (TC, pC,
283
+ A:: AbstractArray , pA, conjA,
284
+ B:: AbstractArray , pB, conjB,
285
+ istemp, :: cuTENSORBackend )
286
+ ttype = CuArray{TC,TO. numind (pC)}
287
+ structure = TO. tensorcontract_structure (pC, A, pA, conjA, B, pB, conjB)
286
288
return tensoralloc (ttype, structure, istemp):: ttype
287
289
end
288
290
0 commit comments