Skip to content

[Do Not Merge] multi partitioned graph in one qnn delegate #8175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backends/qualcomm/aot/python/PyQnnManagerAdaptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class PyQnnManager {
std::vector<uint8_t> tensor_data;
std::vector<uint8_t*> tensor_ptr;
std::vector<uint64_t> tensor_size;
std::unordered_map<std::string, int> partition_num;
uint64_t total_tensor_size = 0;
for (size_t i = 0; i < qcirs.size(); ++i) {
py::buffer_info info(py::buffer(qcirs[i].cast<py::bytes>()).request());
Expand Down Expand Up @@ -147,7 +148,8 @@ class PyQnnManager {
&params));
}
graphs.emplace_back(qcir::CreateGraphDirect(
builder_, graph->name()->str().c_str(), &nodes, &tensors));
builder_, (graph->name()->str() + "_" + std::to_string(partition_num[graph->name()->str()])).c_str(), &nodes, &tensors));
partition_num[graph->name()->str()] = partition_num[graph->name()->str()] + 1;
}
}

Expand Down
4 changes: 4 additions & 0 deletions backends/qualcomm/qnn_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
from executorch.backends.qualcomm.builders.node_visitor import get_node_visitors
from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader
from executorch.backends.qualcomm.partition.utils import generate_qnn_executorch_option
from executorch.backends.qualcomm.serialization.qc_schema_serialize import (
flatbuffer_to_option,
option_to_flatbuffer,
)
from executorch.exir.backend.backend_details import (
BackendDetails,
CompileSpec,
Expand Down
20 changes: 13 additions & 7 deletions backends/qualcomm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,7 @@ def generate_multi_graph_program(
executorch_in_order,
executorch_out_order,
) = ({}, {}, {}, {}, {})
# graph name will be suffixed with _{num}
qnn_mgr = PyQnnManagerAdaptor.QnnManager(
generate_qnn_executorch_option(compiler_specs), processed_bytes
)
Expand All @@ -831,15 +832,16 @@ def generate_multi_graph_program(

# We need to obtain the order of the IOs to correctly map QNN with nn.module
for graph_name in graph_names:
ori_graph_name, cur_idx = "_".join(graph_name.split("_")[:-1]), int(graph_name.split("_")[-1])
if input_nodes_dict:
# input
input_names = [node.name for node in input_nodes_dict[graph_name]]
input_names = [node.name for node in input_nodes_dict[ori_graph_name][cur_idx]]
qnn_input_names = [
wrapper.GetName() for wrapper in graph_inputs[graph_name]
]
# The input of intermideate module including call_function node
# could not be reorder by node name
if len(input_names) == len(qnn_input_names):
if len(input_names) == len(qnn_input_names) and cur_idx == 0:
input_order_list = []
for input_name in input_names:
# e.g., input_0_tokens_0
Expand Down Expand Up @@ -868,7 +870,7 @@ def generate_multi_graph_program(
bundle_progs = [
from_context_binary(
ctx_path=binary_info,
op_name=f"loader_{graph_name}_{int(time.time())}",
op_name=graph_name,
soc_model=compiler_options.soc_info.soc_model,
custom_info={
"graph_inputs": graph_inputs[graph_name],
Expand All @@ -877,10 +879,10 @@ def generate_multi_graph_program(
"qnn_in_order": qnn_in_order.get(graph_name, None),
"executorch_in_order": executorch_in_order.get(graph_name, None),
"executorch_out_order": executorch_out_order.get(graph_name, None),
},
)
},
)
for graph_name in graph_names
]
]
# leverage ExecutorchProgramManager for generating pte with multi-methods
edge_prog_mgr = to_edge(
{
Expand All @@ -898,11 +900,15 @@ def generate_multi_graph_program(
n.meta[OpContextLoader.meta_ctx_bin] = binary_info
break

opt = flatbuffer_to_option(compiler_specs[0].value)
opt.graph_name = "multi_graph"
new_opt = option_to_flatbuffer(opt)
compiler_specs[0].value = new_opt
edge_prog_mgr = edge_prog_mgr.to_backend(QnnPartitioner(compiler_specs))
exec_prog = edge_prog_mgr.to_executorch(
config=backend_config or ExecutorchBackendConfig()
)
return exec_prog, bundle_progs
return exec_prog, bundle_progs, graph_names


def generate_composite_llama_program(
Expand Down
48 changes: 26 additions & 22 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import sys
import time
from collections import defaultdict
from functools import partial
from multiprocessing.connection import Client

Expand Down Expand Up @@ -626,7 +627,7 @@ def compile(args, pte_filename, tokenizer):
call_delegate_inputs_dict = {name: [] for name in graph_names}
call_delegate_node_name_dict = {name: [] for name in graph_names}
outputs_dict = {name: [] for name in graph_names}
input_nodes_dict = {name: [] for name in graph_names}
input_nodes_dict = defaultdict(list)
for prog, graph_name in zip(exported_programs, graph_names):
for node in prog.graph_module.graph.nodes:
if (
Expand Down Expand Up @@ -654,8 +655,11 @@ def compile(args, pte_filename, tokenizer):

if args.num_sharding > 0:
bundle_progs_list = []
processed_bytes = []
call_delegate_node = []

for num in range(args.num_sharding - 1, -1, -1):
processed_bytes = []
cur_inputs = []
for prog, graph_name in zip(exported_programs, graph_names):
processed_bytes.append(
getattr(
Expand All @@ -669,28 +673,28 @@ def compile(args, pte_filename, tokenizer):
if node.op == "get_attr"
and node.name == f"lowered_module_{num}"
]
input_nodes_dict[graph_name] = [
node
for node in call_delegate_node[0].args
if node.op == "placeholder"
cur_inputs =[
node for node in call_delegate_node[0].args if node.op == "placeholder"
]
input_nodes_dict[graph_name].append(cur_inputs)
prog_mgr, bundle_progs, partitioned_graph_names = generate_multi_graph_program(
compiler_specs=compiler_specs[0],
processed_bytes=processed_bytes,
input_nodes_dict=input_nodes_dict,
backend_config=executorch_config,
constant_methods=llama_instance_list[
1
].llama_meta, # kv method meta
)

prog_mgr, bundle_progs = generate_multi_graph_program(
compiler_specs=compiler_specs[0],
processed_bytes=processed_bytes,
input_nodes_dict=input_nodes_dict,
backend_config=executorch_config,
constant_methods=llama_instance_list[
1
].llama_meta, # kv method meta
)
bundle_progs_list.append(bundle_progs)
for graph_name in graph_names:
lower_module_dict[graph_name].append(
prog_mgr.exported_program(graph_name).graph_module._modules.get(
"lowered_module_0"
)
bundle_progs_list.append(bundle_progs)
for graph_name in partitioned_graph_names:
ori_graph_name, cur_idx = "_".join(graph_name.split("_")[:-1]), int(graph_name.split("_")[-1])
lower_module_dict[ori_graph_name].append(
prog_mgr.exported_program(f"{graph_name}").graph_module._modules.get(
"lowered_module_0"
)
)

exec_prog = generate_composite_llama_program(
graph_names=graph_names,
Expand Down Expand Up @@ -723,7 +727,7 @@ def compile(args, pte_filename, tokenizer):
if node.op == "output"
]

prog_mgr, _ = generate_multi_graph_program(
prog_mgr, _, _ = generate_multi_graph_program(
compiler_specs=compiler_specs[0],
processed_bytes=processed_bytes,
input_nodes_dict=input_nodes_dict,
Expand Down
4 changes: 2 additions & 2 deletions examples/qualcomm/oss_scripts/llama/model/static_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import torch.nn as nn
import torch.nn.functional as F
from executorch.examples.models.llama.llama_transformer import (
ModelArgs,
precompute_freqs_cis,
ModelArgs
)
from executorch.examples.models.llama.rope import precompute_freqs_cis


def apply_rotary_emb_single(
Expand Down
4 changes: 4 additions & 0 deletions runtime/executor/method.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,10 @@ Error Method::execute_instruction() {
}
} break;
case executorch_flatbuffer::InstructionArguments::DelegateCall: {
ET_LOG(Info, "CHECK n_delegate_: %zu", n_delegate_);
ET_LOG(Info, "CHECK n_chains_: %zu", n_chains_);
ET_LOG(Info, "CHECK num instructions of cur_chain: %zu", instructions->size());

EXECUTORCH_SCOPE_PROF("DELEGATE_CALL");
internal::EventTracerProfileOpScope event_tracer_op_scope =
internal::EventTracerProfileOpScope(event_tracer_, "DELEGATE_CALL");
Expand Down
Loading