Skip to content

Commit

Permalink
[Dynamo] Fix TIMM benchmark compute_loss (pytorch#97423)
Browse files Browse the repository at this point in the history
Fixes pytorch#97382

pytorch#95416 fixed a critical bug in dynamo benchmark, where AMP tests fall back to eager mode before that PR. However, after that PR, we found [a list of TIMM models amp + eager + training testing failed](https://docs.google.com/spreadsheets/d/1DEhirVOkj15Lu4UNawIUon9MqkVLaWqyT-DQPif5NHk/edit#gid=0).
Now we identified the root cause is: high loss values make gradient checking harder, as small changes in accumulation order upset accuracy checks. We should switch to the helper function ```reduce_to_scalar_loss``` which has been used by Torchbench tests.
After switching to ```reduce_to_scalar_loss```, TIMM models accuracy pass rate grows from 67.74% to 91.94% in my local test. The rest 5 failed models(ese_vovnet19b_dw, fbnetc_100, mnasnet_100, mobilevit_s, sebotnet33ts_256) need further investigation and handling, but I think it should be similar reason.

Pull Request resolved: pytorch#97423
Approved by: https://github.com/Chillee
  • Loading branch information
yanboliang authored and pytorchmergebot committed Mar 24, 2023
1 parent 5f5d675 commit d305d4a
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 55 deletions.
48 changes: 24 additions & 24 deletions benchmarks/dynamo/ci_expected_accuracy/training_timm_models0.csv
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
name,graph_breaks
adv_inception_v3,7
beit_base_patch16_224,7
coat_lite_mini,7
convmixer_768_32,4
convnext_base,7
crossvit_9_240,7
cspdarknet53,9
deit_base_distilled_patch16_224,7
dla102,7
dm_nfnet_f0,7
dpn107,9
eca_botnext26ts_256,9
ese_vovnet19b_dw,9
fbnetc_100,9
gernet_l,9
ghostnet_100,9
gluon_inception_v3,7
gmixer_24_224,7
gmlp_s16_224,7
hrnet_w18,4
inception_v3,7
jx_nest_base,7
lcnet_050,9
mixer_b16_224,7
adv_inception_v3,9
beit_base_patch16_224,9
coat_lite_mini,9
convmixer_768_32,6
convnext_base,9
crossvit_9_240,9
cspdarknet53,11
deit_base_distilled_patch16_224,9
dla102,9
dm_nfnet_f0,9
dpn107,11
eca_botnext26ts_256,11
ese_vovnet19b_dw,11
fbnetc_100,11
gernet_l,11
ghostnet_100,11
gluon_inception_v3,9
gmixer_24_224,9
gmlp_s16_224,9
hrnet_w18,6
inception_v3,9
jx_nest_base,9
lcnet_050,11
mixer_b16_224,9
58 changes: 29 additions & 29 deletions benchmarks/dynamo/ci_expected_accuracy/training_timm_models1.csv
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
name,graph_breaks
mixnet_l,9
mnasnet_100,9
mobilenetv2_100,9
mobilenetv3_large_100,9
mobilevit_s,9
nfnet_l0,7
pit_b_224,7
pnasnet5large,6
poolformer_m36,7
regnety_002,9
repvgg_a2,9
res2net101_26w_4s,7
res2net50_14w_8s,7
res2next50,7
resmlp_12_224,7
resnest101e,7
rexnet_100,9
selecsls42b,7
spnasnet_100,9
swin_base_patch4_window7_224,7
swsl_resnext101_32x16d,7
tf_efficientnet_b0,9
tf_mixnet_l,9
tinynet_a,9
tnt_s_patch16_224,7
twins_pcpvt_base,7
visformer_small,7
vit_base_patch16_224,7
volo_d1_224,7
mixnet_l,11
mnasnet_100,11
mobilenetv2_100,11
mobilenetv3_large_100,11
mobilevit_s,11
nfnet_l0,9
pit_b_224,9
pnasnet5large,8
poolformer_m36,9
regnety_002,11
repvgg_a2,11
res2net101_26w_4s,9
res2net50_14w_8s,9
res2next50,9
resmlp_12_224,9
resnest101e,9
rexnet_100,11
selecsls42b,9
spnasnet_100,11
swin_base_patch4_window7_224,9
swsl_resnext101_32x16d,9
tf_efficientnet_b0,11
tf_mixnet_l,11
tinynet_a,11
tnt_s_patch16_224,9
twins_pcpvt_base,9
visformer_small,9
vit_base_patch16_224,9
volo_d1_224,9
20 changes: 18 additions & 2 deletions benchmarks/dynamo/timm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from common import BenchmarkRunner, main

from torch._dynamo.testing import collect_results
from torch._dynamo.testing import collect_results, reduce_to_scalar_loss
from torch._dynamo.utils import clone_inputs


Expand Down Expand Up @@ -84,6 +84,14 @@ def pip_install(package):
"cait_m36_384": 4,
}

SCALED_COMPUTE_LOSS = {
"ese_vovnet19b_dw",
"fbnetc_100",
"mnasnet_100",
"mobilevit_s",
"sebotnet33ts_256",
}


def refresh_model_names():
import glob
Expand Down Expand Up @@ -268,6 +276,10 @@ def load_model(
self.target = self._gen_target(batch_size, device)

self.loss = torch.nn.CrossEntropyLoss().to(device)

if model_name in SCALED_COMPUTE_LOSS:
self.compute_loss = self.scaled_compute_loss

if is_training and not use_eval_mode:
model.train()
else:
Expand Down Expand Up @@ -319,7 +331,11 @@ def _gen_target(self, batch_size, device):
def compute_loss(self, pred):
# High loss values make gradient checking harder, as small changes in
# accumulation order upsets accuracy checks.
return self.loss(pred, self.target) / 10.0
return reduce_to_scalar_loss(pred)

def scaled_compute_loss(self, pred):
# Loss values need zoom out further.
return reduce_to_scalar_loss(pred) / 1000.0

def forward_pass(self, mod, inputs, collect_outputs=True):
with self.autocast():
Expand Down

0 comments on commit d305d4a

Please sign in to comment.