From b1edc3db5699ad476b0049efa6dd8ed271f2341e Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Fri, 12 Apr 2024 09:48:37 -0700 Subject: [PATCH] Add util to print out ops and frequency (#2983) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/2983 As titled. Reviewed By: cccclai Differential Revision: D56001227 fbshipit-source-id: cefef12662e03171136f03138fb814d61a28a0f3 --- examples/cadence/aot/compiler.py | 6 +- examples/cadence/aot/export_example.py | 12 ++- examples/cadence/aot/utils.py | 125 +++++++++++++++++++++++++ 3 files changed, 139 insertions(+), 4 deletions(-) diff --git a/examples/cadence/aot/compiler.py b/examples/cadence/aot/compiler.py index c9df9ef2ba..36a5b30855 100644 --- a/examples/cadence/aot/compiler.py +++ b/examples/cadence/aot/compiler.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Any, Callable +from typing import Any, Callable, Tuple import torch @@ -48,7 +48,7 @@ def export_to_edge( inputs: Any, pt2_quant: bool = False, dump_graphs: bool = False, -) -> EdgeProgramManager: +) -> Tuple[EdgeProgramManager, ExportedProgram]: # Export the model into an ExportedProgram. expo_program = export_program(model, inputs, pt2_quant) @@ -65,4 +65,4 @@ def export_to_edge( f"Edge graph:\n{edge_prog_manager.exported_program().graph_module.graph}" ) - return edge_prog_manager + return edge_prog_manager, expo_program diff --git a/examples/cadence/aot/export_example.py b/examples/cadence/aot/export_example.py index 23cc9fd789..864df963f7 100644 --- a/examples/cadence/aot/export_example.py +++ b/examples/cadence/aot/export_example.py @@ -22,6 +22,7 @@ ReplacePT2DequantWithCadenceDequant, ReplacePT2QuantWithCadenceQuant, ) +from .utils import print_ops_info FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -47,7 +48,9 @@ def export_model(model, example_inputs): QuantFusion(patterns)(converted_model) # Get edge program (note: the name will change to export_to_cadence in future PRs) - edge_prog_manager = export_to_edge(converted_model, example_inputs, pt2_quant=True) + edge_prog_manager, expo_prog = export_to_edge( + converted_model, example_inputs, pt2_quant=True + ) # Run a couple required passes for quant/dequant ops cadence_prog_manager = edge_prog_manager.transform( @@ -61,5 +64,12 @@ def export_model(model, example_inputs): f"Final exported graph module:\n{exec_prog.exported_program().graph_module}" ) + # Print some information to terminal + print_ops_info( + expo_prog.graph_module, + edge_prog_manager.exported_program().graph_module, + cadence_prog_manager.exported_program().graph_module, + ) + # Save the program as CadenceDemoModel.pte save_pte_program(exec_prog, "CadenceDemoModel") diff --git a/examples/cadence/aot/utils.py b/examples/cadence/aot/utils.py index 73f863eed9..b511f95e80 100644 --- a/examples/cadence/aot/utils.py +++ b/examples/cadence/aot/utils.py @@ -4,7 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging +import operator +from typing import Dict + import torch +from executorch.exir import memory +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket +from tabulate import tabulate # Get the output size of a 1D convolution given the input size and parameters @@ -23,3 +31,120 @@ def get_conv1d_output_size( lout = (L + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1 return torch.Size((in_size[0], out_channels, lout)) + + +# Return the overload packet for the edge op +def get_edge_overload_packet(edge_op: EdgeOpOverload) -> EdgeOpOverloadPacket: + edge_op_namespace, edge_op_name = ( + edge_op.namespace, + edge_op._schema.name.split("::")[1], + ) + edge_op_overload_packet = getattr( + getattr(exir_ops.edge, edge_op_namespace), edge_op_name + ) + return edge_op_overload_packet + + +# Get the frequency list of ops in a graph module +def get_ops_count(graph_module: torch.fx.GraphModule) -> Dict[str, int]: + freq = {} + # Loop over nodes to count the number of times each op occurs + for node in graph_module.graph.nodes: + if node.op == "call_function": + # Ignore getitem, alloc and view cases, we only want actual operations + if ( + node.target == operator.getitem + or node.target.__name__ == "alloc" + or node.target == memory.view + ): + continue + # If the op is already present, increment the count + if get_edge_overload_packet(node.target).__name__ in freq: + freq[get_edge_overload_packet(node.target).__name__] += 1 + # else, add a new entry + else: + freq[get_edge_overload_packet(node.target).__name__] = 1 + return freq + + +# Print the ops and how many times they occur multiple graph modules: +# from export, from to_edge, and from Jarvis. Print the available +# implementations for each op, and error out if the op is not supported. +def print_ops_info( + export_gm: torch.fx.GraphModule, + to_edge_gm: torch.fx.GraphModule, + jarvis_gm: torch.fx.GraphModule, +): + export_ops_count = get_ops_count(export_gm) + to_edge_ops_count = get_ops_count(to_edge_gm) + jarvis_ops_count = get_ops_count(jarvis_gm) + + # De-duplicate the "" and "_copy" ops + keys_to_delete_and_add = [] + for k1 in export_ops_count: + for k2 in {**to_edge_ops_count, **jarvis_ops_count}: + if k2.startswith(k1): + keys_to_delete_and_add.append((k1, k2)) + break + + for k in keys_to_delete_and_add: + export_ops_count[k[1]] = export_ops_count[k[0]] + del export_ops_count[k[0]] + + removed_ops = [] + # Get the counts of the ops that are removed from the final graph + for k in {**export_ops_count, **to_edge_ops_count}: + if k not in jarvis_ops_count: + removed_ops.append(k) + + # Create a dict of ops and their counts to pass to tabulate + ops_count = [ + [ + op, + jarvis_ops_count[op], + to_edge_ops_count[op] if op in to_edge_ops_count else 0, + export_ops_count[op] if op in export_ops_count else 0, + ] + for op in jarvis_ops_count + ] + sorted_ops_count = sorted(ops_count, key=lambda x: x[1], reverse=True) + + # Create a dict of deleted ops and their counts to pass to tabulate + removed_ops_count = [ + [ + op, + 0, + to_edge_ops_count[op] if op in to_edge_ops_count else 0, + export_ops_count[op] if op in export_ops_count else 0, + ] + for op in removed_ops + ] + + # Print the final ops and their counts in a tabular format + logging.info( + tabulate( + sorted_ops_count, + headers=[ + "Final Operators ", # one character longer than the longest op name + "Jarvis (Final) Graph", + "To_edge Graph", + "Export Graph", + ], + tablefmt="outline", + ) + ) + + # Print the removed ops and their counts in a tabular format (if any) + if removed_ops != []: + logging.info( + tabulate( + removed_ops_count, + headers=[ + "Deleted Operators ", # one character longer than the longest op name + "Jarvis (Final) Graph", + "To_edge Graph", + "Export Graph", + ], + tablefmt="outline", + ) + )