Skip to content

Commit a674c63

Browse files
authored
[test] added torchvision models to test model zoo (#3132)
* [test] added torchvision models to test model zoo * polish code * polish code * polish code * polish code * polish code * polish code
1 parent 1216d1e commit a674c63

File tree

5 files changed

+162
-26
lines changed

5 files changed

+162
-26
lines changed

tests/kit/model_zoo/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from . import diffusers, timm
1+
from . import diffusers, timm, torchvision
22
from .registry import model_zoo
33

44
__all__ = ['model_zoo']

tests/kit/model_zoo/registry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,13 @@
99
class ModelAttribute:
1010
"""
1111
Attributes of a model.
12+
13+
Args:
14+
has_control_flow (bool): Whether the model contains branching in its forward method.
15+
has_stochastic_depth_prob (bool): Whether the model contains stochastic depth probability. Often seen in the torchvision models.
1216
"""
1317
has_control_flow: bool = False
18+
has_stochastic_depth_prob: bool = False
1419

1520

1621
class ModelZooRegistry(dict):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .torchvision import *
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
from collections import namedtuple
2+
3+
import torch
4+
import torchvision
5+
import torchvision.models as tm
6+
from packaging import version
7+
8+
from ..registry import ModelAttribute, model_zoo
9+
10+
data_gen_fn = lambda: dict(x=torch.rand(4, 3, 224, 224))
11+
output_transform_fn = lambda x: dict(output=x)
12+
13+
# special data gen fn
14+
inception_v3_data_gen_fn = lambda: dict(x=torch.rand(4, 3, 299, 299))
15+
16+
17+
# special model fn
18+
def swin_s():
19+
from torchvision.models.swin_transformer import Swin_T_Weights, _swin_transformer
20+
21+
# adapted from torchvision.models.swin_transformer.swin_small
22+
weights = None
23+
weights = Swin_T_Weights.verify(weights)
24+
progress = True
25+
26+
return _swin_transformer(
27+
patch_size=[4, 4],
28+
embed_dim=96,
29+
depths=[2, 2, 6, 2],
30+
num_heads=[3, 6, 12, 24],
31+
window_size=[7, 7],
32+
stochastic_depth_prob=0, # it is originally 0.2, but we set it to 0 to make it deterministic
33+
weights=weights,
34+
progress=progress,
35+
)
36+
37+
38+
# special output transform fn
39+
google_net_output_transform_fn = lambda x: dict(output=x.logits) if isinstance(x, torchvision.models.GoogLeNetOutputs
40+
) else dict(output=x)
41+
swin_s_output_output_transform_fn = lambda x: {f'output{idx}': val
42+
for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
43+
inception_v3_output_transform_fn = lambda x: dict(output=x.logits) if isinstance(x, torchvision.models.InceptionOutputs
44+
) else dict(output=x)
45+
46+
model_zoo.register(name='torchvision_alexnet',
47+
model_fn=tm.alexnet,
48+
data_gen_fn=data_gen_fn,
49+
output_transform_fn=output_transform_fn)
50+
model_zoo.register(name='torchvision_densenet121',
51+
model_fn=tm.densenet121,
52+
data_gen_fn=data_gen_fn,
53+
output_transform_fn=output_transform_fn)
54+
model_zoo.register(name='torchvision_efficientnet_b0',
55+
model_fn=tm.efficientnet_b0,
56+
data_gen_fn=data_gen_fn,
57+
output_transform_fn=output_transform_fn,
58+
model_attribute=ModelAttribute(has_stochastic_depth_prob=True))
59+
model_zoo.register(name='torchvision_googlenet',
60+
model_fn=tm.googlenet,
61+
data_gen_fn=data_gen_fn,
62+
output_transform_fn=google_net_output_transform_fn)
63+
model_zoo.register(name='torchvision_inception_v3',
64+
model_fn=tm.inception_v3,
65+
data_gen_fn=inception_v3_data_gen_fn,
66+
output_transform_fn=inception_v3_output_transform_fn)
67+
model_zoo.register(name='torchvision_mobilenet_v2',
68+
model_fn=tm.mobilenet_v2,
69+
data_gen_fn=data_gen_fn,
70+
output_transform_fn=output_transform_fn)
71+
model_zoo.register(name='torchvision_mobilenet_v3_small',
72+
model_fn=tm.mobilenet_v3_small,
73+
data_gen_fn=data_gen_fn,
74+
output_transform_fn=output_transform_fn)
75+
model_zoo.register(name='torchvision_mnasnet0_5',
76+
model_fn=tm.mnasnet0_5,
77+
data_gen_fn=data_gen_fn,
78+
output_transform_fn=output_transform_fn)
79+
model_zoo.register(name='torchvision_resnet18',
80+
model_fn=tm.resnet18,
81+
data_gen_fn=data_gen_fn,
82+
output_transform_fn=output_transform_fn)
83+
model_zoo.register(name='torchvision_regnet_x_16gf',
84+
model_fn=tm.regnet_x_16gf,
85+
data_gen_fn=data_gen_fn,
86+
output_transform_fn=output_transform_fn)
87+
model_zoo.register(name='torchvision_resnext50_32x4d',
88+
model_fn=tm.resnext50_32x4d,
89+
data_gen_fn=data_gen_fn,
90+
output_transform_fn=output_transform_fn)
91+
model_zoo.register(name='torchvision_shufflenet_v2_x0_5',
92+
model_fn=tm.shufflenet_v2_x0_5,
93+
data_gen_fn=data_gen_fn,
94+
output_transform_fn=output_transform_fn)
95+
model_zoo.register(name='torchvision_squeezenet1_0',
96+
model_fn=tm.squeezenet1_0,
97+
data_gen_fn=data_gen_fn,
98+
output_transform_fn=output_transform_fn)
99+
100+
model_zoo.register(name='torchvision_vgg11',
101+
model_fn=tm.vgg11,
102+
data_gen_fn=data_gen_fn,
103+
output_transform_fn=output_transform_fn)
104+
model_zoo.register(name='torchvision_wide_resnet50_2',
105+
model_fn=tm.wide_resnet50_2,
106+
data_gen_fn=data_gen_fn,
107+
output_transform_fn=output_transform_fn)
108+
109+
if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
110+
model_zoo.register(name='torchvision_vit_b_16',
111+
model_fn=tm.vit_b_16,
112+
data_gen_fn=data_gen_fn,
113+
output_transform_fn=output_transform_fn)
114+
model_zoo.register(name='torchvision_convnext_base',
115+
model_fn=tm.convnext_base,
116+
data_gen_fn=data_gen_fn,
117+
output_transform_fn=output_transform_fn,
118+
model_attribute=ModelAttribute(has_stochastic_depth_prob=True))
119+
120+
if version.parse(torchvision.__version__) >= version.parse('0.13.0'):
121+
model_zoo.register(
122+
name='torchvision_swin_s',
123+
model_fn=swin_s,
124+
data_gen_fn=data_gen_fn,
125+
output_transform_fn=swin_s_output_output_transform_fn,
126+
)
127+
model_zoo.register(name='torchvision_efficientnet_v2_s',
128+
model_fn=tm.efficientnet_v2_s,
129+
data_gen_fn=data_gen_fn,
130+
output_transform_fn=output_transform_fn,
131+
model_attribute=ModelAttribute(has_stochastic_depth_prob=True))

tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,43 @@
11
import torch
2-
import torchvision
3-
import torchvision.models as tm
4-
from packaging import version
52

63
from colossalai.fx import symbolic_trace
4+
from tests.kit.model_zoo import model_zoo
75

86

97
def test_torchvision_models():
10-
MODEL_LIST = [
11-
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
12-
tm.regnet_x_16gf, tm.mnasnet0_5, tm.efficientnet_b0
13-
]
14-
15-
RANDOMIZED_MODELS = [tm.efficientnet_b0]
16-
17-
if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
18-
MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small])
19-
RANDOMIZED_MODELS.append(tm.convnext_small)
20-
218
torch.backends.cudnn.deterministic = True
9+
tv_sub_registry = model_zoo.get_sub_registry('torchvision')
2210

23-
data = torch.rand(2, 3, 224, 224)
11+
for name, (model_fn, data_gen_fn, output_transform_fn, model_attribute) in tv_sub_registry.items():
12+
data = data_gen_fn()
2413

25-
for model_cls in MODEL_LIST:
26-
if model_cls in RANDOMIZED_MODELS:
27-
# remove the impact of randomicity
28-
model = model_cls(stochastic_depth_prob=0)
14+
if model_attribute is not None and model_attribute.has_stochastic_depth_prob:
15+
model = model_fn(stochastic_depth_prob=0)
2916
else:
30-
model = model_cls()
17+
model = model_fn()
3118

3219
gm = symbolic_trace(model)
3320

3421
model.eval()
3522
gm.eval()
3623

37-
with torch.no_grad():
38-
fx_out = gm(data)
39-
non_fx_out = model(data)
40-
assert torch.allclose(
41-
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
24+
try:
25+
with torch.no_grad():
26+
fx_out = gm(**data)
27+
non_fx_out = model(**data)
28+
transformed_out = output_transform_fn(fx_out)
29+
transformed_non_fx_out = output_transform_fn(non_fx_out)
30+
31+
assert len(transformed_out) == len(transformed_non_fx_out)
32+
33+
for key in transformed_out.keys():
34+
fx_val = transformed_out[key]
35+
non_fx_val = transformed_non_fx_out[key]
36+
assert torch.allclose(
37+
fx_val,
38+
non_fx_val), f'{model.__class__.__name__} has inconsistent outputs, {fx_val} vs {non_fx_val}'
39+
except Exception as e:
40+
print(name, e)
4241

4342

4443
if __name__ == '__main__':

0 commit comments

Comments
 (0)