14
14
import os
15
15
import sys
16
16
import time
17
+ from collections import defaultdict
17
18
from functools import partial
18
19
from multiprocessing .connection import Client
19
20
@@ -626,7 +627,7 @@ def compile(args, pte_filename, tokenizer):
626
627
call_delegate_inputs_dict = {name : [] for name in graph_names }
627
628
call_delegate_node_name_dict = {name : [] for name in graph_names }
628
629
outputs_dict = {name : [] for name in graph_names }
629
- input_nodes_dict = { name : [] for name in graph_names }
630
+ input_nodes_dict = defaultdict ( list )
630
631
for prog , graph_name in zip (exported_programs , graph_names ):
631
632
for node in prog .graph_module .graph .nodes :
632
633
if (
@@ -654,8 +655,11 @@ def compile(args, pte_filename, tokenizer):
654
655
655
656
if args .num_sharding > 0 :
656
657
bundle_progs_list = []
658
+ processed_bytes = []
659
+ call_delegate_node = []
660
+
657
661
for num in range (args .num_sharding - 1 , - 1 , - 1 ):
658
- processed_bytes = []
662
+ cur_inputs = []
659
663
for prog , graph_name in zip (exported_programs , graph_names ):
660
664
processed_bytes .append (
661
665
getattr (
@@ -669,28 +673,28 @@ def compile(args, pte_filename, tokenizer):
669
673
if node .op == "get_attr"
670
674
and node .name == f"lowered_module_{ num } "
671
675
]
672
- input_nodes_dict [graph_name ] = [
673
- node
674
- for node in call_delegate_node [0 ].args
675
- if node .op == "placeholder"
676
+ cur_inputs = [
677
+ node for node in call_delegate_node [0 ].args if node .op == "placeholder"
676
678
]
679
+ input_nodes_dict [graph_name ].append (cur_inputs )
680
+ prog_mgr , bundle_progs , partitioned_graph_names = generate_multi_graph_program (
681
+ compiler_specs = compiler_specs [0 ],
682
+ processed_bytes = processed_bytes ,
683
+ input_nodes_dict = input_nodes_dict ,
684
+ backend_config = executorch_config ,
685
+ constant_methods = llama_instance_list [
686
+ 1
687
+ ].llama_meta , # kv method meta
688
+ )
677
689
678
- prog_mgr , bundle_progs = generate_multi_graph_program (
679
- compiler_specs = compiler_specs [0 ],
680
- processed_bytes = processed_bytes ,
681
- input_nodes_dict = input_nodes_dict ,
682
- backend_config = executorch_config ,
683
- constant_methods = llama_instance_list [
684
- 1
685
- ].llama_meta , # kv method meta
686
- )
687
- bundle_progs_list .append (bundle_progs )
688
- for graph_name in graph_names :
689
- lower_module_dict [graph_name ].append (
690
- prog_mgr .exported_program (graph_name ).graph_module ._modules .get (
691
- "lowered_module_0"
692
- )
690
+ bundle_progs_list .append (bundle_progs )
691
+ for graph_name in partitioned_graph_names :
692
+ ori_graph_name , cur_idx = "_" .join (graph_name .split ("_" )[:- 1 ]), int (graph_name .split ("_" )[- 1 ])
693
+ lower_module_dict [ori_graph_name ].append (
694
+ prog_mgr .exported_program (f"{ graph_name } " ).graph_module ._modules .get (
695
+ "lowered_module_0"
693
696
)
697
+ )
694
698
695
699
exec_prog = generate_composite_llama_program (
696
700
graph_names = graph_names ,
@@ -723,7 +727,7 @@ def compile(args, pte_filename, tokenizer):
723
727
if node .op == "output"
724
728
]
725
729
726
- prog_mgr , _ = generate_multi_graph_program (
730
+ prog_mgr , _ , _ = generate_multi_graph_program (
727
731
compiler_specs = compiler_specs [0 ],
728
732
processed_bytes = processed_bytes ,
729
733
input_nodes_dict = input_nodes_dict ,
0 commit comments