-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ITensors] [BUG] Automatic differentiation #73
Comments
Very strange bug, thanks @ArtemStrashko. By the way, I would recommend updating to the latest version of ITensors (0.3.14), though I see this bug in the latest version as well so that won't fix this issue. |
This seems to fix the initial error: using Zygote
using ITensors
indices = [Index(2) for _ in 1:10]
mps = randomMPS(indices; linkdims=10)
state = [randomITensor(index) for index in indices]
function f(mps, state)
N = length(mps)
L = 0.
tens = 1.
for i in 1:N
tens *= mps[i]
proj = if i < N
idx = commonind(mps[i], mps[i+1])
onehot(idx => 1)
else
ITensor(1.0)
end
partial_mps = tens * proj
L += (partial_mps * state[i])[]
tens *= state[i]
end
return L
end
f(x) = f(x, state)
f'(mps) It seems like the issue stems from putting something differentiable (like |
Actually, there's a more minimal fix. These both run without errors for me: using Zygote
using ITensors
indices = [Index(2) for _ in 1:10]
mps = randomMPS(indices; linkdims=10)
state = [randomITensor(index) for index in indices]
function f(mps, state)
N = length(mps)
L = 0.
tens = 1.
for i in 1:N
tens *= mps[i]
if i < N
idx = commonind(mps[i], mps[i+1])
partial_mps = tens * onehot(idx => 1)
else
partial_mps = tens * ITensor(1.0)
end
L += (partial_mps * state[i])[]
tens *= state[i]
end
return L
end
f(x) = f(x, state)
f'(mps) and using Zygote
using ITensors
indices = [Index(2) for _ in 1:10]
mps = randomMPS(indices; linkdims=10)
state = [randomITensor(index) for index in indices]
function f(mps, state)
N = length(mps)
L = 0.
tens = 1.
for i in 1:N
tens *= mps[i]
if i < N
idx = commonind(mps[i], mps[i+1])
partial_mps = tens * onehot(idx => 1)
else
partial_mps = tens * ITensor(1.0)
end
phys_idx = commonind(mps[i], state[i])
L -= (partial_mps * onehot(phys_idx => 1))[]
L += (partial_mps * state[i])[]
tens *= state[i]
end
return L
end
f(x) = f(x, state)
f'(mps) The only change I made is rewriting: partial_mps = tens to: partial_mps = tens * ITensor(1.0) I can't say I understand why this is happening. It would be helpful to concoct a more minimal example to investigate this better. |
So oddly enough this seems to be a Zygote bug: using FiniteDifferences
using Zygote
function f(x)
y = [[x]', [x]]
r = 0.0
o = 1.0
for n in 1:2
o *= y[n]
if n < 2
proj_o = o * [1.0]
else
# Error
proj_o = o
# Fix
# proj_o = o * 1.0
end
r += proj_o
end
return r
end
x = 1.2
@show f(x)
@show central_fdm(5, 1)(f, x)
@show f'(x) which throws the error: f(x) = 2.6399999999999997
(central_fdm(5, 1))(f, x) = 3.4000000000000967
ERROR: LoadError: MethodError: no method matching +(::Float64, ::LinearAlgebra.Adjoint{Float64, Vector{Float64}})
For element-wise addition, use broadcasting with dot syntax: scalar .+ array
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...) at ~/software/julia-1.7.3/share/julia/base/operators.jl:655
+(::Union{Float16, Float32, Float64}, ::BigFloat) at ~/software/julia-1.7.3/share/julia/base/mpfr.jl:413
+(::ChainRulesCore.Tangent{P}, ::P) where P at ~/.julia/packages/ChainRulesCore/GUvJT/src/tangent_arithmetic.jl:146
...
Stacktrace:
[1] accum(x::Float64, y::LinearAlgebra.Adjoint{Float64, Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:17
[...] while with the fix it outputs: f(x) = 2.6399999999999997
(central_fdm(5, 1))(f, x) = 3.4000000000000967
(f')(x) = 3.4000000000000004 |
Description of bug
Weird behavior of automatic differentiation: error is thrown for a function, but not for a slightly rewritten (though analogous) function or for a bit more complicated one.
Minimal code demonstrating the bug or unexpected behavior
Minimal runnable code
The function below leads to an error.
However, if I just add two extra lines to that function, differentiation works well.
Finally, if I change just one line of that function, which does not affect its output, differentiation results in another error.
Expected output or behavior
I would expect to get no errors when setting up automatic differentiation with the functions above.
Version information
versioninfo()
:using Pkg; Pkg.status("ITensors")
:The text was updated successfully, but these errors were encountered: