Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: fix some mixed CPU/GPU transfer adjoint issues #571

Closed
wants to merge 1 commit into from

Conversation

ChrisRackauckas
Copy link
Member

Start fixing some of the issues like #401

Start fixing some of the issues like #401
@ChrisRackauckas
Copy link
Member Author

Zygote.pull_block_vert(sz, Δ::CUDA.CuArray, A::Number) = CUDA.@allowscalar Δ[sz] fixes:

using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, CUDA, DiffEqSensitivity, Test
CUDA.allowscalar(false) # Makes sure no slow operations are occuring

#generating exogenus signal and output signal
tspan = (0.1f0, Float32(10.0))
tsteps = range(tspan[1], tspan[2], length = 100)
t_vec = collect(tsteps)
ex = vec(ones(Float32,length(tsteps), 1))
f(x) = (atan(8.0 * x - 4.0) + atan(4.0)) / (2.0 * atan(4.0))

function hammerstein_system(u)
    y= zeros(size(u))
    for k in 2:length(u)
        y[k] = 0.2 * f(u[k-1]) + 0.8 * y[k-1]
    end
    return y
end

ex = vec([ones(Float32,50,1) 2*ones(Float32,50,1)]) #exogenus signal
ex = ex'
ode_data = gpu(Float32.(hammerstein_system(ex))) #signal we want to predict

#Define the ode layer
nn_dudt = FastChain(FastDense(2, 8, tanh),FastDense(8, 1))
u0 = Float32[0.0]|> gpu
p = initial_params(nn_dudt)|> gpu

function dudt2(u,p,t,ex)
  nn_dudt(vcat(u,ex[Int(round(t*10))]), p)
end

@test vcat(u0,ex[Int(round(1.0*10))]) isa CuArray

_dudt2(u,p,t) = dudt2(u,p,t,ex)
prob_gpu = ODEProblem(_dudt2, u0, tspan, nothing)

# Runs on a GPU
function predict_neuralode(p)
  _prob_gpu = remake(prob_gpu,p=p)
  gpu(solve(_prob_gpu, Tsit5(), saveat = tsteps, abstol = 1e-8, reltol = 1e-6))
end

function loss_neuralode(p)
    pred =predict_neuralode(p)
    N = length(pred)
    l = sum(abs2, ode_data[1:N]' .- pred)/N
    return l
end
Zygote.gradient(loss_neuralode,p)

@DhairyaLGandhi could we revive that to merge?

@ChrisRackauckas
Copy link
Member Author

The second example fails:

using DiffEqFlux, Flux, Optim, OrdinaryDiffEq, CUDA, DiffEqSensitivity, Plots
using Zygote
Zygote.pull_block_vert(sz, Δ::CUDA.CuArray, A::Number) = CUDA.@allowscalar Δ[sz]

u0 = [1.1; 1.1] |> gpu
tspan = (0.0f0,25.0f0)

ann = FastChain(FastDense(2,16,tanh), FastDense(16,16,tanh), FastDense(16,1))
p1 = initial_params(ann)
p2 = Float32[0.5,-0.5]
p3 = [p1;p2]
θ = Float32[u0;p3]

function dudt_(u,p,t)
    x, y = u
    pend = cpu(p[end-1:end])
    [cpu(ann(gpu(u),p[1:length(p1)]))[1],pend[1]*y + pend[2]*x]
end
prob = ODEProblem{false}(dudt_,u0,tspan,p3)

function predict_adjoint(θ)
  gpu(Array(solve(prob,Tsit5(),u0=cpu(θ[1:2]),p=θ[3:end],saveat=0.0:1:25.0,sensealg=QuadratureAdjoint())))
end
loss_adjoint(θ) = sum(abs2,predict_adjoint(θ)[2,:].-1)
l = loss_adjoint(θ)

cb = function (θ,l)
  println(l)
  #display(plot(solve(remake(prob,p=Flux.data(p3),u0=Flux.data(u0)),Tsit5(),saveat=0.1),ylim=(0,6)))
  return false
end

loss1 = loss_adjoint(θ)
Zygote.gradient(loss_adjoint,θ)

but has an easy non-diffeq MWE:

p = gpu(rand(5))
function f(p)
    pend = cpu(p[end-1:end])
    pend[1]*2 + pend[2]* 5
end
Zygote.gradient(f,p)

@DhairyaLGandhi
Copy link
Member

OneElement shouldn't be leaked to user facing code. We already got most of the performance from generalizing our accumulation strategy already. Do you think we can remove OneElement?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants