Open
Description
Description of bug
Weird behavior of automatic differentiation: error is thrown for a function, but not for a slightly rewritten (though analogous) function or for a bit more complicated one.
Minimal code demonstrating the bug or unexpected behavior
Minimal runnable code
using ChainRulesCore
using Zygote
using ITensors
@non_differentiable onehot(::Any...)
indices = [Index(2) for _ in 1:10]
mps = randomMPS(indices; linkdims=10)
state = [randomITensor(index) for index in indices]
The function below leads to an error.
function f(mps, state)
N = length(mps)
L = 0.
tens = 1.
for i in 1:N
tens *= mps[i]
if i < N
idx = commonind(mps[i], mps[i+1])
partial_mps = tens * onehot(idx => 1)
else
partial_mps = tens
end
L += (partial_mps * state[i])[]
tens *= state[i]
end
return L
end
f(x) = f(x, state)
f'(mps)
# DimensionMismatch("cannot add ITensors with different numbers of indices")
However, if I just add two extra lines to that function, differentiation works well.
function f(mps, state)
N = length(mps)
L = 0.
tens = 1.
for i in 1:N
tens *= mps[i]
if i < N
idx = commonind(mps[i], mps[i+1])
partial_mps = tens * onehot(idx => 1)
else
partial_mps = tens
end
phys_idx = commonind(mps[i], state[i])
L -= sum([(partial_mps * onehot(phys_idx => j))[] for j in 1:1])
L += (partial_mps * state[i])[]
tens *= state[i]
end
return L
end
f(x) = f(x, state)
f'(mps)
# works well
Finally, if I change just one line of that function, which does not affect its output, differentiation results in another error.
function f(mps, state)
N = length(mps)
L = 0.
tens = 1.
for i in 1:N
tens *= mps[i]
if i < N
idx = commonind(mps[i], mps[i+1])
partial_mps = tens * onehot(idx => 1)
else
partial_mps = tens
end
phys_idx = commonind(mps[i], state[i])
L -= (partial_mps * onehot(phys_idx => 1))[]
L += (partial_mps * state[i])[]
tens *= state[i]
end
return L
end
f(x) = f(x, state)
f'(mps)
# DimensionMismatch("cannot add ITensors with different numbers of indices")
Expected output or behavior
I would expect to get no errors when setting up automatic differentiation with the functions above.
Version information
- Output from
versioninfo()
:
julia> versioninfo()
Julia Version 1.7.0
Commit 3bf9d17731* (2021-11-30 12:12 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: AMD EPYC 7742 64-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-12.0.1 (ORCJIT, znver2)
Environment:
JULIA_NUM_THREADS = 128
- Output from
using Pkg; Pkg.status("ITensors")
:
julia> using Pkg; Pkg.status("ITensors")
Status `~/.julia/environments/v1.7/Project.toml`
[9136182c] ITensors v0.2.16
Activity