-
-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Activation Function Experiments #441
Comments
Experiments on 5ff6e6b below. Results are test.py mAP at Swish produces the best results, with the highest mAP and lowest validation losses, across almost all epochs (not just the final epoch), but the difference is small, and the increase in GPU memory is significant. LeakyReLU is 'inplace', reducing GPU memory, whereas swish requires +50% more GPU memory (being a custom module), and PRELU requires about 30% more GPU memory. python3 train.py --img-size 320 --epochs 27 --batch-size 64 --accumulate 1 --nosave
|
@glenn-jocher Following lukemelas/EfficientNet-PyTorch#88, GPU memory consumption for Swish decreases, if the swish implementation inherits class SwishImplementation(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = i * torch.sigmoid(i)
ctx.save_for_backward(i)
return result
@staticmethod
def backward(ctx, grad_output):
i = ctx.saved_variables[0]
sigmoid_i = torch.sigmoid(i)
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
class MemoryEfficientSwish(nn.Module):
def forward(self, x):
return SwishImplementation.apply(x) However, this will increase the training time. If you are still using Swish for some of your experiments and getting out of memory errors, it could be useful. |
@okanlv nice find!! It looks like Swish is indeed improving performance in this repo, so this new class may be very useful. I will test it on 1 epoch of COCO using the command below on a V100 GCP instance. python3 train.py --data data/coco.data --cfg cfg/yolov3s.cfg --weights '' --epochs 1 Comparing it to the default LeakyReLU(0.1) and our current Swish() implementation: class Swish(nn.Module):
def forward(self, x):
return x.mul_(torch.sigmoid(x))
That's strange, the two Swish versions are returning different losses and mAPs, with the memory efficient version worse in both. I had expected them to produce exactly the same results. |
Hmm, I didn't expect that. If I find anything else, I will keep you updated. |
To double check, I trained to 27 epochs, and got the same results. MemoryEfficientSwish() produces worse results: 49.3 mAP vs 49.7 mAP compared to default Swish() implementation. I don't exactly know why. I use Apex for mixed precision training BTW, not sure if that has any effect. |
Edit: Both forward and backward for both functions produces the same results as expected. I have updated code to plot gradients for both functions. @glenn-jocher, I might have found the problem. Inplace operation import torch
import torch.nn as nn
import matplotlib.pyplot as plt
class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()
def forward(self, x):
return x.mul_(torch.sigmoid(x))
class SwishImplementation(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = i * torch.sigmoid(i)
ctx.save_for_backward(i)
return result
@staticmethod
def backward(ctx, grad_output):
i = ctx.saved_variables[0]
sigmoid_i = torch.sigmoid(i)
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
class MemoryEfficientSwish(nn.Module):
def __init__(self):
super(MemoryEfficientSwish, self).__init__()
def forward(self, x):
return SwishImplementation.apply(x)
f1 = Swish()
f2 = MemoryEfficientSwish()
# 1st method
# returns RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.
# x = torch.linspace(-5, 5, 1000)
# x1 = x.clone().detach()
# x1.requires_grad = True
# y1 = f1(x1)
# 2nd method
x = torch.linspace(-5, 5, 1000, requires_grad=True)
x_copy= x.clone().detach()
x_copy.requires_grad = True
x1 = x.clone()
x2 = x_copy.clone()
y1 = f1(x1)
y2 = f2(x2)
print('\nDid Swish changed its input?')
print(not torch.allclose(x, x1))
print('\nDid MemoryEfficientSwish changed its input?')
print(not torch.allclose(x, x2))
plt.xlim(-6, 6)
plt.ylim(-1, 6)
plt.plot(x.detach().numpy(), y1.detach().numpy())
plt.plot(x.detach().numpy(), y2.detach().numpy())
plt.title('Swish functions')
plt.legend(['Swish', 'MemoryEfficientSwish'], loc='upper left')
plt.show()
y1.backward(torch.ones_like(x))
y2.backward(torch.ones_like(x))
assert torch.allclose(y1, y2)
assert torch.allclose(x.grad, x_copy.grad)
def getBack(var_grad_fn):
print(var_grad_fn)
for n in var_grad_fn.next_functions:
if n[0]:
try:
tensor = getattr(n[0], 'variable')
print('\t', n[0])
# print('Tensor with grad found:', tensor)
# print(' - gradient:', tensor.grad)
print()
except AttributeError as e:
getBack(n[0])
print('\nTracing backward functions for Swish')
getBack(y1.grad_fn)
print('\nTracing backward functions for MemoryEfficientSwish')
getBack(y2.grad_fn)
plt.xlim(-6, 6)
plt.ylim(-1, 2)
plt.plot(x.detach().numpy(), x.grad.detach().numpy())
plt.plot(x.detach().numpy(), x_copy.grad.detach().numpy())
plt.title('Swish gradient functions')
plt.legend(['Swish', 'MemoryEfficientSwish'], loc='upper left')
plt.show() |
How much better are the results? |
@FranciscoReveriano I am referring to @glenn-jocher 's results in this thread. I have not trained the model myself. |
@okanlv ah, so you are saying that the inplace operator in Swish() is interfering with the gradient computation? That's odd, because I trained with Swish() with and without the inplace operator So do you think the better results with Swish() might be a random occurance? |
I think they are a random occurrence. |
@glenn-jocher @FranciscoReveriano I am not sure actually because using |
@okanlv hmm interesting, ok, keep us updated! I think Swish might be something that we'll want to integrate more in the future, as it does seem to increase mAP a bit in most circumstances. |
Yeah I am looking more into understanding Swish. Might be very beneficial. |
This issue is stale because it has been open 30 days with no activity. Remove Stale label or comment or this will be closed in 5 days. |
@glenn-jocher Hi, I wonder when you change the loss, say from SmoothL1 to GIOU or activation from ReLU to Swish, will you train the entire model from scratch or load part of the pretrained weights from former version before change as a starting point? |
@sudo-rm-covid19 from scratch always. |
This issue documents studies on the YOLOV3 activation function. The PyTorch 1.2 release updated some of the BatchNorm2D weight initializations (from 0-1 uniform random to all 1s), so I thought this would be a good time to benchmark the model and test the default repo against 3 possible improvements:
nn.LeakyReLU(0.1, inplace=True)
class Swish(nn.Module)
nn.PReLU(num_parameters=filters)
nn.PReLU(num_parameters=1)
I benchmarked 5ff6e6b with each of the above activations on the small coco_64img.data tutorial dataset:
PReLU looks promising, but we can't draw any conclusions from this small dataset. In my next post I will plot the results on the full coco dataset trained to 10% of the final epochs, which should be a much more useful comparison.
The text was updated successfully, but these errors were encountered: