Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poor performance relative to PyTorch #886

Closed
jessebett opened this issue Oct 7, 2019 · 5 comments · Fixed by FluxML/Zygote.jl#1044
Closed

Poor performance relative to PyTorch #886

jessebett opened this issue Oct 7, 2019 · 5 comments · Fixed by FluxML/Zygote.jl#1044

Comments

@jessebett
Copy link
Contributor

In addition to the numerical stability differences between Tracker and Zygote described in #876, Zygote is performing considerably worse than the equivalent pytorch code for that example.

Here is the PyTorch code:

import torch
from torch import nn
from torch.utils.data import DataLoader,TensorDataset

# dummy data
x = torch.rand(100000,113)
y = torch.sum(x**2,dim=1, keepdim=True)
dataset = TensorDataset(x,y)
dataloader = DataLoader(dataset, batch_size=256)

model = nn.Sequential(
    nn.Linear(113,1000),
    nn.Linear(1000,1)
    )

criterion = nn.BCEWithLogitsLoss()  # binary cross entropy
optimizer = torch.optim.Adam(
    model.parameters(), lr=1e-4, betas=(0.9, 0.99)
    )

model.train() # enable autograd
for (x,y) in dataloader:
  y_hat_logit = model(x)
  loss = criterion(y_hat_logit,y)
  print(loss.float())

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

Here is the Flux code:

using Flux
using Statistics: mean
using MLDataUtils

# dummy data
x = Float32.(rand(113,100000))
y = sum(x.^2,dims=1)

dataset = batchview((x,y),size=256)

model = Chain(
              Dense(113,1000),
              Dense(1000,1)
             )

criterion(logits,y) = mean(Flux.logitbinarycrossentropy.(logits,y))
optimizer = Flux.ADAM(1e-4,(0.9,0.99))

for (x,y) in dataset
  θ = params(model)
  loss,back = Flux.Zygote.pullback(()->criterion(model(x),y),θ)
  println(loss)
  grads = back(1.)
  Flux.Optimise.update!(optimizer,θ,grads)
end
  1. The above PyTorch code is much faster than the Flux code.
  2. The Flux code, after a few iterations, results in NaNs, where the PyTorch code does not. Possibly the same issue as Model optimization fails (NaNs) with Zygote.pullback but works with Tracker.forward #876
@MikeInnes
Copy link
Member

As with debugging the NaN issue, it would be helpful to strip away as much of this code as possible (e.g. removing the dense layer, loss function) in ways that still show the perf bug. If we can narrow it down to e.g. just one function it'll probably be very easy to fix. I would definitely like to track this down, although the numerical issue should probably be the priority.

@jessebett
Copy link
Contributor Author

@MikeInnes re stripping away as much as possible. Yep. These performance and numerical concerns showed up while doing an assignment, so I just wanted to include the circumstances that showed the issues as they came up in the real world. Can look into where the performance difference shows up later.

One thing I'll look at is whether nn.BCEWithLogitsLoss() is computing the same thing as mean(Flux.logitbinarycrossentropy.(logits,y))

@mcabbott
Copy link
Member

mcabbott commented Feb 8, 2020

With #1031, this gives a warning "Debug: Chain(...) has output of eltype Float32 but receives gradient of eltype Float64". Changing back(1.) to back(1f0) then cuts the time from 26 seconds to 1.5 seconds.

@AStupidBear
Copy link
Contributor

AStupidBear commented May 9, 2020

This is can be reproduced on CPU but not on GPU.

using Flux
using Statistics: mean
using MLDataUtils

# dummy data
x = rand(Float32, 113, 100000) |> gpu
y = sum(x.^2, dims = 1) |> gpu

dataset = batchview((x, y), size = 256)

model = Chain(Dense(113, 1000, relu), Dense(1000, 1)) |> gpu

criterion(logits, y) = mean(Flux.logitbinarycrossentropy.(logits, y))
optimizer = Flux.ADAMW(1e-4, (0.9, 0.999), 1e-5)

for (x,y) in dataset
  θ = params(model)
  loss,back = Flux.Zygote.pullback(θ) do
    criterion(model(x), y)
  end
  println(loss)
  grads = back(1f0)
  Flux.Optimise.update!(optimizer, θ, grads)
end

@mcabbott
Copy link
Member

mcabbott commented Sep 6, 2021

Effect of FluxML/Zygote.jl#1044 on this, on CPU:

julia> @time for (x,y) in dataset
         θ = params(model)
         loss,back = Flux.Zygote.pullback(()->criterion(model(x),y),θ)
         # println(loss)
         grads = back(1.)
         Flux.Optimise.update!(optimizer,θ,grads)
       end
 17.277373 seconds (238.11 k allocations: 1.927 GiB, 0.37% gc time, 0.73% compilation time)  # before
  1.771140 seconds (239.28 k allocations: 2.091 GiB, 6.30% gc time, 6.42% compilation time)  # after
  
  1.671819 seconds (238.49 k allocations: 2.089 GiB, 3.72% gc time, 7.03% compilation time) # with loss(1f0)

All after warming up, quite noisy times.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants