Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ITensors] [BUG] Automatic differentiation #73

Open
ArtemStrashko opened this issue Jun 3, 2022 · 4 comments
Open

[ITensors] [BUG] Automatic differentiation #73

ArtemStrashko opened this issue Jun 3, 2022 · 4 comments
Labels
bug Something isn't working

Comments

@ArtemStrashko
Copy link

Description of bug

Weird behavior of automatic differentiation: error is thrown for a function, but not for a slightly rewritten (though analogous) function or for a bit more complicated one.

Minimal code demonstrating the bug or unexpected behavior

Minimal runnable code

using ChainRulesCore
using Zygote
using ITensors
@non_differentiable onehot(::Any...)

indices = [Index(2) for _ in 1:10]
mps = randomMPS(indices; linkdims=10)
state = [randomITensor(index) for index in indices]

The function below leads to an error.

function f(mps, state)
    N = length(mps)
    L = 0.
    tens = 1.
    for i in 1:N
        tens *= mps[i]
        if i < N
            idx = commonind(mps[i], mps[i+1])
            partial_mps = tens * onehot(idx => 1)
        else
            partial_mps = tens
        end
        L += (partial_mps * state[i])[]
        tens *= state[i]
    end
    return L
end

f(x) = f(x, state)
f'(mps)
# DimensionMismatch("cannot add ITensors with different numbers of indices")

However, if I just add two extra lines to that function, differentiation works well.

function f(mps, state)
    N = length(mps)
    L = 0.
    tens = 1.
    for i in 1:N
        tens *= mps[i]
        if i < N
            idx = commonind(mps[i], mps[i+1])
            partial_mps = tens * onehot(idx => 1)
        else
            partial_mps = tens
        end
        phys_idx = commonind(mps[i], state[i])
        L -= sum([(partial_mps * onehot(phys_idx => j))[] for j in 1:1])
        L += (partial_mps * state[i])[]
        tens *= state[i]
    end
    return L
end

f(x) = f(x, state)
f'(mps)
# works well

Finally, if I change just one line of that function, which does not affect its output, differentiation results in another error.

function f(mps, state)
    N = length(mps)
    L = 0.
    tens = 1.
    for i in 1:N
        tens *= mps[i]
        if i < N
            idx = commonind(mps[i], mps[i+1])
            partial_mps = tens * onehot(idx => 1)
        else
            partial_mps = tens
        end
        phys_idx = commonind(mps[i], state[i])
        L -= (partial_mps * onehot(phys_idx => 1))[]
        L += (partial_mps * state[i])[]
        tens *= state[i]
    end
    return L
end

f(x) = f(x, state)
f'(mps)
# DimensionMismatch("cannot add ITensors with different numbers of indices")

Expected output or behavior

I would expect to get no errors when setting up automatic differentiation with the functions above.

Version information

  • Output from versioninfo():
julia> versioninfo()
Julia Version 1.7.0
Commit 3bf9d17731* (2021-11-30 12:12 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: AMD EPYC 7742 64-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-12.0.1 (ORCJIT, znver2)
Environment:
  JULIA_NUM_THREADS = 128
  • Output from using Pkg; Pkg.status("ITensors"):
julia> using Pkg; Pkg.status("ITensors")
Status `~/.julia/environments/v1.7/Project.toml`
  [9136182c] ITensors v0.2.16
@ArtemStrashko ArtemStrashko added the bug Something isn't working label Jun 3, 2022
@mtfishman
Copy link
Member

Very strange bug, thanks @ArtemStrashko. By the way, I would recommend updating to the latest version of ITensors (0.3.14), though I see this bug in the latest version as well so that won't fix this issue.

@mtfishman
Copy link
Member

This seems to fix the initial error:

using Zygote
using ITensors

indices = [Index(2) for _ in 1:10]
mps = randomMPS(indices; linkdims=10)
state = [randomITensor(index) for index in indices]

function f(mps, state)
  N = length(mps)
  L = 0.
  tens = 1.
  for i in 1:N
    tens *= mps[i]
    proj = if i < N
      idx = commonind(mps[i], mps[i+1])
      onehot(idx => 1)
    else
      ITensor(1.0)
    end
    partial_mps = tens * proj
    L += (partial_mps * state[i])[]
    tens *= state[i]
  end
  return L
end

f(x) = f(x, state)
f'(mps)

It seems like the issue stems from putting something differentiable (like tens) inside of the if-statement. I'm not sure why that's the case, perhaps there's an issue with one of our ChainRules definitions, or it is a Zygote issue.

@mtfishman
Copy link
Member

Actually, there's a more minimal fix. These both run without errors for me:

using Zygote
using ITensors

indices = [Index(2) for _ in 1:10]
mps = randomMPS(indices; linkdims=10)
state = [randomITensor(index) for index in indices]

function f(mps, state)
  N = length(mps)
  L = 0.
  tens = 1.
  for i in 1:N
    tens *= mps[i]
    if i < N
      idx = commonind(mps[i], mps[i+1])
      partial_mps = tens * onehot(idx => 1)
    else
      partial_mps = tens * ITensor(1.0)
    end
    L += (partial_mps * state[i])[]
    tens *= state[i]
  end
  return L
end

f(x) = f(x, state)
f'(mps)

and

using Zygote
using ITensors

indices = [Index(2) for _ in 1:10]
mps = randomMPS(indices; linkdims=10)
state = [randomITensor(index) for index in indices]

function f(mps, state)
  N = length(mps)
  L = 0.
  tens = 1.
  for i in 1:N
    tens *= mps[i]
    if i < N
      idx = commonind(mps[i], mps[i+1])
      partial_mps = tens * onehot(idx => 1)
    else
      partial_mps = tens * ITensor(1.0)
    end
    phys_idx = commonind(mps[i], state[i])
    L -= (partial_mps * onehot(phys_idx => 1))[]
    L += (partial_mps * state[i])[]
    tens *= state[i]
  end
  return L 
end

f(x) = f(x, state)
f'(mps)

The only change I made is rewriting:

      partial_mps = tens

to:

      partial_mps = tens * ITensor(1.0)

I can't say I understand why this is happening. It would be helpful to concoct a more minimal example to investigate this better.

@mtfishman
Copy link
Member

mtfishman commented Jun 3, 2022

So oddly enough this seems to be a Zygote bug:

using FiniteDifferences
using Zygote

function f(x)
  y = [[x]', [x]]
  r = 0.0
  o = 1.0
  for n in 1:2
    o *= y[n]
    if n < 2
      proj_o = o * [1.0]
    else
      # Error
      proj_o = o
      # Fix
      # proj_o = o * 1.0
    end
    r += proj_o
  end
  return r
end

x = 1.2
@show f(x)
@show central_fdm(5, 1)(f, x)
@show f'(x)

which throws the error:

f(x) = 2.6399999999999997
(central_fdm(5, 1))(f, x) = 3.4000000000000967
ERROR: LoadError: MethodError: no method matching +(::Float64, ::LinearAlgebra.Adjoint{Float64, Vector{Float64}})
For element-wise addition, use broadcasting with dot syntax: scalar .+ array
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at ~/software/julia-1.7.3/share/julia/base/operators.jl:655
  +(::Union{Float16, Float32, Float64}, ::BigFloat) at ~/software/julia-1.7.3/share/julia/base/mpfr.jl:413
  +(::ChainRulesCore.Tangent{P}, ::P) where P at ~/.julia/packages/ChainRulesCore/GUvJT/src/tangent_arithmetic.jl:146
  ...
Stacktrace:
 [1] accum(x::Float64, y::LinearAlgebra.Adjoint{Float64, Vector{Float64}})
   @ Zygote ~/.julia/packages/Zygote/DkIUK/src/lib/lib.jl:17
[...]

while with the fix it outputs:

f(x) = 2.6399999999999997
(central_fdm(5, 1))(f, x) = 3.4000000000000967
(f')(x) = 3.4000000000000004

@mtfishman mtfishman transferred this issue from ITensor/ITensors.jl Oct 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants