Skip to content

Commit abbd9ba

Browse files
authored
Avoid using strings in nconstyle checks and errors (#219)
* Avoid using strings and update nconerror function names * Fix typo * Change docstring to comment to make documenter happy * apply code suggestions
1 parent 0d241b1 commit abbd9ba

File tree

4 files changed

+17
-22
lines changed

4 files changed

+17
-22
lines changed

ext/TensorOperationsChainRulesCoreExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ end
340340
# NCON functions
341341
@non_differentiable TensorOperations.ncontree(args...)
342342
@non_differentiable TensorOperations.nconoutput(args...)
343-
@non_differentiable TensorOperations.isnconstyle(args...)
343+
@non_differentiable TensorOperations.check_nconstyle(args...)
344344
@non_differentiable TensorOperations.indexordertree(args...)
345345

346346
end

src/implementation/ncon.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function ncon(
2828
)
2929
length(tensors) == length(network) == length(conjlist) ||
3030
throw(ArgumentError("number of tensors and of index lists should be the same"))
31-
nconstylecheck(network) # asserts that the network is in ncon style
31+
check_nconstyle(network)
3232
output′ = nconoutput(network, output)
3333

3434
if length(tensors) == 1

src/indexnotation/contractiontrees.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ function defaulttreesorter(args, tree, depth)
215215
end
216216

217217
function defaulttreebuilder(network)
218-
if isnconstyle(network)
218+
if check_nconstyle(Bool, network)
219219
tree = ncontree(network)
220220
else
221221
tree = Any[1, 2]

src/indexnotation/ncontree.jl

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,32 @@
1-
const NCONSTYLE = "Valid ncon style network"
1+
# Verify if a list of indices specifies a tensor contraction in ncon style.
2+
check_nconstyle(::Type{Bool}, network) = _check_nconstyle(network, Val(true))
3+
check_nconstyle(network) = _check_nconstyle(network, Val(false))
24

3-
# check if a list of indices specifies a tensor contraction in ncon style
4-
function isnconstyle(network)
5-
return _nconstyle_error(network) == NCONSTYLE
6-
end
7-
8-
function _nconstyle_error(network)
5+
function _check_nconstyle(network, ::Val{check}) where {check}
96
allindices = Vector{Int}()
107
for ind in network
11-
all(i -> isa(i, Integer), ind) || return "All indices must be integers"
8+
all(i -> isa(i, Integer), ind) || return check ? false :
9+
throw(IndexError("All indices must be integers"))
1210
append!(allindices, ind)
1311
end
1412
while length(allindices) > 0
1513
i = pop!(allindices)
1614
if i > 0 # positive labels represent contractions or traces and should appear twice
1715
k = findfirst(isequal(i), allindices)
18-
k === nothing && return "Index $i appears only once in the network"
16+
isnothing(k) && return check ? false :
17+
throw(IndexError(lazy"Index $i appears only once in the network"))
1918
l = findnext(isequal(i), allindices, k + 1)
20-
l !== nothing && return "Index $i appears more than twice in the network"
19+
!isnothing(l) && return check ? false :
20+
throw(IndexError(lazy"Index $i appears more than twice in the network"))
2121
deleteat!(allindices, k)
2222
elseif i < 0 # negative labels represent open indices and should appear once
23-
findfirst(isequal(i), allindices) === nothing || return "Index $i appears more than once in the network"
23+
isnothing(findfirst(isequal(i), allindices)) || return check ? false :
24+
throw(IndexError(lazy"Index $i appears more than once in the network"))
2425
else # i == 0
25-
return "Index 0 is not allowed in the network"
26+
return check ? false : throw(IndexError("Index 0 is not allowed in the network"))
2627
end
2728
end
28-
return NCONSTYLE
29-
end
30-
31-
function nconstylecheck(network)
32-
err = _nconstyle_error(network)
33-
err === NCONSTYLE || throw(ArgumentError(err))
34-
return nothing
29+
return check ? true : nothing
3530
end
3631

3732
function ncontree(network)

0 commit comments

Comments
 (0)