Skip to content

Add util to print out ops and frequency #2983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Any, Callable
from typing import Any, Callable, Tuple

import torch

Expand Down Expand Up @@ -48,7 +48,7 @@ def export_to_edge(
inputs: Any,
pt2_quant: bool = False,
dump_graphs: bool = False,
) -> EdgeProgramManager:
) -> Tuple[EdgeProgramManager, ExportedProgram]:
# Export the model into an ExportedProgram.
expo_program = export_program(model, inputs, pt2_quant)

Expand All @@ -65,4 +65,4 @@ def export_to_edge(
f"Edge graph:\n{edge_prog_manager.exported_program().graph_module.graph}"
)

return edge_prog_manager
return edge_prog_manager, expo_program
12 changes: 11 additions & 1 deletion examples/cadence/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ReplacePT2DequantWithCadenceDequant,
ReplacePT2QuantWithCadenceQuant,
)
from .utils import print_ops_info


FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
Expand All @@ -47,7 +48,9 @@ def export_model(model, example_inputs):
QuantFusion(patterns)(converted_model)

# Get edge program (note: the name will change to export_to_cadence in future PRs)
edge_prog_manager = export_to_edge(converted_model, example_inputs, pt2_quant=True)
edge_prog_manager, expo_prog = export_to_edge(
converted_model, example_inputs, pt2_quant=True
)

# Run a couple required passes for quant/dequant ops
cadence_prog_manager = edge_prog_manager.transform(
Expand All @@ -61,5 +64,12 @@ def export_model(model, example_inputs):
f"Final exported graph module:\n{exec_prog.exported_program().graph_module}"
)

# Print some information to terminal
print_ops_info(
expo_prog.graph_module,
edge_prog_manager.exported_program().graph_module,
cadence_prog_manager.exported_program().graph_module,
)

# Save the program as CadenceDemoModel.pte
save_pte_program(exec_prog, "CadenceDemoModel")
125 changes: 125 additions & 0 deletions examples/cadence/aot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import operator
from typing import Dict

import torch
from executorch.exir import memory
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
from tabulate import tabulate


# Get the output size of a 1D convolution given the input size and parameters
Expand All @@ -23,3 +31,120 @@ def get_conv1d_output_size(
lout = (L + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1

return torch.Size((in_size[0], out_channels, lout))


# Return the overload packet for the edge op
def get_edge_overload_packet(edge_op: EdgeOpOverload) -> EdgeOpOverloadPacket:
edge_op_namespace, edge_op_name = (
edge_op.namespace,
edge_op._schema.name.split("::")[1],
)
edge_op_overload_packet = getattr(
getattr(exir_ops.edge, edge_op_namespace), edge_op_name
)
return edge_op_overload_packet


# Get the frequency list of ops in a graph module
def get_ops_count(graph_module: torch.fx.GraphModule) -> Dict[str, int]:
freq = {}
# Loop over nodes to count the number of times each op occurs
for node in graph_module.graph.nodes:
if node.op == "call_function":
# Ignore getitem, alloc and view cases, we only want actual operations
if (
node.target == operator.getitem
or node.target.__name__ == "alloc"
or node.target == memory.view
):
continue
# If the op is already present, increment the count
if get_edge_overload_packet(node.target).__name__ in freq:
freq[get_edge_overload_packet(node.target).__name__] += 1
# else, add a new entry
else:
freq[get_edge_overload_packet(node.target).__name__] = 1
return freq


# Print the ops and how many times they occur multiple graph modules:
# from export, from to_edge, and from Jarvis. Print the available
# implementations for each op, and error out if the op is not supported.
def print_ops_info(
export_gm: torch.fx.GraphModule,
to_edge_gm: torch.fx.GraphModule,
jarvis_gm: torch.fx.GraphModule,
):
export_ops_count = get_ops_count(export_gm)
to_edge_ops_count = get_ops_count(to_edge_gm)
jarvis_ops_count = get_ops_count(jarvis_gm)

# De-duplicate the "<op>" and "<op>_copy" ops
keys_to_delete_and_add = []
for k1 in export_ops_count:
for k2 in {**to_edge_ops_count, **jarvis_ops_count}:
if k2.startswith(k1):
keys_to_delete_and_add.append((k1, k2))
break

for k in keys_to_delete_and_add:
export_ops_count[k[1]] = export_ops_count[k[0]]
del export_ops_count[k[0]]

removed_ops = []
# Get the counts of the ops that are removed from the final graph
for k in {**export_ops_count, **to_edge_ops_count}:
if k not in jarvis_ops_count:
removed_ops.append(k)

# Create a dict of ops and their counts to pass to tabulate
ops_count = [
[
op,
jarvis_ops_count[op],
to_edge_ops_count[op] if op in to_edge_ops_count else 0,
export_ops_count[op] if op in export_ops_count else 0,
]
for op in jarvis_ops_count
]
sorted_ops_count = sorted(ops_count, key=lambda x: x[1], reverse=True)

# Create a dict of deleted ops and their counts to pass to tabulate
removed_ops_count = [
[
op,
0,
to_edge_ops_count[op] if op in to_edge_ops_count else 0,
export_ops_count[op] if op in export_ops_count else 0,
]
for op in removed_ops
]

# Print the final ops and their counts in a tabular format
logging.info(
tabulate(
sorted_ops_count,
headers=[
"Final Operators ", # one character longer than the longest op name
"Jarvis (Final) Graph",
"To_edge Graph",
"Export Graph",
],
tablefmt="outline",
)
)

# Print the removed ops and their counts in a tabular format (if any)
if removed_ops != []:
logging.info(
tabulate(
removed_ops_count,
headers=[
"Deleted Operators ", # one character longer than the longest op name
"Jarvis (Final) Graph",
"To_edge Graph",
"Export Graph",
],
tablefmt="outline",
)
)