Skip to content

Commit 11245fe

Browse files
authored
fix instantiation bug (#155)
* fix instantiation bug * some more parser updates * format and update CI * more parser/instantiator updates * disable polynomial tests and cleanup * add scaltype to instantiators/restore label tools * update project.toml and aqua
1 parent e92d25c commit 11245fe

File tree

13 files changed

+176
-174
lines changed

13 files changed

+176
-174
lines changed

.appveyor.yml

Lines changed: 0 additions & 36 deletions
This file was deleted.

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
os:
2727
- ubuntu-latest
2828
- macOS-latest
29-
# - windows-latest # run on AppVeyor instead
29+
- windows-latest
3030
arch:
3131
- x64
3232
- x86

Project.toml

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TensorOperations"
22
uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
33
authors = ["Lukas Devos <lukas.devos@ugent.be>", "Maarten Van Damme <maartenvd1994@gmail.com>", "Jutho Haegeman <jutho.haegeman@ugent.be>"]
4-
version = "4.0.7"
4+
version = "4.0.8"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
@@ -15,29 +15,29 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
1515
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
1616
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
1717

18-
[weakdeps]
19-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
20-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
21-
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
22-
23-
[extensions]
24-
TensorOperationsChainRulesCoreExt = "ChainRulesCore"
25-
TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"]
26-
2718
[compat]
28-
Aqua = "0.6, 0.7"
19+
Aqua = "0.6, 0.7, 0.8"
2920
CUDA = "4,5"
3021
ChainRulesCore = "1"
22+
ChainRulesTestUtils = "1"
3123
DynamicPolynomials = "0.5"
3224
LRUCache = "1"
25+
LinearAlgebra = "1.6"
26+
Logging = "1.6"
3327
PackageExtensionCompat = "1"
28+
Random = "1"
3429
Strided = "2.0.4"
3530
StridedViews = "0.2"
31+
Test = "1"
3632
TupleTools = "1.1"
3733
VectorInterface = "0.4.1"
3834
cuTENSOR = "1"
3935
julia = "1.6"
4036

37+
[extensions]
38+
TensorOperationsChainRulesCoreExt = "ChainRulesCore"
39+
TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"]
40+
4141
[extras]
4242
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4343
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
@@ -50,3 +50,8 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
5050

5151
[targets]
5252
test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging"]
53+
54+
[weakdeps]
55+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
56+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
57+
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ Fast tensor operations using a convenient Einstein index notation.
1212
|:----------------:|:------------:|:------------:|:---------------------:|
1313
| [![CI][ci-img]][ci-url] | [![PkgEval][pkgeval-img]][pkgeval-url] | [![Codecov][codecov-img]][codecov-url] | [![Aqua QA][aqua-img]][aqua-url] |
1414

15-
1615
[docs-stable-img]: https://img.shields.io/badge/docs-stable-blue.svg
1716
[docs-stable-url]: https://jutho.github.io/TensorOperations.jl/stable
1817

src/implementation/indices.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ function contract_indices(IA::NTuple{NA,Any}, IB::NTuple{NB,Any},
108108
end
109109

110110
#-------------------------------------------------------------------------------------------
111-
# Generate index information
111+
# Convert indices to einsum labels: useful for package extensions / add-ons (e.g. TBLIS)
112112
#-------------------------------------------------------------------------------------------
113113
const OFFSET_OPEN = 'a' - 1
114114
const OFFSET_CLOSED = 'A' - 1

src/indexnotation/analyzers.jl

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,11 @@ function decomposegeneraltensor(ex)
5353
count(a -> isgeneraltensor(a), ex.args) == 1 # scalar multiplication: multiply scalar factors
5454
idx = findfirst(a -> isgeneraltensor(a), ex.args)
5555
(object, leftind, rightind, α, conj) = decomposegeneraltensor(ex.args[idx])
56-
scalars = Expr(:call)
57-
append!(scalars.args, deepcopy(ex.args))
58-
scalars.args[idx] = α
59-
return (object, leftind, rightind, scalars, conj)
56+
scalar = One()
57+
for i in 2:length(ex.args)
58+
scalar = simplify_scalarmul(scalar, i == idx ? α : ex.args[i])
59+
end
60+
return (object, leftind, rightind, scalar, conj)
6061
elseif ex.args[1] == :/ && length(ex.args) == 3 # scalar multiplication: muliply scalar factors
6162
if isscalarexpr(ex.args[3]) && isgeneraltensor(ex.args[2])
6263
(object, leftind, rightind, α, conj) = decomposegeneraltensor(ex.args[2])
@@ -208,6 +209,26 @@ function getallindices(ex)
208209
return Any[]
209210
end
210211

212+
function simplify_scalarmul(exa, exb)
213+
if exa === One()
214+
return exb
215+
elseif exb === One()
216+
return exa
217+
end
218+
if isexpr(exa, :call) && exa.args[1] == :* && isexpr(exb, :call) && exb.args[1] == :*
219+
return Expr(:call, :*, exa.args[2:end]..., exb.args[2:end]...)
220+
elseif isexpr(exa, :call) && exa.args[1] == *
221+
return Expr(:call, :*, exa.args[2:end]..., exb)
222+
elseif isexpr(exb, :call) && exb.args[1] == *
223+
return Expr(:call, :*, exa, exb.args[2:end]...)
224+
else
225+
return Expr(:call, :*, exa, exb)
226+
end
227+
end
228+
function simplify_scalarmul(exa, exb, exc, exd...)
229+
return simplify_scalarmul(simplify_scalarmul(exa, exb), exc, exd...)
230+
end
231+
211232
# # auxiliary routine
212233
# function unique2(itr)
213234
# out = collect(itr)

src/indexnotation/contractiontrees.jl

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ function processcontractions(ex, treebuilder, treesorter, costcheck)
2121
elseif isexpr(ex, :macrocall) && ex.args[1] == Symbol("@notensor")
2222
return ex
2323
elseif isexpr(ex, :call) && ex.args[1] == :tensorscalar
24-
return Expr(:call, :tensorscalar,
25-
processcontractions(ex.args[2], treebuilder, treesorter, costcheck))
24+
return processcontractions(ex.args[2], treebuilder, treesorter, costcheck)
25+
# `tensorscalar` will be reinserted automatically
2626
elseif isassignment(ex) || isdefinition(ex)
2727
lhs, rhs = getlhs(ex), getrhs(ex)
2828
rhs, pre, post = _processcontractions(rhs, treebuilder, treesorter, costcheck)
@@ -57,14 +57,23 @@ end
5757

5858
function insertcontractiontrees!(ex, treebuilder, treesorter, costcheck, preexprs,
5959
postexprs)
60+
if isexpr(ex, :call) && ex.args[1] == :tensorscalar
61+
return insertcontractiontrees!(ex.args[2], treebuilder, treesorter, costcheck,
62+
preexprs, postexprs)
63+
end
6064
if isexpr(ex, :call)
6165
args = ex.args
6266
nargs = length(args)
6367
ex = Expr(:call, args[1],
6468
(insertcontractiontrees!(args[i], treebuilder, treesorter, costcheck,
6569
preexprs, postexprs) for i in 2:nargs)...)
6670
end
67-
if istensorcontraction(ex) && length(ex.args) > 3
71+
if !istensorcontraction(ex)
72+
return ex
73+
end
74+
if length(ex.args) <= 3
75+
return isempty(getindices(ex)) ? Expr(:call, :tensorscalar, ex) : ex
76+
else
6877
args = ex.args[2:end]
6978
network = map(getindices, args)
7079
for a in getallindices(ex)
@@ -137,7 +146,6 @@ function insertcontractiontrees!(ex, treebuilder, treesorter, costcheck, preexpr
137146
push!(postexprs, removelinenumbernode(costcompareex))
138147
return treeex
139148
end
140-
return ex
141149
end
142150

143151
function treecost(tree, network, costs)
@@ -175,9 +183,14 @@ function defaulttreesorter(args, tree)
175183
if isa(tree, Int)
176184
return args[tree]
177185
else
178-
return Expr(:call, :*,
179-
defaulttreesorter(args, tree[1]),
180-
defaulttreesorter(args, tree[2]))
186+
ex = Expr(:call, :*,
187+
defaulttreesorter(args, tree[1]),
188+
defaulttreesorter(args, tree[2]))
189+
if isempty(getindices(ex))
190+
return Expr(:call, :tensorscalar, ex)
191+
else
192+
return ex
193+
end
181194
end
182195
end
183196

0 commit comments

Comments
 (0)