@@ -174,14 +174,13 @@ def get_module(model, key):
174174 model (torch.nn.Module): original model
175175 key (str): module name to be replaced
176176 """
177- attrs = key .split ("." )
178177 module = model
179- for attr in attrs :
180- try :
181- attr = int ( attr )
182- module = module [ attr ]
183- except :
184- module = getattr ( module , attr )
178+ name_list = key . split ( "." )
179+ for name in name_list :
180+ if hasattr ( module , name ):
181+ module = getattr ( module , name )
182+ else :
183+ module = module
185184 return module
186185
187186
@@ -193,15 +192,14 @@ def set_module(model, key, new_module):
193192 key (str): module name to be replaced
194193 new_module (torch.nn.Module): new module to be inserted
195194 """
196- attrs = key .split ("." )
197195 module = model
198- for attr in attrs [: - 1 ]:
199- try :
200- attr = int ( attr )
201- module = module [ attr ]
202- except :
203- module = getattr ( module , attr )
204- setattr (module , attrs [- 1 ], new_module )
196+ name_list = key . split ( "." )
197+ for name in name_list [: - 1 ] :
198+ if hasattr ( module , name ):
199+ module = getattr ( module , name )
200+ else :
201+ module = module
202+ setattr (module , name_list [- 1 ], new_module )
205203
206204
207205def cal_scale (input_max , weights , alpha , scale_type = "orig" ):
0 commit comments