1+ import torch
2+ import torch .nn as nn
3+ import torch .nn .functional as F
4+ import yaml
5+
6+ def get_VGG (config ):
7+ name = config ['model' ]
8+ model_list = config ['VGG_types' ]
9+
10+ if name == 'VGG11' :
11+ return VGGnet (model_list [name ])
12+ elif name == 'VGG13' :
13+ return VGGnet (model_list [name ])
14+ elif name == 'VGG16' :
15+ return VGGnet (model_list [name ])
16+ elif name == 'VGG19' :
17+ return VGGnet (model_list [name ])
18+ else :
19+ print ("There is no name in models" )
20+
21+
22+
23+ class VGGnet (nn .Module ):
24+ def __init__ (self , model , in_channels = 3 , num_classes = 36 , init_weights = True ):
25+ super (VGGnet , self ).__init__ ()
26+ self .in_channels = in_channels
27+
28+ self .conv_layers = self .create_conv_layers (model )
29+
30+ self .fcs = nn .Sequential (
31+ nn .Linear (512 , 4096 ),
32+ nn .ReLU (),
33+ nn .Dropout (p = 0.5 ),
34+ nn .Linear (4096 , 4096 ),
35+ nn .ReLU (),
36+ nn .Dropout (p = 0.5 ),
37+ nn .Linear (4096 , num_classes )
38+ )
39+
40+ if init_weights :
41+ self ._initialize_weights ()
42+
43+
44+ def forward (self , x ):
45+ x = self .conv_layers (x )
46+ x = x .view (- 1 , 512 )
47+ x = self .fcs (x )
48+ return x
49+
50+
51+ def _initialize_weights (self ):
52+ for m in self .modules ():
53+ if isinstance (m , nn .Conv2d ):
54+ nn .init .kaiming_normal_ (m .weight , mode = 'fan_out' , nonlinearity = 'relu' )
55+ if m .bias is not None :
56+ nn .init .constant_ (m .bias , 0 )
57+ elif isinstance (m , nn .BatchNorm2d ):
58+ nn .init .constant_ (m .weight , 1 )
59+ nn .init .constant_ (m .bias , 0 )
60+ elif isinstance (m , nn .Linear ):
61+ nn .init .normal_ (m .weight , 0 , 0.01 )
62+ nn .init .constant_ (m .bias , 0 )
63+
64+ def create_conv_layers (self , architecture ):
65+ layers = []
66+ in_channels = self .in_channels
67+
68+ for x in architecture :
69+ if type (x ) == int :
70+ out_channels = x
71+ layers += [nn .Conv2d (in_channels = in_channels , out_channels = out_channels ,
72+ kernel_size = (3 ,3 ), stride = (1 ,1 ), padding = (1 ,1 )),
73+ nn .BatchNorm2d (x ),
74+ nn .ReLU ()]
75+ in_channels = x
76+ elif x == 'M' :
77+ layers += [nn .MaxPool2d (kernel_size = (2 ,2 ), stride = (2 ,2 ))]
78+
79+ return nn .Sequential (* layers )
80+
81+
82+ # Open config file -> quick test
83+ def open_config_file ():
84+ with open ("/data/Github_Management/StartDeepLearningWithPytorch/Chapter04/config/config.yaml" , 'r' , encoding = 'utf-8' ) as stream :
85+ try :
86+ config = yaml .safe_load (stream ) # return into Dict
87+ except yaml .YAMLError as exc :
88+ print (exc )
89+ return config ['VGG_types' ]
90+
91+
92+
93+ if __name__ == '__main__' :
94+ print ('Quick Test...' )
95+
96+ models = open_config_file ()
97+ model = VGGnet (models ['VGG19' ])
98+ print (model )
99+
100+ input = torch .zeros ([1 ,3 ,32 ,32 ], dtype = torch .float32 )
101+ # model = VGG_19(32, 3)
102+ output = model (input )
103+
104+ print ('input_shape: {}, output_size: {}'
105+ .format (input .shape , output .shape ))
106+
0 commit comments