-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[Compression] Quantization: add module fusion #5400
Conversation
print(f"target_space_2={target_spaces_2}\n") | ||
for module_name, wrapper in module_wrappers_2.items(): | ||
print(f"module_name={module_name}\tconfig={wrapper.config}\twrapper={wrapper}\n") | ||
print(f"target_space_2={target_spaces_2}\n") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could add a test under /test/algo
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -177,7 +180,10 @@ def select_modules_by_config(model: torch.nn.Module, config: Dict[str, Any]) -> | |||
for op_type in exclude_op_types: | |||
op_names.difference_update(type2names.get(op_type, set())) | |||
|
|||
return {module_name: name2module[module_name] for module_name in op_names}, config | |||
if len(fuse_names) > 0 and len(op_names) > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this len(fuse_names) > 0 and len(fuse_names) != len(op_names)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
if cur_module is None: | ||
raise ValueError(f"can\'t find {module_name} in the model") | ||
|
||
return cur_module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here can use get_nested_attr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okok
'quant_dtype': 'int2', | ||
'quant_scheme': 'affine', | ||
'granularity': 'default', | ||
'fuse_names': ["conv1", "batchnorm1"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a little confused, maybe [("conv1", "batchnorm1")]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
self.fused_modules = fused_modules if fused_modules is not None else [] | ||
if len(self.fused_modules) > 0: | ||
self.is_bias = check_bias(self.module) # used for fold_bn | ||
self.register_bias() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
call this func in wrap()
? we don't wrap the module at this time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider a case, if a module has wrapped in the pruning compressor, then it can not be wrapped for module fusion.
WARNING: Wrapper of module is wrapped, no need to wrap again.
delattr(self.module, 'original_bias') | ||
if len(self.fused_modules) > 0 and not self.is_bias and check_bias(self.module): | ||
delattr(self.module, 'bias') | ||
self.module.register_parameter('bias', None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if the original module don't have a None bias?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the original module don't have a None bias, here self.is_bias is True. For example, in a linear module, if we set bias=False, then linear.bias is None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean that some module like nn.Softmax
do not have a parameter bias?
maybe check_bias
should have three states: 'Tensor', 'None', 'non-exist'?
wrapper.register_fusion_info(fused_modules) | ||
assert hasattr(wrapper.module, '_nni_wrapper'), \ | ||
f'wrapper {wrapper.name} is not wrapped, please wrap it before register new wrapper' | ||
wrapper.fused_modules = fused_modules |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if wrapper already has none empty fused_modules
?
@@ -177,7 +180,7 @@ def select_modules_by_config(model: torch.nn.Module, config: Dict[str, Any]) -> | |||
for op_type in exclude_op_types: | |||
op_names.difference_update(type2names.get(op_type, set())) | |||
|
|||
return {module_name: name2module[module_name] for module_name in op_names}, config | |||
return {module_name: name2module[module_name] for module_name in op_names}, config, fuse_names |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please search select_modules_by_config
and check this retrun values change is legal.
Description
module fusion for quant
Test Options
Checklist
How to test