-
Notifications
You must be signed in to change notification settings - Fork 468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: replace bce by focal loss in linknet loss #824
Conversation
Codecov Report
@@ Coverage Diff @@
## main #824 +/- ##
==========================================
- Coverage 96.01% 95.98% -0.03%
==========================================
Files 131 131
Lines 5019 5033 +14
==========================================
+ Hits 4819 4831 +12
- Misses 200 202 +2
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I added a comment on the implementation! Have you tried to check if it improves training perf?
p_t = (seg_target[seg_mask] * pred_prob) + ((1 - seg_target[seg_mask]) * (1 - pred_prob)) | ||
# Compute alpha factor | ||
alpha_factor = seg_target[seg_mask] * alpha + (1 - seg_target[seg_mask]) * (1 - alpha) | ||
# compute the final loss | ||
focal_loss = (alpha_factor * (1. - p_t) ** gamma * bce_loss[seg_mask]).mean() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to address the masking + reduction problem of the dice loss first: once masked, this reduces the tensor to something in 1D. So if one class has 10 times more masked region than another, this will be a problem
I'd suggest doing the following: changing
seg_target[mask].mean()
to
mask = mask.to(dtype=torch.foat32)
# Average on N, H, W
class_loss = (seg_target * mask).sum((0, 2, 3)) / mask.sum((0, 2, 3))
loss = class_loss.mean()
or average it on H, W only before the final mean
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I understand well here: do you want to remove each ...[mask]
occurrence ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the difference is the following:
my_tensor[mask]
is 1D tensor with a number of elements = number ofTrue
in the maskmy_tensor * mask.to(dtype=torch.float32)
has the shape shape asmy_tensor
, it only puts zero on elements that are masked out
Now if you perform a reduction operation like mean:
- in the first case, you divide by the number of elements in
my_tensor[mask]
=mask.sum()
- in the second one, you dive by the number of elements in
my_tensor
And this extends to dimension-specific operations, so if we mask it, we lose the separation of the dimensions to get a contiguous tensor in the end. To properly scale the loss, in the first case, this widely increases the contribution of classes with the highest amount of positive mask (makes no difference if there is only a single class). And since we specifically want to help balance the less-frequent classes here, I suggest leveraging that second option with my suggestion above 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(and we need to do this for both the dice loss and the focal loss)
it will only make a difference in cases of multi-class and where the mask isn't only True
but that might be safer!
Either way, I think we should run a training with the configuration to make sure this yields a positive change 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few corrections in the loss computation and we'll be good to go!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Final cosmetic adjustments and we're good to go 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad, a few fixes to do on my previous suggestion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
* feat: replace bce by focal loss in linknet loss * fix: requested changes * fix: mask reduction * fix: mask reduction * fix: loss reduction * fix: final adjustements * fix: final changes
* feat: replace bce by focal loss in linknet loss * fix: requested changes * fix: mask reduction * fix: mask reduction * fix: loss reduction * fix: final adjustements * fix: final changes
This reverts commit 6511183.
* backup * onnx classification * fix: Fixed some ResNet architecture imprecisions (#828) * feat: Added new resnets * feat: Added ResNet101 * fix: Fixed ResNet31 & ResNet34 wide * feat: Added new pretrained resnets * style: Fixed isort * fix: Fixed ResNet architectures * refactor: Refactored LinkNet * feat: Added more LinkNets * fix: Fixed MAGResNet * docs: Updated documentation * refactor: Removed ResNet101 * fix: Fixed warning * fix: Fixed a few bugs * test: Updated unittests * docs: Fixed docstrings * update with new models * feat: replace bce by focal loss in linknet loss (#824) * feat: replace bce by focal loss in linknet loss * fix: requested changes * fix: mask reduction * fix: mask reduction * fix: loss reduction * fix: final adjustements * fix: final changes * Revert "feat: replace bce by focal loss in linknet loss (#824)" This reverts commit 6511183. * Revert "fix: Fixed some ResNet architecture imprecisions (#828)" This reverts commit 72e5e0d. * happy codacy * sapply suggestions * fix-setup * remove onnx from test req * move onnx deps ftm to torch * up * up * revert requirements * fix * update docstring * up Co-authored-by: F-G Fernandez <76527547+fg-mindee@users.noreply.github.com> Co-authored-by: Charles Gaillard <charles@mindee.co>
* backup * onnx classification * fix: Fixed some ResNet architecture imprecisions (mindee#828) * feat: Added new resnets * feat: Added ResNet101 * fix: Fixed ResNet31 & ResNet34 wide * feat: Added new pretrained resnets * style: Fixed isort * fix: Fixed ResNet architectures * refactor: Refactored LinkNet * feat: Added more LinkNets * fix: Fixed MAGResNet * docs: Updated documentation * refactor: Removed ResNet101 * fix: Fixed warning * fix: Fixed a few bugs * test: Updated unittests * docs: Fixed docstrings * update with new models * feat: replace bce by focal loss in linknet loss (mindee#824) * feat: replace bce by focal loss in linknet loss * fix: requested changes * fix: mask reduction * fix: mask reduction * fix: loss reduction * fix: final adjustements * fix: final changes * Revert "feat: replace bce by focal loss in linknet loss (mindee#824)" This reverts commit 6511183. * Revert "fix: Fixed some ResNet architecture imprecisions (mindee#828)" This reverts commit 72e5e0d. * happy codacy * sapply suggestions * fix-setup * remove onnx from test req * move onnx deps ftm to torch * up * up * revert requirements * fix * update docstring * up Co-authored-by: F-G Fernandez <76527547+fg-mindee@users.noreply.github.com> Co-authored-by: Charles Gaillard <charles@mindee.co>
* backup * onnx classification * fix: Fixed some ResNet architecture imprecisions (mindee#828) * feat: Added new resnets * feat: Added ResNet101 * fix: Fixed ResNet31 & ResNet34 wide * feat: Added new pretrained resnets * style: Fixed isort * fix: Fixed ResNet architectures * refactor: Refactored LinkNet * feat: Added more LinkNets * fix: Fixed MAGResNet * docs: Updated documentation * refactor: Removed ResNet101 * fix: Fixed warning * fix: Fixed a few bugs * test: Updated unittests * docs: Fixed docstrings * update with new models * feat: replace bce by focal loss in linknet loss (mindee#824) * feat: replace bce by focal loss in linknet loss * fix: requested changes * fix: mask reduction * fix: mask reduction * fix: loss reduction * fix: final adjustements * fix: final changes * Revert "feat: replace bce by focal loss in linknet loss (mindee#824)" This reverts commit 6511183. * Revert "fix: Fixed some ResNet architecture imprecisions (mindee#828)" This reverts commit 72e5e0d. * happy codacy * sapply suggestions * fix-setup * remove onnx from test req * move onnx deps ftm to torch * up * up * revert requirements * fix * update docstring * up Co-authored-by: F-G Fernandez <76527547+fg-mindee@users.noreply.github.com> Co-authored-by: Charles Gaillard <charles@mindee.co>
Following a suggestion by @fg-mindee and @SiddhantBahuguna, this PR replaces the BCE loss by the Focal loss in the linknet loss to increase the recall (imbalanced classes)
Any feedback is welcome!