-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Getting NaN at reliability loss occasionally during training #21
Comments
I'm not sure what is the cause of the error in your case. Anyway, what you can do if you lack memory is simply to accumulate the gradient over each individual images: class MyTrainer(trainer.Trainer):
""" This class implements the network training.
Below is the function I need to overload to explain how to do the backprop.
"""
def forward_backward(self, inputs):
img1, img2 = inputs.pop('img1'), inputs.pop('img2')
batch_size = len(img1)
sum_loss = 0
sum_details = {}
for i in range(batch_size):
sl = slice(i, i+1) # select a single image
output = self.net(imgs=[img1[sl],img2[sl]])
allvars = dict(inputs, **output)
loss, details = self.loss_func(**{k:v[sl] for k,v in allvars.items()})
if torch.is_grad_enabled(): loss.backward()
sum_loss += loss
sum_details = {k:v+sum_details.get(k,0) for k,v in details.items()}
return sum_loss, sum_details Without any other modification of the code, this should lead to exactly the same results as the original code (yet using much less memory). (Disclaimer: i didn't actually execute this code, but in principle it should work perfectly. ) Also, I noticed one thing is weird in your screenshot: the |
Thanks for quick reply, and the information. About the present issue, I see what you mean...
But as you mentioned... I also tried to debug where is the problem. It seems to me that this problem happens in relation with the for-loop over tqdm(self.loader). Therefore, I think that this problem is more related to the conversion between PIL, numpy, and pytorch conversion during preparing data somewhere between CatPairDataset and PairLoader... However, yes, this could be the root to the problem of why I get the NaN reliability loss... which I am going to talk about next...
I found that the problem happens more often with batch_size = 1 was caused by msk the line 41 in reliability_loss is having False value. So, I checked further in self.sampler which directs me to NghSampler2. It seems that this is because (mask == False).all() in L337. I found that it is both because "aflow" is all NaN, as well as after assigning b1, x1, y1, then aflow[b1, :, y1, x1] is all NaN. So, either the value of aflow or the assignment value of b1, x1, y1 is the problem. But I am more inclined that the value of *aflow is the problem.
Here, I also have the captured screen where I found the (mask == False).all() and "aflow" is all NaN ... |
Update! I found where in the scripts that causes the
This is somewhat following the suggestion in https://stackoverflow.com/questions/39554660/np-arrays-being-immutable-assignment-destination-is-read-only
So, it would be great if you would kindly provide some answers to my previous questions.. |
Ok, I see the problem much better now. The problem with NaN is indeed due to the fact that sometimes, one of the training pair contains not a single valid pixel. When training with However, in your case, well of course with loss = loss = pixel_loss[msk].sum() / (1 + msk.sum()) Note that it will not be exactly equal to the original loss, which was giving an equal weight to each valid pixel over the entire batch, whereas now each valid pixel gets a weight that depends on the number of valid pixel per image. Another solution would be to ensure that the pair loader never returns image pairs with 'all-invalid' pixels :) |
I see. Thank you very much, @jerome-revaud ! Also, I am not sure if this is too late to say .... I think I get all the answers .... and then this issue should be closed.. |
@GabbySuwichaya have you resoved the problem for MMA@3 drop? I also observe the same drop even for batch_size 8 and can not firgure out why. |
@FangGet .... I have resolved my problem by using a batch size of 4. Also, I find that it is unnecessary to remove the warning. I get a slightly better stats with a batch size of 4 than 8 (Here, I use N=16). This could be because of my GPU power. Also, my problem was somehow based on an older version of R2D2 (commit b23a2c4f4f608 adding MMA plots Jan/28 ). I am not sure if the problem that you have come from the same causes. However, it is probably best to debug if you get the NaN loss at APLoss by any chance and if that is because "aflow" is all NaN? The problem that I have is because of the NaN aflow. |
ok, I will check it, thank you. |
Could you please suggest how to solve this problem and why it happens ?
During training, I get NaN at reliability loss occasionally, which happens more often when batch size is set to small number such as 1 or 2. (threads = 1)
Here, I also attached the screen shot when it happens...
I have used the default setting in the train.py, except that batch size = 1. My computer does not have enough GPU memory when batch size > 4 .
Initially, I suspected that this problem happens due to the lack of the corresponding pixels between two image. Therefore, I have tried to skip any samples that causes the NaN loss by forcing MyTrainer.forward_backward() in train.py to return before calculating loss.backward() as shown in the following captured screen and adding "continuous" and print(details) in the if-condition in Line 55 and continue training until it finished 25 epoch.
However, my trained R2D2 WASF N16 has a drop in MMA performance as shown below in the attached plot...
Here, I have used the default setting with WASF N16 and min_size = 256, max_size =1024.
The performance of my trained model is denoted as R2D2 WASF N16 (Self-trained) and
it is compared with the downloaded pretrained R2D2_WASF_N16.pt, at the same feature extraction setting.
But the performance drop is quite obvious.
The MMA @ 3px is 0.67 instead of 0.72+/- 0.01.
The text was updated successfully, but these errors were encountered: