You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting")
self._has_warned=True
JIT should now be able to handle this automatically without the need of this variable. Thus we should remove the specific variable and confirm that calling forward multiple times on a JIT-scripted model produces a single warning. If that's indeed the case, we can remove it from all places:
If I remove the if check from line 554, I see the warning appear every time forward() is invoked. Here is what I did:
model = retinanet_resnet50_fpn(pretrained=True)
model.eval()
img = torch.rand([1, 3, 512, 512])
model(img)
for p in model.parameters():
p.requires_grad_(False)
scripted_model = torch.jit.script(model, img)
for _ in range(10):
img_test = torch.rand([3, 512, 512])
scripted_model([img_test])
Sounds like it's not fixed yet on the TorchScript side and can't remove it. I'll be closing the issue and we can revisit on the future once it's patched.
🚀 Feature
Several detection models make use of an attribute
_has_warned
which was originally used by JIT to avoid throwing the same warning multiple times:vision/torchvision/models/detection/retinanet.py
Lines 373 to 374 in 9778d26
vision/torchvision/models/detection/retinanet.py
Lines 554 to 556 in 9778d26
JIT should now be able to handle this automatically without the need of this variable. Thus we should remove the specific variable and confirm that calling forward multiple times on a JIT-scripted model produces a single warning. If that's indeed the case, we can remove it from all places:
The text was updated successfully, but these errors were encountered: