Skip to content

Commit

Permalink
ns for fx: expose hook to define custom weight extraction functions (p…
Browse files Browse the repository at this point in the history
…ytorch#62047)

Summary:
Pull Request resolved: pytorch#62047

Adds a hook for user to define a weight extraction function for a
custom type.

Example usage:
```
op_to_type_to_weight_extraction_fn = \
    get_op_to_type_to_weight_extraction_fn()
op_to_type_to_weight_extraction_fn['call_function'][_wrapped_linear] = \
    torch.quantization.ns.weight_utils.get_linear_fun_weight

results = extract_weights_impl(
    'a', m1, 'b', m2,
    op_to_type_to_weight_extraction_fn=op_to_type_to_weight_extraction_fn)
```

Test Plan:
```
python test/test_quantization.py TestFXNumericSuiteCoreAPIs.test_user_defined_function
```

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D29853625

fbshipit-source-id: 183916ef54ba303bc818e0eba00b52e33c4633ad
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jul 23, 2021
1 parent 07c6a12 commit 04c95a0
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 65 deletions.
38 changes: 34 additions & 4 deletions test/quantization/fx/test_numeric_suite_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@
get_base_name_for_op,
add_op_to_sets_of_related_ops,
)
from torch.quantization.ns.weight_utils import (
get_op_to_type_to_weight_extraction_fn,
)
from torch.quantization._numeric_suite_fx import (
extract_weights,
_extract_weights_impl,
Expand Down Expand Up @@ -263,6 +266,10 @@ def _wrapped_hardswish_fp16(x):
def _wrapped_sigmoid(x):
return F.sigmoid(x)

@torch.fx.wrap
def _wrapped_linear(x, w, b):
return F.linear(x, w, b)



class TestFXGraphMatcher(QuantizationTestCase):
Expand Down Expand Up @@ -1576,33 +1583,56 @@ def test_user_defined_function(self):
Verify that NS APIs work on user defined functions
"""
class M1(nn.Module):
def __init__(self):
super().__init__()
self.w1 = nn.Parameter(torch.empty(1, 1))
self.b1 = nn.Parameter(torch.zeros(1))
torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))

def forward(self, x):
x = F.hardswish(x)
x = x.sigmoid()
x = F.linear(x, self.w1, self.b1)
return x

class M2(nn.Module):
def __init__(self):
super().__init__()
self.w1 = nn.Parameter(torch.empty(1, 1))
self.b1 = nn.Parameter(torch.zeros(1))
torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))

def forward(self, x):
x = _wrapped_hardswish(x)
x = _wrapped_sigmoid(x)
x = _wrapped_linear(x, self.w1, self.b1)
return x

qconfig_dict = {'': torch.quantization.default_qconfig}
m1 = prepare_fx(M1().eval(), qconfig_dict)
m2 = prepare_fx(M2().eval(), qconfig_dict)
data = torch.randn(4, 4)
data = torch.randn(1, 1)

base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
add_op_to_sets_of_related_ops(
base_name_to_sets_of_related_ops, _wrapped_hardswish, F.hardswish)
add_op_to_sets_of_related_ops(
base_name_to_sets_of_related_ops, _wrapped_sigmoid, F.sigmoid)
add_op_to_sets_of_related_ops(
base_name_to_sets_of_related_ops, _wrapped_linear, F.linear)

op_to_type_to_weight_extraction_fn = \
get_op_to_type_to_weight_extraction_fn()
op_to_type_to_weight_extraction_fn['call_function'][_wrapped_linear] = \
torch.quantization.ns.weight_utils.get_linear_fun_weight

# test compare weights
results = _extract_weights_impl(
'a', m1, 'b', m2,
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops)
self.assertTrue(len(results) == 0)
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
op_to_type_to_weight_extraction_fn=op_to_type_to_weight_extraction_fn)
self.assertTrue(len(results) == 1)
self.assertTrue(len(results['_wrapped_linear']['weight']) == 2)

# test unshadowed activations

Expand All @@ -1617,7 +1647,7 @@ def forward(self, x):

# check activation result correctness
act_compare_dict = extract_logger_info(m1_ns, m2_ns, OutputLogger, 'b')
self.assertTrue(len(act_compare_dict) == 2)
self.assertTrue(len(act_compare_dict) == 3)
self.assert_ns_compare_dict_valid(act_compare_dict)

# test shadowed activations
Expand Down
14 changes: 10 additions & 4 deletions torch/quantization/_numeric_suite_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,13 @@ def _extract_weights_one_model(
model: GraphModule,
nodes_and_names_to_instrument: List[Tuple[Node, str]],
results: NSResultsType,
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
) -> None:
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_one_model")
for node, ref_name in nodes_and_names_to_instrument:
res_type = NSSingleResultValuesType.WEIGHT.value
extracted_weight = extract_weight_from_node(node, model)
extracted_weight = extract_weight_from_node(
node, model, op_to_type_to_weight_extraction_fn)
if extracted_weight:
if ref_name not in results:
results[ref_name] = {res_type: {}}
Expand All @@ -146,6 +148,7 @@ def _extract_weights_impl(
gm_b: GraphModule,
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
) -> NSResultsType:
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_impl")
matched_subgraph_pairs = get_matching_subgraph_pairs(
Expand All @@ -163,9 +166,11 @@ def _extract_weights_impl(
# populate the results, one model at a time
results: NSResultsType = {}
_extract_weights_one_model(
model_name_a, gm_a, nodes_and_names_to_instrument_a, results)
model_name_a, gm_a, nodes_and_names_to_instrument_a, results,
op_to_type_to_weight_extraction_fn)
_extract_weights_one_model(
model_name_b, gm_b, nodes_and_names_to_instrument_b, results)
model_name_b, gm_b, nodes_and_names_to_instrument_b, results,
op_to_type_to_weight_extraction_fn)

# fill in missing fqn entries
maybe_add_missing_fqns(results)
Expand All @@ -183,6 +188,7 @@ def extract_weights(
model_b: nn.Module,
base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
) -> NSResultsType:
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
Expand All @@ -202,7 +208,7 @@ def extract_weights(
gm_b._node_name_to_scope = model_b._node_name_to_scope
return _extract_weights_impl(
model_name_a, gm_a, model_name_b, gm_b, base_name_to_sets_of_related_ops,
unmatchable_types_map)
unmatchable_types_map, op_to_type_to_weight_extraction_fn)


def _add_loggers_one_model(
Expand Down
122 changes: 65 additions & 57 deletions torch/quantization/ns/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,65 +151,70 @@ def get_qlinear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
(weight, _bias), _name = packed_weight.__getstate__()
return weight

OP_TO_TYPE_TO_WEIGHT_EXTRACTION_FN: Dict[str, Dict[Callable, Callable]] = {
'call_module': {
# Conv
nn.Conv1d: mod_weight_detach,
nn.Conv2d: mod_weight_detach,
nn.Conv3d: mod_weight_detach,
nni.ConvReLU1d: mod_0_weight_detach,
nni.ConvReLU2d: mod_0_weight_detach,
nni.ConvReLU3d: mod_0_weight_detach,
nnq.Conv1d: mod_weight_bias_0,
nniqat.ConvBn1d: mod_weight_detach,
nniqat.ConvBnReLU1d: mod_weight_detach,
nniq.ConvReLU1d: mod_weight_bias_0,
nnq.Conv2d: mod_weight_bias_0,
nnqat.Conv2d: mod_weight_detach,
nniqat.ConvBn2d: mod_weight_detach,
nniqat.ConvBnReLU2d: mod_weight_detach,
nniqat.ConvReLU2d: mod_weight_detach,
nniq.ConvReLU2d: mod_weight_bias_0,
nnq.Conv3d: mod_weight_bias_0,
nnqat.Conv3d: mod_weight_detach,
nniqat.ConvBn3d: mod_weight_detach,
nniqat.ConvBnReLU3d: mod_weight_detach,
nniqat.ConvReLU3d: mod_weight_detach,
nniq.ConvReLU3d: mod_weight_bias_0,
# Linear
nn.Linear: mod_weight_detach,
nnq.Linear: mod_weight_bias_0,
nni.LinearReLU: mod_0_weight_detach,
nniq.LinearReLU: mod_weight_bias_0,
nnqat.Linear: mod_weight_detach,
nnqd.Linear: mod_weight_bias_0,
nniqat.LinearReLU: mod_weight_detach,
nn.modules.linear.NonDynamicallyQuantizableLinear: mod_weight_detach,
# LSTM
nn.LSTM: get_lstm_weight,
nnqd.LSTM: get_qlstm_weight,
},
'call_function': {
# Conv
F.conv1d: get_conv_fun_weight,
F.conv2d: get_conv_fun_weight,
F.conv3d: get_conv_fun_weight,
toq.conv1d: get_qconv_fun_weight,
toq.conv2d: get_qconv_fun_weight,
toq.conv3d: get_qconv_fun_weight,
toq.conv1d_relu: get_qconv_fun_weight,
toq.conv2d_relu: get_qconv_fun_weight,
toq.conv3d_relu: get_qconv_fun_weight,
# Linear
F.linear: get_linear_fun_weight,
toq.linear: get_qlinear_fun_weight,
toq.linear_relu: get_qlinear_fun_weight,
},
}
def get_op_to_type_to_weight_extraction_fn() -> Dict[str, Dict[Callable, Callable]]:

op_to_type_to_weight_extraction_fn: Dict[str, Dict[Callable, Callable]] = {
'call_module': {
# Conv
nn.Conv1d: mod_weight_detach,
nn.Conv2d: mod_weight_detach,
nn.Conv3d: mod_weight_detach,
nni.ConvReLU1d: mod_0_weight_detach,
nni.ConvReLU2d: mod_0_weight_detach,
nni.ConvReLU3d: mod_0_weight_detach,
nnq.Conv1d: mod_weight_bias_0,
nniqat.ConvBn1d: mod_weight_detach,
nniqat.ConvBnReLU1d: mod_weight_detach,
nniq.ConvReLU1d: mod_weight_bias_0,
nnq.Conv2d: mod_weight_bias_0,
nnqat.Conv2d: mod_weight_detach,
nniqat.ConvBn2d: mod_weight_detach,
nniqat.ConvBnReLU2d: mod_weight_detach,
nniqat.ConvReLU2d: mod_weight_detach,
nniq.ConvReLU2d: mod_weight_bias_0,
nnq.Conv3d: mod_weight_bias_0,
nnqat.Conv3d: mod_weight_detach,
nniqat.ConvBn3d: mod_weight_detach,
nniqat.ConvBnReLU3d: mod_weight_detach,
nniqat.ConvReLU3d: mod_weight_detach,
nniq.ConvReLU3d: mod_weight_bias_0,
# Linear
nn.Linear: mod_weight_detach,
nnq.Linear: mod_weight_bias_0,
nni.LinearReLU: mod_0_weight_detach,
nniq.LinearReLU: mod_weight_bias_0,
nnqat.Linear: mod_weight_detach,
nnqd.Linear: mod_weight_bias_0,
nniqat.LinearReLU: mod_weight_detach,
nn.modules.linear.NonDynamicallyQuantizableLinear: mod_weight_detach,
# LSTM
nn.LSTM: get_lstm_weight,
nnqd.LSTM: get_qlstm_weight,
},
'call_function': {
# Conv
F.conv1d: get_conv_fun_weight,
F.conv2d: get_conv_fun_weight,
F.conv3d: get_conv_fun_weight,
toq.conv1d: get_qconv_fun_weight,
toq.conv2d: get_qconv_fun_weight,
toq.conv3d: get_qconv_fun_weight,
toq.conv1d_relu: get_qconv_fun_weight,
toq.conv2d_relu: get_qconv_fun_weight,
toq.conv3d_relu: get_qconv_fun_weight,
# Linear
F.linear: get_linear_fun_weight,
toq.linear: get_qlinear_fun_weight,
toq.linear_relu: get_qlinear_fun_weight,
},
}

return op_to_type_to_weight_extraction_fn

def extract_weight_from_node(
node: Node,
gm: GraphModule,
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
) -> Optional[NSSingleResultType]:
res_type = NSSingleResultValuesType.WEIGHT.value

Expand All @@ -219,8 +224,11 @@ def extract_weight_from_node(
if hasattr(gm, '_node_name_to_scope'):
fqn = gm._node_name_to_scope[node.name][0] # type: ignore[index]

if op_to_type_to_weight_extraction_fn is None:
op_to_type_to_weight_extraction_fn = get_op_to_type_to_weight_extraction_fn()

if node.op == 'call_function':
function_mapping = OP_TO_TYPE_TO_WEIGHT_EXTRACTION_FN['call_function']
function_mapping = op_to_type_to_weight_extraction_fn['call_function']
for target_fn_type, weight_extraction_fn in function_mapping.items():
if node.target == target_fn_type:
weight = weight_extraction_fn(node, gm)
Expand All @@ -239,7 +247,7 @@ def extract_weight_from_node(
# for call_module, we need to look up the modules to do the type check
assert isinstance(node.target, str)
mod = getattr_from_fqn(gm, node.target)
module_mapping = OP_TO_TYPE_TO_WEIGHT_EXTRACTION_FN['call_module']
module_mapping = op_to_type_to_weight_extraction_fn['call_module']
for target_mod_type, weight_extraction_fn in module_mapping.items():
if type(mod) == target_mod_type:
weight = weight_extraction_fn(mod)
Expand Down

0 comments on commit 04c95a0

Please sign in to comment.