Skip to content

Support multi-module exports in Inspector and ETRecord #8336

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions devtools/etrecord/_etrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,10 @@ def generate_etrecord(
edge_dialect_program.exported_program,
)
else:
raise RuntimeError(
f"Unsupported type of edge_dialect_program passed in {type(edge_dialect_program)}."
)
if export_modules is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure i fully understand why you needed to add this check here?

raise RuntimeError(
f"Unsupported type of edge_dialect_program passed in {type(edge_dialect_program)}."
)

# When a BundledProgram is passed in, extract the reference outputs and save in a file
if isinstance(executorch_program, BundledProgram):
Expand Down
33 changes: 25 additions & 8 deletions devtools/inspector/_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,8 @@ def __init__(
Callable[[Union[int, str], Union[int, float]], Union[int, float]]
] = None,
enable_module_hierarchy: bool = False,
module_name: Optional[str] = None,
method_name: Optional[str] = None,
) -> None:
r"""
Initialize an `Inspector` instance with the underlying `EventBlock`\ s populated with data from the provided ETDump path or binary,
Expand All @@ -995,6 +997,8 @@ def __init__(
delegate_time_scale_converter: Optional function to convert the time scale of delegate profiling data. If not given, use the conversion ratio of
target_time_scale/source_time_scale.
enable_module_hierarchy: Enable submodules in the operator graph. Defaults to False.
module_name: Optional module name to inspect (used with multi-module exports).
method_name: Optional method name to inspect (used with multi-module exports).

Returns:
None
Expand Down Expand Up @@ -1059,9 +1063,13 @@ def __init__(
# Key str is method name; value is list of ProgramOutputs because of list of test cases
self._reference_outputs: Dict[str, List[ProgramOutput]] = {}
self._enable_module_hierarchy = enable_module_hierarchy
self._consume_etrecord()
self._consume_etrecord(module_name, method_name)

def _consume_etrecord(self) -> None:
def _consume_etrecord(
self,
module_name: Optional[str] = None,
method_name: Optional[str] = None,
) -> None:
"""
If an ETRecord is provided, connect it to the EventBlocks and populate the Event metadata.

Expand All @@ -1081,15 +1089,23 @@ def _consume_etrecord(self) -> None:
bundled_input_index of the EventBlock.
"""

if self._etrecord is None:
return
if method_name is None and module_name is None:
method_name = FORWARD
edge_dialect_graph_key = EDGE_DIALECT_GRAPH_KEY
elif method_name is None or module_name is None:
raise ValueError(
"Either both method_name and module_name should be provided or neither should be provided"
)
else:
method_name = method_name
edge_dialect_graph_key = f"{module_name}/{method_name}"

# (1) Debug Handle Symbolification
for event_block in self.event_blocks:
event_block._gen_resolve_debug_handles(
self._etrecord._debug_handle_map[FORWARD],
self._etrecord._debug_handle_map[method_name],
(
self._etrecord._delegate_map[FORWARD]
self._etrecord._delegate_map[method_name]
if self._etrecord._delegate_map is not None
else None
),
Expand All @@ -1099,9 +1115,10 @@ def _consume_etrecord(self) -> None:
self.op_graph_dict = gen_graphs_from_etrecord(
etrecord=self._etrecord,
enable_module_hierarchy=self._enable_module_hierarchy,
edge_dialect_graph_key=edge_dialect_graph_key,
)
debug_handle_to_op_node_map = create_debug_handle_to_op_node_mapping(
self.op_graph_dict[EDGE_DIALECT_GRAPH_KEY],
self.op_graph_dict[edge_dialect_graph_key],
)
for event_block in self.event_blocks:
for event in event_block.events:
Expand All @@ -1116,7 +1133,7 @@ def _consume_etrecord(self) -> None:
for event_block in self.event_blocks:
index = event_block.bundled_input_index
if index is not None:
event_block.reference_output = self._reference_outputs[FORWARD][
event_block.reference_output = self._reference_outputs[method_name][
index
]

Expand Down
6 changes: 4 additions & 2 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ def is_debug_output(value: Value) -> bool:


def gen_graphs_from_etrecord(
etrecord: ETRecord, enable_module_hierarchy: bool = False
etrecord: ETRecord,
enable_module_hierarchy: bool = False,
edge_dialect_graph_key: str = EDGE_DIALECT_GRAPH_KEY,
) -> Mapping[str, OperatorGraph]:
op_graph_map = {}
if etrecord.graph_map is not None:
Expand All @@ -248,7 +250,7 @@ def gen_graphs_from_etrecord(
for name, exported_program in etrecord.graph_map.items()
}
if etrecord.edge_dialect_program is not None:
op_graph_map[EDGE_DIALECT_GRAPH_KEY] = FXOperatorGraph.gen_operator_graph(
op_graph_map[edge_dialect_graph_key] = FXOperatorGraph.gen_operator_graph(
etrecord.edge_dialect_program.graph_module,
enable_module_hierarchy=enable_module_hierarchy,
)
Expand Down
14 changes: 14 additions & 0 deletions devtools/inspector/inspector_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ def main() -> None:
required=False,
help="Provide an optional tsv file path.",
)
parser.add_argument(
"--method_name",
required=False,
default=None,
help="Method Name to inspect (used with multi-module exports)",
)
parser.add_argument(
"--module_name",
required=False,
default=None,
help="Module Name to inspect (used with multi-module exports)",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is module name intended to be here?

)
parser.add_argument("--compare_results", action="store_true")

args = parser.parse_args()
Expand All @@ -58,6 +70,8 @@ def main() -> None:
debug_buffer_path=args.debug_buffer_path,
source_time_scale=TimeScale(args.source_time_scale),
target_time_scale=TimeScale(args.target_time_scale),
module_name=args.module_name,
method_name=args.method_name,
)
inspector.print_data_tabular()
if args.tsv_path:
Expand Down
Loading