Skip to content

Commit

Permalink
Add util to print out ops and frequency (#2983)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2983

As titled.

Reviewed By: cccclai

Differential Revision: D56001227

fbshipit-source-id: cefef12662e03171136f03138fb814d61a28a0f3
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Apr 12, 2024
1 parent 5b7c4ba commit b1edc3d
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 4 deletions.
6 changes: 3 additions & 3 deletions examples/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Any, Callable
from typing import Any, Callable, Tuple

import torch

Expand Down Expand Up @@ -48,7 +48,7 @@ def export_to_edge(
inputs: Any,
pt2_quant: bool = False,
dump_graphs: bool = False,
) -> EdgeProgramManager:
) -> Tuple[EdgeProgramManager, ExportedProgram]:
# Export the model into an ExportedProgram.
expo_program = export_program(model, inputs, pt2_quant)

Expand All @@ -65,4 +65,4 @@ def export_to_edge(
f"Edge graph:\n{edge_prog_manager.exported_program().graph_module.graph}"
)

return edge_prog_manager
return edge_prog_manager, expo_program
12 changes: 11 additions & 1 deletion examples/cadence/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ReplacePT2DequantWithCadenceDequant,
ReplacePT2QuantWithCadenceQuant,
)
from .utils import print_ops_info


FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
Expand All @@ -47,7 +48,9 @@ def export_model(model, example_inputs):
QuantFusion(patterns)(converted_model)

# Get edge program (note: the name will change to export_to_cadence in future PRs)
edge_prog_manager = export_to_edge(converted_model, example_inputs, pt2_quant=True)
edge_prog_manager, expo_prog = export_to_edge(
converted_model, example_inputs, pt2_quant=True
)

# Run a couple required passes for quant/dequant ops
cadence_prog_manager = edge_prog_manager.transform(
Expand All @@ -61,5 +64,12 @@ def export_model(model, example_inputs):
f"Final exported graph module:\n{exec_prog.exported_program().graph_module}"
)

# Print some information to terminal
print_ops_info(
expo_prog.graph_module,
edge_prog_manager.exported_program().graph_module,
cadence_prog_manager.exported_program().graph_module,
)

# Save the program as CadenceDemoModel.pte
save_pte_program(exec_prog, "CadenceDemoModel")
125 changes: 125 additions & 0 deletions examples/cadence/aot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import operator
from typing import Dict

import torch
from executorch.exir import memory
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
from tabulate import tabulate


# Get the output size of a 1D convolution given the input size and parameters
Expand All @@ -23,3 +31,120 @@ def get_conv1d_output_size(
lout = (L + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1

return torch.Size((in_size[0], out_channels, lout))


# Return the overload packet for the edge op
def get_edge_overload_packet(edge_op: EdgeOpOverload) -> EdgeOpOverloadPacket:
edge_op_namespace, edge_op_name = (
edge_op.namespace,
edge_op._schema.name.split("::")[1],
)
edge_op_overload_packet = getattr(
getattr(exir_ops.edge, edge_op_namespace), edge_op_name
)
return edge_op_overload_packet


# Get the frequency list of ops in a graph module
def get_ops_count(graph_module: torch.fx.GraphModule) -> Dict[str, int]:
freq = {}
# Loop over nodes to count the number of times each op occurs
for node in graph_module.graph.nodes:
if node.op == "call_function":
# Ignore getitem, alloc and view cases, we only want actual operations
if (
node.target == operator.getitem
or node.target.__name__ == "alloc"
or node.target == memory.view
):
continue
# If the op is already present, increment the count
if get_edge_overload_packet(node.target).__name__ in freq:
freq[get_edge_overload_packet(node.target).__name__] += 1
# else, add a new entry
else:
freq[get_edge_overload_packet(node.target).__name__] = 1
return freq


# Print the ops and how many times they occur multiple graph modules:
# from export, from to_edge, and from Jarvis. Print the available
# implementations for each op, and error out if the op is not supported.
def print_ops_info(
export_gm: torch.fx.GraphModule,
to_edge_gm: torch.fx.GraphModule,
jarvis_gm: torch.fx.GraphModule,
):
export_ops_count = get_ops_count(export_gm)
to_edge_ops_count = get_ops_count(to_edge_gm)
jarvis_ops_count = get_ops_count(jarvis_gm)

# De-duplicate the "<op>" and "<op>_copy" ops
keys_to_delete_and_add = []
for k1 in export_ops_count:
for k2 in {**to_edge_ops_count, **jarvis_ops_count}:
if k2.startswith(k1):
keys_to_delete_and_add.append((k1, k2))
break

for k in keys_to_delete_and_add:
export_ops_count[k[1]] = export_ops_count[k[0]]
del export_ops_count[k[0]]

removed_ops = []
# Get the counts of the ops that are removed from the final graph
for k in {**export_ops_count, **to_edge_ops_count}:
if k not in jarvis_ops_count:
removed_ops.append(k)

# Create a dict of ops and their counts to pass to tabulate
ops_count = [
[
op,
jarvis_ops_count[op],
to_edge_ops_count[op] if op in to_edge_ops_count else 0,
export_ops_count[op] if op in export_ops_count else 0,
]
for op in jarvis_ops_count
]
sorted_ops_count = sorted(ops_count, key=lambda x: x[1], reverse=True)

# Create a dict of deleted ops and their counts to pass to tabulate
removed_ops_count = [
[
op,
0,
to_edge_ops_count[op] if op in to_edge_ops_count else 0,
export_ops_count[op] if op in export_ops_count else 0,
]
for op in removed_ops
]

# Print the final ops and their counts in a tabular format
logging.info(
tabulate(
sorted_ops_count,
headers=[
"Final Operators ", # one character longer than the longest op name
"Jarvis (Final) Graph",
"To_edge Graph",
"Export Graph",
],
tablefmt="outline",
)
)

# Print the removed ops and their counts in a tabular format (if any)
if removed_ops != []:
logging.info(
tabulate(
removed_ops_count,
headers=[
"Deleted Operators ", # one character longer than the longest op name
"Jarvis (Final) Graph",
"To_edge Graph",
"Export Graph",
],
tablefmt="outline",
)
)

0 comments on commit b1edc3d

Please sign in to comment.