Skip to content

Commit c951390

Browse files
limintangfacebook-github-bot
authored andcommitted
Add a parameter to output delegate summary in llama export (#8174)
Summary: Print delegation summary when the verbose parameter is set. Differential Revision: D68991594
1 parent e63c923 commit c951390

File tree

5 files changed

+26
-11
lines changed

5 files changed

+26
-11
lines changed

devtools/backend_debug/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from executorch.devtools.backend_debug.delegation_info import (
88
DelegationBreakdown,
99
get_delegation_info,
10+
print_delegation_info,
1011
)
1112

12-
__all__ = ["DelegationBreakdown", "get_delegation_info"]
13+
__all__ = ["DelegationBreakdown", "get_delegation_info", "print_delegation_info"]

devtools/backend_debug/delegation_info.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import re
88
from collections import defaultdict
99
from dataclasses import asdict, dataclass
10+
from tabulate import tabulate
1011
from typing import Dict
1112

1213
import pandas as pd
@@ -174,3 +175,10 @@ def _insert_op_occurrences_dict(node_name: str, delegated: bool) -> None:
174175
num_delegated_subgraphs=delegated_subgraph_counter,
175176
delegation_by_operator=op_occurrences_dict,
176177
)
178+
179+
180+
def print_delegation_info(graph_module: torch.fx.GraphModule):
181+
delegation_info = get_delegation_info(graph_module)
182+
print(delegation_info.get_summary())
183+
df = delegation_info.get_operator_delegation_dataframe()
184+
print(tabulate(df, headers="keys", tablefmt="fancy_grid"))

examples/models/llama/export_llama_lib.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch
2424

2525
from executorch.backends.vulkan._passes.remove_asserts import remove_asserts
26-
from executorch.devtools.backend_debug import get_delegation_info
26+
from executorch.devtools.backend_debug import print_delegation_info
2727

2828
from executorch.devtools.etrecord import generate_etrecord
2929
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
@@ -46,7 +46,6 @@
4646
get_vulkan_quantizer,
4747
)
4848
from executorch.util.activation_memory_profiler import generate_memory_trace
49-
from tabulate import tabulate
5049

5150
from ..model_factory import EagerModelFactory
5251
from .source_transformation.apply_spin_quant_r1_r2 import (
@@ -801,12 +800,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
801800
for partitioner in partitioners:
802801
logging.info(f"--> {partitioner.__class__.__name__}")
803802

804-
def print_delegation_info(graph_module: torch.fx.GraphModule):
805-
delegation_info = get_delegation_info(graph_module)
806-
print(delegation_info.get_summary())
807-
df = delegation_info.get_operator_delegation_dataframe()
808-
print(tabulate(df, headers="keys", tablefmt="fancy_grid"))
809-
810803
additional_passes = []
811804
if args.model in TORCHTUNE_DEFINED_MODELS:
812805
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
get_soc_to_chipset_map,
5252
update_spill_fill_size,
5353
)
54+
55+
from executorch.devtools.backend_debug import print_delegation_info
5456
from executorch.examples.models.llama.source_transformation.quantize import (
5557
get_quant_embedding_transform,
5658
)
@@ -389,6 +391,7 @@ def lowering_modules(
389391
num_sharding=1,
390392
passes_job=OrderedDict(),
391393
shared_buffer=False,
394+
verbose=False,
392395
):
393396
executorch_config = ExecutorchBackendConfig(
394397
# For shared buffer, user must pass the memory address
@@ -440,6 +443,10 @@ def lowering_modules(
440443
edge_prog_mgr = edge_prog_mgr.to_backend(partitioner)
441444
if num_sharding > 1:
442445
update_spill_fill_size(edge_prog_mgr.exported_program())
446+
447+
if verbose:
448+
print_delegation_info(edge_prog_mgr.exported_program().graph_module)
449+
443450
exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config)
444451
with open(f"{work_space}/{self.pte_filename}.pte", "wb") as file:
445452
exec_prog_mgr.write_to_file(file)
@@ -667,6 +674,10 @@ def compile(args, pte_filename, tokenizer):
667674
)
668675
compiler_specs[0][0].value = option_to_flatbuffer(qnn_executorch_options)
669676

677+
if args.verbose:
678+
for exported_program in exported_programs:
679+
print_delegation_info(exported_program.graph_module)
680+
670681
executorch_config = ExecutorchBackendConfig(
671682
# For shared buffer, user must pass the memory address
672683
# which is allocated by RPC memory to executor runner.
@@ -980,6 +991,8 @@ def _build_parser():
980991
help="Fallback to cpu embedding operator and type of embedding quantization, '<bitwidth>,<groupsize>', e.g., '4,32'.",
981992
)
982993

994+
parser.add_argument("-v", "--verbose", action="store_true")
995+
983996
return parser
984997

985998

examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -557,8 +557,8 @@ SmartMaskIoMgr::SmartMaskIoMgr(
557557
const bool use_int64_token)
558558
: IoMgrBase(modules),
559559
shard_layers_({num_layers}),
560-
prefill_cache_len_(prefill_cache_len),
561560
kv_cache_len_(kv_cache_len),
561+
prefill_cache_len_(prefill_cache_len),
562562
vocab_size_(vocab_size),
563563
num_layers_(num_layers),
564564
head_dim_(head_dim),
@@ -1002,7 +1002,7 @@ void SmartMaskIoMgr::prepare_prefill_io(
10021002

10031003
// [O]: logits
10041004
int logit_index = 0;
1005-
Result<TensorInfo> logits = methods_meta[0]->output_tensor_meta(0);
1005+
Result<TensorInfo> logits = methods_meta[0]->output_tensor_meta(logit_index);
10061006
prefill_logits_ = std::make_unique<TensorImpl>(
10071007
logits->scalar_type(),
10081008
logits->sizes().size(),

0 commit comments

Comments
 (0)