Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

[Compression] Quantization: add module fusion #5400

Merged
merged 64 commits into from
Mar 9, 2023

Conversation

Bonytu
Copy link
Contributor

@Bonytu Bonytu commented Feb 23, 2023

Description

module fusion for quant

Test Options

  • fast test
  • full test - HPO
  • full test - NAS
  • full test - compression

Checklist

  • test case
  • doc

How to test

print(f"target_space_2={target_spaces_2}\n")
for module_name, wrapper in module_wrappers_2.items():
print(f"module_name={module_name}\tconfig={wrapper.config}\twrapper={wrapper}\n")
print(f"target_space_2={target_spaces_2}\n")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could add a test under /test/algo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -177,7 +180,10 @@ def select_modules_by_config(model: torch.nn.Module, config: Dict[str, Any]) ->
for op_type in exclude_op_types:
op_names.difference_update(type2names.get(op_type, set()))

return {module_name: name2module[module_name] for module_name in op_names}, config
if len(fuse_names) > 0 and len(op_names) > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this len(fuse_names) > 0 and len(fuse_names) != len(op_names)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

if cur_module is None:
raise ValueError(f"can\'t find {module_name} in the model")

return cur_module
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here can use get_nested_attr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okok

'quant_dtype': 'int2',
'quant_scheme': 'affine',
'granularity': 'default',
'fuse_names': ["conv1", "batchnorm1"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a little confused, maybe [("conv1", "batchnorm1")]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self.fused_modules = fused_modules if fused_modules is not None else []
if len(self.fused_modules) > 0:
self.is_bias = check_bias(self.module) # used for fold_bn
self.register_bias()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call this func in wrap()? we don't wrap the module at this time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okok

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider a case, if a module has wrapped in the pruning compressor, then it can not be wrapped for module fusion.
WARNING: Wrapper of module is wrapped, no need to wrap again.

delattr(self.module, 'original_bias')
if len(self.fused_modules) > 0 and not self.is_bias and check_bias(self.module):
delattr(self.module, 'bias')
self.module.register_parameter('bias', None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if the original module don't have a None bias?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the original module don't have a None bias, here self.is_bias is True. For example, in a linear module, if we set bias=False, then linear.bias is None.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean that some module like nn.Softmax do not have a parameter bias?
maybe check_bias should have three states: 'Tensor', 'None', 'non-exist'?

wrapper.register_fusion_info(fused_modules)
assert hasattr(wrapper.module, '_nni_wrapper'), \
f'wrapper {wrapper.name} is not wrapped, please wrap it before register new wrapper'
wrapper.fused_modules = fused_modules
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if wrapper already has none empty fused_modules?

@Bonytu Bonytu requested a review from J-shang March 3, 2023 04:16
@@ -177,7 +180,7 @@ def select_modules_by_config(model: torch.nn.Module, config: Dict[str, Any]) ->
for op_type in exclude_op_types:
op_names.difference_update(type2names.get(op_type, set()))

return {module_name: name2module[module_name] for module_name in op_names}, config
return {module_name: name2module[module_name] for module_name in op_names}, config, fuse_names
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please search select_modules_by_config and check this retrun values change is legal.

@J-shang J-shang changed the title module fusion in the quantization [Compression] Quantization: add module fusion Mar 9, 2023
@J-shang J-shang merged commit f875f47 into microsoft:master Mar 9, 2023
super-dainiu pushed a commit to super-dainiu/nni that referenced this pull request May 27, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants