Skip to content

Commit 6de9ce3

Browse files
xin3hechensuyue
authored andcommitted
fix bug in get/set_module (#1268)
Signed-off-by: Xin He <xin3.he@intel.com> (cherry picked from commit dffcfe1)
1 parent 1a6526f commit 6de9ce3

File tree

2 files changed

+13
-16
lines changed

2 files changed

+13
-16
lines changed

neural_compressor/adaptor/torch_utils/smooth_quant.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -174,14 +174,13 @@ def get_module(model, key):
174174
model (torch.nn.Module): original model
175175
key (str): module name to be replaced
176176
"""
177-
attrs = key.split(".")
178177
module = model
179-
for attr in attrs:
180-
try:
181-
attr = int(attr)
182-
module = module[attr]
183-
except:
184-
module = getattr(module, attr)
178+
name_list = key.split(".")
179+
for name in name_list:
180+
if hasattr(module, name):
181+
module = getattr(module, name)
182+
else:
183+
module = module
185184
return module
186185

187186

@@ -193,15 +192,14 @@ def set_module(model, key, new_module):
193192
key (str): module name to be replaced
194193
new_module (torch.nn.Module): new module to be inserted
195194
"""
196-
attrs = key.split(".")
197195
module = model
198-
for attr in attrs[:-1]:
199-
try:
200-
attr = int(attr)
201-
module = module[attr]
202-
except:
203-
module = getattr(module, attr)
204-
setattr(module, attrs[-1], new_module)
196+
name_list = key.split(".")
197+
for name in name_list[:-1]:
198+
if hasattr(module, name):
199+
module = getattr(module, name)
200+
else:
201+
module = module
202+
setattr(module, name_list[-1], new_module)
205203

206204

207205
def cal_scale(input_max, weights, alpha, scale_type="orig"):

neural_compressor/adaptor/torch_utils/util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,6 @@ def set_module(model, op_name, new_module):
620620
else:
621621
module = module
622622
setattr(module, name_list[-1], new_module)
623-
return module
624623

625624

626625
def simple_inference(model, input):

0 commit comments

Comments
 (0)