diff --git a/test/quantization/fx/test_numeric_suite_fx.py b/test/quantization/fx/test_numeric_suite_fx.py index a83c5cb5eefe0..25f5ae09fc189 100644 --- a/test/quantization/fx/test_numeric_suite_fx.py +++ b/test/quantization/fx/test_numeric_suite_fx.py @@ -2034,6 +2034,34 @@ def forward(self, x): m, (torch.randn(1, 1, 4, 4),), results_len=0) + def test_linear_kwargs_shadow(self): + + class M(nn.Module): + def __init__(self): + super().__init__() + self.w1 = nn.Parameter(torch.empty(4, 4)) + self.b1 = nn.Parameter(torch.zeros(4)) + torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) + + def forward(self, x): + x = F.linear(input=x, weight=self.w1, bias=self.b1) + return x + + # note: FX graph mode quantization does not have good support + # for kwargs-only right now, so we pass in two unquantized + # models + m = M().eval() + mt = torch.fx.symbolic_trace(m) + mt_copy = copy.deepcopy(mt) + + mt_shadows_mt_copy = add_shadow_loggers( + 'a', mt, 'b', mt_copy, OutputLogger) + + mt_shadows_mt_copy(torch.randn(4, 4)) + act_compare_dict = extract_shadow_logger_info( + mt_shadows_mt_copy, OutputLogger, 'b') + self.assertTrue(len(act_compare_dict) == 1) + class TestFXNumericSuiteCoreAPIsModels(FXNumericSuiteQuantizationTestCase): """ diff --git a/torch/ao/ns/fx/graph_passes.py b/torch/ao/ns/fx/graph_passes.py index 23e235c891db5..2709c73ce8517 100644 --- a/torch/ao/ns/fx/graph_passes.py +++ b/torch/ao/ns/fx/graph_passes.py @@ -27,6 +27,40 @@ from typing import Dict, Tuple, Callable, List, Any, Union, Optional, Set +def _get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node: + """ + Given a node, gets the n'th input to that node, normalizing + args and kwargs to the best of its ability. + """ + try: + norm_args_and_kwargs = node.normalized_arguments( + gm, normalize_to_only_use_kwargs=True) + if norm_args_and_kwargs is not None: + norm_args, norm_kwargs = norm_args_and_kwargs + assert len(norm_args) + len(norm_kwargs) > idx + if idx < len(norm_args): + return norm_args[idx] + else: + # note: in Python 3.7+ dicts are ordered + return list(norm_kwargs.values())[idx] + else: + assert len(node.args) + len(node.kwargs) > idx + if idx < len(node.args): + return node.args[idx] # type: ignore[return-value] + else: + kwargs_idx = idx + len(node.args) + return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value] + except RuntimeError: + # this RuntimeError happens when node argument normalization + # requires typehints to proceed, such as for torch.add where + # either the first, second or both arguments could be tensors + assert len(node.args) + len(node.kwargs) > idx + if idx < len(node.args): + return node.args[idx] # type: ignore[return-value] + else: + kwargs_idx = idx + len(node.args) + return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value] + def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]: fqn = None if hasattr(gm, '_node_name_to_scope'): @@ -38,8 +72,7 @@ def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]: assert isinstance(node.target, str) module = getattr_from_fqn(gm, node.target) if is_activation_post_process(module): - assert isinstance(node.args[0], Node) - node_to_use_for_fqn = node.args[0] + node_to_use_for_fqn = _get_normalized_nth_input(node, gm, 0) fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0] # type: ignore[index] return fqn # type: ignore[return-value] @@ -104,7 +137,7 @@ def load_arg(a): for node in gm.graph.nodes: if node.op == 'output': - new_graph.output(map_arg(node.args[0], load_arg)) + new_graph.output(map_arg(_get_normalized_nth_input(node, gm, 0), load_arg)) continue if ( @@ -121,7 +154,7 @@ def load_arg(a): # second (x + 1 versus 1 + x). arg_indices_to_log = get_arg_indices_of_inputs_to_log(node) for node_arg_idx in arg_indices_to_log: - node_arg = node.args[node_arg_idx] + node_arg = _get_normalized_nth_input(node, gm, node_arg_idx) if type(node_arg) == Node: # create a single input logger prev_node = env[node_arg.name] @@ -133,7 +166,7 @@ def load_arg(a): fqn=fqn) elif type(node_arg) == torch.fx.immutable_collections.immutable_list: # create N input loggers, one for each node - for arg_idx, arg in enumerate(node_arg): + for arg_idx, arg in enumerate(node_arg): # type: ignore[var-annotated, arg-type] prev_node = env[arg.name] env[prev_node.name] = _insert_logger_after_node( prev_node, gm, logger_cls, '_ns_logger_', node.name, @@ -341,19 +374,23 @@ def _copy_node_from_a_to_c( assert node_a.target in ('dequantize', 'to'), \ f"target {node_a.target} is not implemented" if node_a.target == 'dequantize': - arg_copy = _copy_node_from_a_to_c(node_a.args[0], gm_a, gm_b, graph_c) # type: ignore[arg-type] + arg_copy = _copy_node_from_a_to_c( + _get_normalized_nth_input(node_a, gm_a, 0), + gm_a, gm_b, graph_c) # type: ignore[arg-type] node_a_copy_name = \ get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b) node_a_copy = graph_c.create_node( node_a.op, node_a.target, (arg_copy,), {}, node_a_copy_name) return node_a_copy else: # to - arg_copy = _copy_node_from_a_to_c(node_a.args[0], gm_a, gm_b, graph_c) # type: ignore[arg-type] + arg_copy = _copy_node_from_a_to_c( + _get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c) # type: ignore[arg-type] node_a_copy_name = \ get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b) node_a_copy = graph_c.create_node( - node_a.op, node_a.target, (arg_copy, node_a.args[1]), {}, - node_a_copy_name) + node_a.op, node_a.target, + (arg_copy, _get_normalized_nth_input(node_a, gm_a, 1)), + {}, node_a_copy_name) return node_a_copy else: @@ -375,7 +412,7 @@ def _can_insert_copy_of_subgraph_a( cur_node = subgraph_a.end_node while cur_node != subgraph_a.start_node: nodes.append(cur_node) - cur_node = cur_node.args[0] # type: ignore[assignment] + cur_node = _get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment] nodes.append(cur_node) nodes.reverse() @@ -396,14 +433,40 @@ def _can_insert(node_a_arg, gm_a): # For each node, check if we handle the copy behavior. This follows the # logic in `_insert_copy_of_subgraph_a_after_input_node_c`. - for node_a_arg in nodes[0].args[num_non_param_args_node_a:]: - if not _can_insert(node_a_arg, gm_a): - return False + for node_a in nodes: - for node in nodes[1:]: - for node_a_arg in node.args[1:]: - if not _can_insert(node_a_arg, gm_a): - return False + local_num_non_param_args_node_a = num_non_param_args_node_a \ + if node_a is nodes[0] else 1 + + norm_args_kwargs = node_a.normalized_arguments( + gm_a, normalize_to_only_use_kwargs=True) + if norm_args_kwargs is not None: + norm_args, norm_kwargs = norm_args_kwargs + else: + norm_args, norm_kwargs = node_a.args, node_a.kwargs + + cur_idx = 0 + + while cur_idx < len(norm_args): + if cur_idx == 0: + pass + elif cur_idx == 1 and local_num_non_param_args_node_a == 2: + pass + else: + if not _can_insert(norm_args[cur_idx], gm_a): + return False + cur_idx += 1 + + for kwarg_name, kwarg_val in norm_kwargs.items(): + # stitch the inputs from base graph + if cur_idx == 0: + pass + elif cur_idx == 1 and local_num_non_param_args_node_a == 2: + pass + else: + if not _can_insert(kwarg_val, gm_a): + return False + cur_idx += 1 return True @@ -429,7 +492,7 @@ def _insert_copy_of_subgraph_a_after_input_node_c( nodes_of_a = [subgraph_a.end_node] cur_node = subgraph_a.end_node while cur_node != subgraph_a.start_node: - cur_node = cur_node.args[0] # type: ignore[assignment] + cur_node = _get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment] nodes_of_a.insert(0, cur_node) # go through nodes of a in order, and insert them into the graph of c @@ -514,44 +577,60 @@ def _insert_copy_of_node_a_after_input_node_c( assert isinstance(input_node_c, list) graph_c = input_node_c[0].graph - # generically handle all args and kwargs except for the input - # Note: this hasn't been tested with many ops, logic may change. - new_args: List[Any] = [] - # assumes that the first arg is the input - num_non_param_args = 1 if input_node_c_2 is None else 2 - for node_a_arg in node_a.args[num_non_param_args:]: - if isinstance(node_a_arg, Node): - arg_a = return_first_non_observer_node(node_a_arg, gm_a) - node_a_arg_copy = _copy_node_from_a_to_c(arg_a, gm_a, gm_b, graph_c) - new_args.append(node_a_arg_copy) - elif isinstance(node_a_arg, (int, float, torch.dtype)): - new_args.append(node_a_arg) - elif isinstance(node_a_arg, (list, tuple)): - for el in node_a_arg: + norm_args_kwargs = node_a.normalized_arguments( + gm_a, normalize_to_only_use_kwargs=True) + if norm_args_kwargs is not None: + norm_args, norm_kwargs = norm_args_kwargs + else: + norm_args, norm_kwargs = node_a.args, node_a.kwargs + + new_args = [] + new_kwargs = {} + + def _copy_arg(arg): + # copy the other inputs from the other graph + if isinstance(arg, Node): + arg = return_first_non_observer_node(arg, gm_a) + arg = _copy_node_from_a_to_c(arg, gm_a, gm_b, graph_c) + return arg + elif isinstance(arg, (int, float, torch.dtype)): + return arg + elif isinstance(kwarg_val, (list, tuple)): + for el in kwarg_val: assert not isinstance(el, Node), \ "handling of Node inside list is not implemented" - new_args.append(node_a_arg) + return arg else: raise AssertionError( - f"handling for arg of type {type(node_a_arg)} is not implemented") - - new_kwargs: Dict[str, Any] = {} - for node_a_k, node_a_kwarg in node_a.kwargs.items(): - if isinstance(node_a_kwarg, Node): - kwarg_a = return_first_non_observer_node(node_a_kwarg, gm_a) - node_a_kwarg_copy = _copy_node_from_a_to_c(kwarg_a, gm_a, gm_b, graph_c) - new_kwargs[node_a_k] = node_a_kwarg_copy + f"handling for kwarg of type {type(kwarg_val)} is not implemented") + + cur_idx = 0 + + while cur_idx < len(norm_args): + if cur_idx == 0: + new_arg = input_node_c + elif cur_idx == 1 and input_node_c_2 is not None: + new_arg = input_node_c_2 else: - new_kwargs[node_a_k] = node_a_kwarg + new_arg = _copy_arg(norm_args[cur_idx]) + new_args.append(new_arg) + cur_idx += 1 + + for kwarg_name, kwarg_val in norm_kwargs.items(): + # stitch the inputs from base graph + if cur_idx == 0: + new_kwargs[kwarg_name] = input_node_c + elif cur_idx == 1 and input_node_c_2 is not None: + new_kwargs[kwarg_name] = input_node_c_2 + else: + new_kwargs[kwarg_name] = _copy_arg(kwarg_val) + cur_idx += 1 + + new_args = tuple(new_args) # type: ignore[assignment] node_a_shadows_c_name = \ get_new_attr_name_with_prefix(node_name_prefix)(gm_b) - if input_node_c_2: - input_node_c_args = [input_node_c, input_node_c_2] - else: - input_node_c_args = [input_node_c] - if node_a.op == 'call_module': # if target is a module, we point to the module from gm_b new_mod_copy_name = \ @@ -561,13 +640,13 @@ def _insert_copy_of_node_a_after_input_node_c( mod_a = getattr_from_fqn(gm_a, node_a.target) setattr(gm_b, new_mod_copy_name, mod_a) node_a_shadows_c = graph_c.create_node( - node_a.op, new_mod_copy_name, (*input_node_c_args, *new_args), + node_a.op, new_mod_copy_name, new_args, new_kwargs, node_a_shadows_c_name) return node_a_shadows_c else: assert node_a.op in ('call_function', 'call_method') node_a_shadows_c = graph_c.create_node( - node_a.op, node_a.target, (*input_node_c_args, *new_args), + node_a.op, node_a.target, new_args, new_kwargs, node_a_shadows_c_name) return node_a_shadows_c @@ -649,14 +728,6 @@ def load_arg(a): subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \ end_node_b_to_matched_subgraph_a_and_name[node_b] - if len(node_b.args) == 0: - print( - f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' + - f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' + - ', kwargs-only node not handled yet') - env_c[node_b.name] = graph_c.node_copy(node_b, load_arg) - continue - all_op_types_support_shadowing = ( op_type_supports_shadowing(subgraph_a.start_node) and op_type_supports_shadowing(node_b) @@ -727,21 +798,22 @@ def load_arg(a): # if necessary, log the input of node_c if should_log_inputs: - if isinstance(node_b.args[0], Node): - prev_node_c = env_c[node_b.args[0].name] + prev_node_b = _get_normalized_nth_input(node_b, gm_b, 0) + if isinstance(prev_node_b, Node): + prev_node_c = env_c[prev_node_b.name] env_c[prev_node_c.name] = _insert_logger_after_node( prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_', node_b.name, name_b, ref_name, ref_node_type_b, NSSingleResultValuesType.NODE_INPUT.value, index_within_arg=0, index_of_arg=0, fqn=fqn_base_b) - elif isinstance(node_b.args[0], list): + elif isinstance(prev_node_b, list): # first, save the prev_node instances, because they # will be overwritten in the env after the first logger # is added - prev_node_c_list = [env_c[arg.name] for arg in node_b.args[0]] + prev_node_c_list = [env_c[arg.name] for arg in prev_node_b] - for arg_idx, arg in enumerate(node_b.args[0]): + for arg_idx, arg in enumerate(prev_node_b): prev_node_c = prev_node_c_list[arg_idx] env_c[prev_node_c.name] = _insert_logger_after_node( prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_', @@ -751,7 +823,7 @@ def load_arg(a): fqn=fqn_base_b) else: # logging of inputs which are not lists is not supported yet - raise AssertionError(f"type {type(node_b.args[0])} is not handled yet") + raise AssertionError(f"type {type(prev_node_b)} is not handled yet") # subgraph so far: # # (prev_node_c)+ -> (logger_c_input)? @@ -777,13 +849,14 @@ def load_arg(a): # cast dtype from the dtype of node_c's input to the dtype of # node_a's input (dequant, etc) - prev_node_c = node_c.args[0] + # prev_node_c = node_c.args[0] + prev_node_c = _get_normalized_nth_input(node_c, gm_b, 0) if should_log_inputs: # skip the input logger when inserting a dtype cast if isinstance(prev_node_c, Node): - prev_node_c = prev_node_c.args[0] + prev_node_c = _get_normalized_nth_input(node_c, gm_b, 0) elif isinstance(prev_node_c, list): - prev_node_c = [arg.args[0] for arg in prev_node_c] + prev_node_c = [_get_normalized_nth_input(arg, gm_b, 0) for arg in prev_node_c] dtype_cast_node = _insert_dtype_cast_after_node( subgraph_a.start_node, node_c, prev_node_c, gm_a, gm_b, graph_c, node_b.name + '_dtype_cast_', logger_cls, @@ -838,7 +911,8 @@ def load_arg(a): node_c_second_non_param_arg = None num_non_param_args_node_a = get_number_of_non_param_args(subgraph_a.start_node, gm_a) if num_non_param_args_node_a == 2: - node_c_second_non_param_arg = node_c.args[1] + # node_c_second_non_param_arg = node_c.args[1] + node_c_second_non_param_arg = _get_normalized_nth_input(node_c, gm_b, 1) node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c( dtype_cast_node, node_c_second_non_param_arg, subgraph_a, gm_a, gm_b, node_c.name + '_shadow_copy_') @@ -860,8 +934,8 @@ def load_arg(a): # input_logger = env_c[dtype_cast_node.name] # Find the first node in the subgraph cur_node = node_a_shadows_c - while cur_node.args[0] != input_logger: - cur_node = cur_node.args[0] # type: ignore[assignment] + while _get_normalized_nth_input(cur_node, gm_b, 0) != input_logger: + cur_node = _get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment] if isinstance(input_logger, Node): input_logger_mod = getattr(gm_b, input_logger.name) input_logger_mod.ref_node_name = cur_node.name