@@ -731,53 +731,43 @@ def _contains_xnnpack_delegate(graph_module: torch.fx.GraphModule) -> bool:
731731
732732
733733def greedy (
734- graph_module : torch .fx .GraphModule ,
735734 alignment : int ,
736- graph_signature : Optional [ExportGraphSignature ] = None ,
737- alloc_graph_input : bool = True ,
738- alloc_graph_output : bool = True ,
735+ specs : Set [TensorSpec ],
736+ graph_module : torch .fx .GraphModule ,
737+ graph_signature : ExportGraphSignature ,
738+ extra_padding : int = 0 ,
739+ * ,
739740 allow_overlapping_allocations : bool = True ,
740741) -> MemoryAlgoResult :
741742 r"""Greedy algorithm to allocate memory for tensors in the graph.
742- alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
743- alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
744- allow_overlapping_allocations: If set to true, allows for allocations that overlap
745- in their lifetime but are at different offsets in the storage. By default true.
746- This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
747- allocations disabled
743+
744+ Args:
745+ alignment: Memory alignment requirement
746+ specs: Set of TensorSpec objects with updated lifetimes
747+ graph_module: Graph module
748+ graph_signature: Graph signature
749+ extra_padding: Additional padding to add to each memory buffer (in bytes)
750+ allow_overlapping_allocations: If set to true, allows for allocations that overlap
751+ in their lifetime but are at different offsets in the storage. By default true.
752+ This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
753+ allocations disabled
754+
755+ Returns:
756+ MemoryAlgoResult containing the allocation decisions
748757 """
749758 greedy_result = MemoryAlgoResult ({}, [])
750- # padding allocation with 64 bytes.
751- # this requirement is really for XNNPACK backend which can read tensors
752- # beyond the end of the tensor. This is done for performance
753- # optimizations in XNNPACK.
754- # While accounting for backend specific requirement is not the right choice
755- # in backend agnostic memory planning, we do it here as it seems most appropriate.
756- # Right now this applies to greedy only so any other
757- # algorithm that plans memory for XNNPACK backend will
758- # not have this.
759- extra_padded_bytes = 0
760- if _contains_xnnpack_delegate (graph_module ):
761- extra_padded_bytes = 64
762759 spec2obj = {}
763760 shared_objects = defaultdict (list )
764- # Don't do assertion in collect_specs_from_nodes if we have already encountered
765- # and ignored some to_out_variant errors.
766- do_assertion = not getattr (graph_module , "encounter_to_out_var_failure" , False )
761+
767762 # For each tensor, pick the available shared object with closest size to
768763 # the tensor. If there are no available shared object left, create a new
769764 # one.
770765 import bisect
771766
772767 sorted_specs = []
773- for spec in collect_specs_from_nodes (
774- graph_module .graph .nodes ,
775- graph_signature ,
776- do_assertion = do_assertion ,
777- ignore_graph_input = not alloc_graph_input ,
778- ignore_graph_output = not alloc_graph_output ,
779- ):
768+ for spec in specs :
780769 bisect .insort (sorted_specs , spec , key = lambda x : x .allocated_memory )
770+
781771 sorted_specs .reverse ()
782772
783773 for spec in sorted_specs :
@@ -806,15 +796,13 @@ def greedy(
806796 for mem_id in shared_objects :
807797 input_total_size = 0
808798 if bufsizes := getattr (graph_module , "input_mem_buffer_sizes" , None ):
809- # pyre-fixme[6]: For 1st argument expected
810- # `pyre_extensions.ReadOnly[Sized]` but got `Union[Tensor, Module]`.
799+ assert isinstance (bufsizes , list )
811800 if len (bufsizes ) > mem_id :
812- # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.Ten...
813801 input_total_size = bufsizes [mem_id ]
814802 total_sizes [mem_id ] = materialize_buffer (
815803 shared_objects [mem_id ], input_total_size
816804 )
817- total_sizes [mem_id ] += extra_padded_bytes
805+ total_sizes [mem_id ] += extra_padding
818806
819807 # Since we now know the number of shared objects we need and the size of
820808 # each shared object, we can assign offset in the memory buffer for each
@@ -837,73 +825,101 @@ def greedy(
837825 greedy_result .bufsizes = total_sizes
838826 return greedy_result
839827
828+ class MemoryPlanningAlgorithmSuite :
829+ def __init__ (self , algo_list : Optional [List [Callable [..., MemoryAlgoResult ]]] = None ,) -> None :
830+ if algo_list is None :
831+ algo_list = [greedy ]
832+ self .algo_list : List [Callable [..., MemoryAlgoResult ]] = algo_list
840833
841- def memory_planning_algorithm_suite (
842- graph_module : torch .fx .GraphModule ,
843- alignment : int ,
844- graph_signature : Optional [ExportGraphSignature ] = None ,
845- alloc_graph_input : bool = True ,
846- alloc_graph_output : bool = True ,
847- allow_overlapping_allocations : bool = True ,
848- algo_list : Optional [List [Callable [..., MemoryAlgoResult ]]] = None ,
849- ) -> List [int ]:
850- r"""
851- Memory planning algorithm suite that runs a list of memory planning algorithms
852- and returns the result of the algorithm that minimizes the total memory usage.
853- """
854- if algo_list is None :
855- algo_list = [greedy ]
856- mem_algo_results = {}
857- for algo in algo_list :
858- if isinstance (algo , functools .partial ):
859- name = algo .func .__name__
860- else :
861- name = getattr (algo , "__name__" , None )
862- # Run this memory planning algorithm and store the result in mem_algo_results
863- # with the name of the algorithm as the key.
864- mem_algo_results [name ] = algo (
865- graph_module ,
866- alignment ,
867- graph_signature ,
868- alloc_graph_input ,
869- alloc_graph_output ,
870- )
834+ def __call__ (
835+ self ,
836+ alignment : int ,
837+ specs : Set [TensorSpec ],
838+ graph_module : torch .fx .GraphModule ,
839+ graph_signature : ExportGraphSignature ,
840+ extra_padding : int ,
841+ ) -> List [int ]:
842+ r"""
843+ Memory planning algorithm suite that runs a list of memory planning algorithms
844+ and returns the result of the algorithm that minimizes the total memory usage.
845+
846+ Args:
847+ graph_module: The graph module to allocate memory for
848+ alignment: Memory alignment requirement
849+ graph_signature: Optional graph signature
850+ alloc_graph_input: Whether to allocate memory for graph input
851+ alloc_graph_output: Whether to allocate memory for graph output
852+ allow_overlapping_allocations: Whether to allow overlapping allocations
853+ algo_list: List of memory planning algorithms to run
854+ specs: Optional set of TensorSpec objects with updated lifetimes. If None, they will be
855+ calculated from the graph_module.
856+
857+ Returns:
858+ List of buffer sizes for each memory hierarchy
859+ """
860+
861+ mem_algo_results = {}
862+ for algo in self .algo_list :
863+ if isinstance (algo , functools .partial ):
864+ name = algo .func .__name__
865+ else :
866+ name = getattr (algo , "__name__" , None )
867+
868+ mem_algo_results [name ] = algo (
869+ alignment ,
870+ specs ,
871+ graph_module ,
872+ graph_signature ,
873+ extra_padding ,
874+ )
871875
872- # All the algorithms should have the same number of buffers allocated.
873- assert (
874- len (
875- {
876- len (mem_algo_result .bufsizes )
877- for mem_algo_result in mem_algo_results .values ()
878- }
879- )
880- == 1
881- ), "Different memory planning algorithms should have the same number of buffers allocated."
882-
883- # Find the algorithm that minimizes the total memory usage.
884- best_algo = min (mem_algo_results , key = lambda k : sum (mem_algo_results [k ].bufsizes ))
885- logging .debug (f"Best memory planning algo for this model is { best_algo } " )
886- bufsizes = mem_algo_results [best_algo ].bufsizes
887-
888- # Update the mem_id and mem_offset for each spec in the graph module based on the
889- # values provided by the best memory planning algorithm.
890- for spec in mem_algo_results [best_algo ].spec_dict :
891- spec_alloc_result = mem_algo_results [best_algo ].spec_dict [spec ]
892- spec .mem_id = spec_alloc_result .mem_id
893- spec .mem_offset = spec_alloc_result .mem_offset
894- spec .mem_obj_id = spec_alloc_result .mem_obj_id
876+ # All the algorithms should have the same number of buffers allocated.
877+ assert (
878+ len (
879+ {
880+ len (mem_algo_result .bufsizes )
881+ for mem_algo_result in mem_algo_results .values ()
882+ }
883+ )
884+ == 1
885+ ), "Different memory planning algorithms should have the same number of buffers allocated."
895886
896- return bufsizes
887+ # Find the algorithm that minimizes the total memory usage.
888+ best_algo = min (mem_algo_results , key = lambda k : sum (mem_algo_results [k ].bufsizes ))
889+ logging .debug (f"Best memory planning algo for this model is { best_algo } " )
890+ bufsizes = mem_algo_results [best_algo ].bufsizes
897891
892+ # Update the mem_id and mem_offset for each spec in the graph module based on the
893+ # values provided by the best memory planning algorithm.
894+ for spec in mem_algo_results [best_algo ].spec_dict :
895+ spec_alloc_result = mem_algo_results [best_algo ].spec_dict [spec ]
896+ spec .mem_id = spec_alloc_result .mem_id
897+ spec .mem_offset = spec_alloc_result .mem_offset
898+ spec .mem_obj_id = spec_alloc_result .mem_obj_id
899+
900+ return bufsizes
898901
899902def naive (
900- graph_module : torch .fx .GraphModule ,
901903 alignment : int ,
902- graph_signature : Optional [ExportGraphSignature ] = None ,
903- alloc_graph_input : bool = True ,
904- alloc_graph_output : bool = True ,
904+ specs : Set [TensorSpec ],
905+ graph_module : torch .fx .GraphModule ,
906+ graph_signature : ExportGraphSignature ,
907+ extra_padding : int ,
905908) -> MemoryAlgoResult :
906-
909+ """Naive algorithm to allocate memory for tensors in the graph.
910+
911+ This algorithm simply allocates memory for each tensor sequentially without reusing memory.
912+
913+ Args:
914+ alignment: Memory alignment requirement
915+ specs: Set of TensorSpec objects with updated lifetimes
916+ graph_module: Graph module
917+ graph_signature: Graph signature
918+ extra_padding: Additional padding to add to each memory buffer (in bytes)
919+
920+ Returns:
921+ MemoryAlgoResult containing the allocation decisions
922+ """
907923 naive_result = MemoryAlgoResult ({}, [])
908924
909925 # allocate 'allocated' bytes from buffer with id mem_id.
@@ -918,14 +934,9 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
918934 bufsizes = getattr (graph_module , "input_mem_buffer_sizes" , None )
919935 if bufsizes is None :
920936 bufsizes = [0 , 0 ]
921-
922937 bufsizes = typing .cast (List [int ], bufsizes )
923- for spec in collect_specs_from_nodes (
924- graph_module .graph .nodes ,
925- graph_signature ,
926- ignore_graph_input = not alloc_graph_input ,
927- ignore_graph_output = not alloc_graph_output ,
928- ):
938+
939+ for spec in specs :
929940 spec_alloc_result = naive_result .spec_dict .get (spec , SpecAllocResult (0 , 0 , 0 ))
930941 # assume a single memory layer which has mem_id 1
931942 if spec .mem_id is None :
@@ -1027,7 +1038,7 @@ def insert_calls_to_free(
10271038
10281039def apply_algo (
10291040 algo : Callable [
1030- [ torch . fx . GraphModule , int , Optional [ ExportGraphSignature ], bool , bool ] ,
1041+ ... ,
10311042 List [int ],
10321043 ],
10331044 graph_module : torch .fx .GraphModule ,
@@ -1048,10 +1059,46 @@ def apply_algo(
10481059 TODO: make these optimizations once we have some baseline working.
10491060 """
10501061
1051- specs = update_all_tensors_lifetime (graph_module , graph_signature )
1062+ # Extract the nodes and their lifespans from the graph_module
1063+ specs = update_all_tensors_lifetime (
1064+ graph_module ,
1065+ graph_signature
1066+ )
1067+
1068+ # Filter specs based on alloc_graph_input and alloc_graph_output
1069+ filtered_specs = set ()
1070+ graph_input_tensors = get_graph_input_tensors (graph_module .graph .nodes , graph_signature )
1071+ graph_output_tensors = get_graph_output_tensors (graph_module .graph .nodes )
1072+
1073+ for spec in specs :
1074+ # Apply the same filtering as collect_specs_from_nodes
1075+ if not alloc_graph_input and spec in graph_input_tensors :
1076+ continue
1077+ if not alloc_graph_output and spec in graph_output_tensors :
1078+ continue
1079+ if spec .shape_dynamism == TensorShapeDynamism .DYNAMIC_UNBOUND :
1080+ continue
1081+ # In Training we flag weights with an associated gradient,
1082+ # as these need to be memory planned since their value will
1083+ # be udated each step of training.
1084+ if spec .const and not getattr (spec , "weight_has_gradient" , False ):
1085+ continue
1086+ filtered_specs .add (spec )
1087+
1088+ # Get extra padding for XNNPACK if needed
1089+ extra_padding = 0
1090+ if _contains_xnnpack_delegate (graph_module ):
1091+ extra_padding = 64
1092+
1093+ # Pass the filtered specs to the algorithm
10521094 bufsizes : List [int ] = algo (
1053- graph_module , alignment , graph_signature , alloc_graph_input , alloc_graph_output
1095+ alignment ,
1096+ filtered_specs ,
1097+ graph_module ,
1098+ graph_signature ,
1099+ extra_padding ,
10541100 )
1101+
10551102 insert_calls_to_free (graph_module , specs )
10561103
10571104 def handle_submodule (
@@ -1063,6 +1110,7 @@ def handle_submodule(
10631110 # memory planning for submodule need to be aware of the amount of
10641111 # buffer already allocated.
10651112 submodule .input_mem_buffer_sizes = bufsizes
1113+
10661114 bufsizes = apply_algo (
10671115 algo ,
10681116 submodule ,
0 commit comments