Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 0 additions & 36 deletions .appveyor.yml

This file was deleted.

2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
os:
- ubuntu-latest
- macOS-latest
# - windows-latest # run on AppVeyor instead
- windows-latest
arch:
- x64
- x86
Expand Down
27 changes: 16 additions & 11 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorOperations"
uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
authors = ["Lukas Devos <lukas.devos@ugent.be>", "Maarten Van Damme <maartenvd1994@gmail.com>", "Jutho Haegeman <jutho.haegeman@ugent.be>"]
version = "4.0.7"
version = "4.0.8"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand All @@ -15,29 +15,29 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[extensions]
TensorOperationsChainRulesCoreExt = "ChainRulesCore"
TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"]

[compat]
Aqua = "0.6, 0.7"
Aqua = "0.6, 0.7, 0.8"
CUDA = "4,5"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
DynamicPolynomials = "0.5"
LRUCache = "1"
LinearAlgebra = "1.6"
Logging = "1.6"
PackageExtensionCompat = "1"
Random = "1"
Strided = "2.0.4"
StridedViews = "0.2"
Test = "1"
TupleTools = "1.1"
VectorInterface = "0.4.1"
cuTENSOR = "1"
julia = "1.6"

[extensions]
TensorOperationsChainRulesCoreExt = "ChainRulesCore"
TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"]

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand All @@ -50,3 +50,8 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[targets]
test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging"]

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ Fast tensor operations using a convenient Einstein index notation.
|:----------------:|:------------:|:------------:|:---------------------:|
| [![CI][ci-img]][ci-url] | [![PkgEval][pkgeval-img]][pkgeval-url] | [![Codecov][codecov-img]][codecov-url] | [![Aqua QA][aqua-img]][aqua-url] |


[docs-stable-img]: https://img.shields.io/badge/docs-stable-blue.svg
[docs-stable-url]: https://jutho.github.io/TensorOperations.jl/stable

Expand Down
2 changes: 1 addition & 1 deletion src/implementation/indices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ function contract_indices(IA::NTuple{NA,Any}, IB::NTuple{NB,Any},
end

#-------------------------------------------------------------------------------------------
# Generate index information
# Convert indices to einsum labels: useful for package extensions / add-ons (e.g. TBLIS)
#-------------------------------------------------------------------------------------------
const OFFSET_OPEN = 'a' - 1
const OFFSET_CLOSED = 'A' - 1
Expand Down
29 changes: 25 additions & 4 deletions src/indexnotation/analyzers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,11 @@ function decomposegeneraltensor(ex)
count(a -> isgeneraltensor(a), ex.args) == 1 # scalar multiplication: multiply scalar factors
idx = findfirst(a -> isgeneraltensor(a), ex.args)
(object, leftind, rightind, α, conj) = decomposegeneraltensor(ex.args[idx])
scalars = Expr(:call)
append!(scalars.args, deepcopy(ex.args))
scalars.args[idx] = α
return (object, leftind, rightind, scalars, conj)
scalar = One()
for i in 2:length(ex.args)
scalar = simplify_scalarmul(scalar, i == idx ? α : ex.args[i])
end
return (object, leftind, rightind, scalar, conj)
elseif ex.args[1] == :/ && length(ex.args) == 3 # scalar multiplication: muliply scalar factors
if isscalarexpr(ex.args[3]) && isgeneraltensor(ex.args[2])
(object, leftind, rightind, α, conj) = decomposegeneraltensor(ex.args[2])
Expand Down Expand Up @@ -208,6 +209,26 @@ function getallindices(ex)
return Any[]
end

function simplify_scalarmul(exa, exb)
if exa === One()
return exb
elseif exb === One()
return exa
end
if isexpr(exa, :call) && exa.args[1] == :* && isexpr(exb, :call) && exb.args[1] == :*
return Expr(:call, :*, exa.args[2:end]..., exb.args[2:end]...)
elseif isexpr(exa, :call) && exa.args[1] == *
return Expr(:call, :*, exa.args[2:end]..., exb)
elseif isexpr(exb, :call) && exb.args[1] == *
return Expr(:call, :*, exa, exb.args[2:end]...)
else
return Expr(:call, :*, exa, exb)
end
end
function simplify_scalarmul(exa, exb, exc, exd...)
return simplify_scalarmul(simplify_scalarmul(exa, exb), exc, exd...)
end

# # auxiliary routine
# function unique2(itr)
# out = collect(itr)
Expand Down
27 changes: 20 additions & 7 deletions src/indexnotation/contractiontrees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ function processcontractions(ex, treebuilder, treesorter, costcheck)
elseif isexpr(ex, :macrocall) && ex.args[1] == Symbol("@notensor")
return ex
elseif isexpr(ex, :call) && ex.args[1] == :tensorscalar
return Expr(:call, :tensorscalar,
processcontractions(ex.args[2], treebuilder, treesorter, costcheck))
return processcontractions(ex.args[2], treebuilder, treesorter, costcheck)
# `tensorscalar` will be reinserted automatically
elseif isassignment(ex) || isdefinition(ex)
lhs, rhs = getlhs(ex), getrhs(ex)
rhs, pre, post = _processcontractions(rhs, treebuilder, treesorter, costcheck)
Expand Down Expand Up @@ -57,14 +57,23 @@ end

function insertcontractiontrees!(ex, treebuilder, treesorter, costcheck, preexprs,
postexprs)
if isexpr(ex, :call) && ex.args[1] == :tensorscalar
return insertcontractiontrees!(ex.args[2], treebuilder, treesorter, costcheck,
preexprs, postexprs)
end
if isexpr(ex, :call)
args = ex.args
nargs = length(args)
ex = Expr(:call, args[1],
(insertcontractiontrees!(args[i], treebuilder, treesorter, costcheck,
preexprs, postexprs) for i in 2:nargs)...)
end
if istensorcontraction(ex) && length(ex.args) > 3
if !istensorcontraction(ex)
return ex
end
if length(ex.args) <= 3
return isempty(getindices(ex)) ? Expr(:call, :tensorscalar, ex) : ex
else
args = ex.args[2:end]
network = map(getindices, args)
for a in getallindices(ex)
Expand Down Expand Up @@ -137,7 +146,6 @@ function insertcontractiontrees!(ex, treebuilder, treesorter, costcheck, preexpr
push!(postexprs, removelinenumbernode(costcompareex))
return treeex
end
return ex
end

function treecost(tree, network, costs)
Expand Down Expand Up @@ -175,9 +183,14 @@ function defaulttreesorter(args, tree)
if isa(tree, Int)
return args[tree]
else
return Expr(:call, :*,
defaulttreesorter(args, tree[1]),
defaulttreesorter(args, tree[2]))
ex = Expr(:call, :*,
defaulttreesorter(args, tree[1]),
defaulttreesorter(args, tree[2]))
if isempty(getindices(ex))
return Expr(:call, :tensorscalar, ex)
else
return ex
end
end
end

Expand Down
Loading