Skip to content

Conversation

@ChrisRackauckas
Copy link
Member

Start fixing some of the issues like #401

Start fixing some of the issues like #401
@ChrisRackauckas
Copy link
Member Author

Zygote.pull_block_vert(sz, Δ::CUDA.CuArray, A::Number) = CUDA.@allowscalar Δ[sz] fixes:

using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, CUDA, DiffEqSensitivity, Test
CUDA.allowscalar(false) # Makes sure no slow operations are occuring

#generating exogenus signal and output signal
tspan = (0.1f0, Float32(10.0))
tsteps = range(tspan[1], tspan[2], length = 100)
t_vec = collect(tsteps)
ex = vec(ones(Float32,length(tsteps), 1))
f(x) = (atan(8.0 * x - 4.0) + atan(4.0)) / (2.0 * atan(4.0))

function hammerstein_system(u)
    y= zeros(size(u))
    for k in 2:length(u)
        y[k] = 0.2 * f(u[k-1]) + 0.8 * y[k-1]
    end
    return y
end

ex = vec([ones(Float32,50,1) 2*ones(Float32,50,1)]) #exogenus signal
ex = ex'
ode_data = gpu(Float32.(hammerstein_system(ex))) #signal we want to predict

#Define the ode layer
nn_dudt = FastChain(FastDense(2, 8, tanh),FastDense(8, 1))
u0 = Float32[0.0]|> gpu
p = initial_params(nn_dudt)|> gpu

function dudt2(u,p,t,ex)
  nn_dudt(vcat(u,ex[Int(round(t*10))]), p)
end

@test vcat(u0,ex[Int(round(1.0*10))]) isa CuArray

_dudt2(u,p,t) = dudt2(u,p,t,ex)
prob_gpu = ODEProblem(_dudt2, u0, tspan, nothing)

# Runs on a GPU
function predict_neuralode(p)
  _prob_gpu = remake(prob_gpu,p=p)
  gpu(solve(_prob_gpu, Tsit5(), saveat = tsteps, abstol = 1e-8, reltol = 1e-6))
end

function loss_neuralode(p)
    pred =predict_neuralode(p)
    N = length(pred)
    l = sum(abs2, ode_data[1:N]' .- pred)/N
    return l
end
Zygote.gradient(loss_neuralode,p)

@DhairyaLGandhi could we revive that to merge?

@ChrisRackauckas
Copy link
Member Author

The second example fails:

using DiffEqFlux, Flux, Optim, OrdinaryDiffEq, CUDA, DiffEqSensitivity, Plots
using Zygote
Zygote.pull_block_vert(sz, Δ::CUDA.CuArray, A::Number) = CUDA.@allowscalar Δ[sz]

u0 = [1.1; 1.1] |> gpu
tspan = (0.0f0,25.0f0)

ann = FastChain(FastDense(2,16,tanh), FastDense(16,16,tanh), FastDense(16,1))
p1 = initial_params(ann)
p2 = Float32[0.5,-0.5]
p3 = [p1;p2]
θ = Float32[u0;p3]

function dudt_(u,p,t)
    x, y = u
    pend = cpu(p[end-1:end])
    [cpu(ann(gpu(u),p[1:length(p1)]))[1],pend[1]*y + pend[2]*x]
end
prob = ODEProblem{false}(dudt_,u0,tspan,p3)

function predict_adjoint(θ)
  gpu(Array(solve(prob,Tsit5(),u0=cpu(θ[1:2]),p=θ[3:end],saveat=0.0:1:25.0,sensealg=QuadratureAdjoint())))
end
loss_adjoint(θ) = sum(abs2,predict_adjoint(θ)[2,:].-1)
l = loss_adjoint(θ)

cb = function (θ,l)
  println(l)
  #display(plot(solve(remake(prob,p=Flux.data(p3),u0=Flux.data(u0)),Tsit5(),saveat=0.1),ylim=(0,6)))
  return false
end

loss1 = loss_adjoint(θ)
Zygote.gradient(loss_adjoint,θ)

but has an easy non-diffeq MWE:

p = gpu(rand(5))
function f(p)
    pend = cpu(p[end-1:end])
    pend[1]*2 + pend[2]* 5
end
Zygote.gradient(f,p)

@DhairyaLGandhi
Copy link
Member

OneElement shouldn't be leaked to user facing code. We already got most of the performance from generalizing our accumulation strategy already. Do you think we can remove OneElement?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants