@@ -44,12 +44,14 @@ def __init__(
4444 graph_module : torch .fx .GraphModule ,
4545 alloc_graph_input : bool ,
4646 alloc_graph_output : bool ,
47+ alloc_mutable_buffers : bool ,
4748 graph_signature : Optional [ExportGraphSignature ] = None ,
4849 ) -> None :
4950 self .graph_module = graph_module
5051 self .graph_signature = graph_signature
5152 self .alloc_graph_input = alloc_graph_input
5253 self .alloc_graph_output = alloc_graph_output
54+ self .alloc_mutable_buffers = alloc_mutable_buffers
5355
5456 @classmethod
5557 def mem_obj_id_match (
@@ -149,6 +151,7 @@ def verify_storage_reuse(
149151 ignore_const = True ,
150152 ignore_graph_input = not self .alloc_graph_input ,
151153 ignore_graph_output = not self .alloc_graph_output ,
154+ ignore_mutable_buffers = not self .alloc_mutable_buffers ,
152155 do_assertion = False ,
153156 ignore_out_var_node = False ,
154157 dedup = True ,
@@ -374,6 +377,7 @@ def collect_specs_from_nodes( # noqa: C901
374377 graph_signature : Optional [ExportGraphSignature ] = None ,
375378 ignore_graph_input : bool = False ,
376379 ignore_graph_output : bool = False ,
380+ ignore_mutable_buffers : bool = False ,
377381 ignore_const : bool = True ,
378382 ignore_out_var_node : bool = True ,
379383 dedup : bool = True ,
@@ -414,6 +418,9 @@ def collect_specs_from_nodes( # noqa: C901
414418 if _is_inplace_node (node ):
415419 continue
416420
421+ if _is_mutable_buffer (node , graph_signature ) and ignore_mutable_buffers :
422+ continue
423+
417424 if do_assertion :
418425 internal_assert (
419426 node .op in ("placeholder" , "output" )
@@ -469,6 +476,7 @@ def update_all_tensors_lifetime(
469476 Set the lifetime for all the tensors encountered in the Fx graph.
470477 """
471478 specs = set ()
479+
472480 for node_idx , node in enumerate (graph_module .graph .nodes ):
473481 for spec in collect_specs_from_nodes (
474482 filter_nodes (itertools .chain ([node ], node .args , node .kwargs .values ())),
@@ -1053,6 +1061,7 @@ def apply_algo(
10531061 graph_signature : Optional [ExportGraphSignature ] = None ,
10541062 alloc_graph_input : bool = True ,
10551063 alloc_graph_output : bool = True ,
1064+ alloc_mutable_buffers : bool = True ,
10561065) -> List [int ]:
10571066 """
10581067 Recursively apply algo to graph_module and its submodules for control flow.
@@ -1065,19 +1074,18 @@ def apply_algo(
10651074 storage with tensors in the outer module.
10661075 TODO: make these optimizations once we have some baseline working.
10671076 """
1068-
10691077 # Extract the nodes and their lifespans from the graph_module
10701078 # Difficult to just filter the list of specs returned by this due to
10711079 # how we flag trainable weights.
10721080 _ = update_all_tensors_lifetime (graph_module , graph_signature )
1073-
10741081 # Filter specs based on alloc_graph_input and alloc_graph_output
10751082 specs = collect_specs_from_nodes (
10761083 graph_module .graph .nodes ,
10771084 graph_signature ,
10781085 do_assertion = False ,
10791086 ignore_graph_input = not alloc_graph_input ,
10801087 ignore_graph_output = not alloc_graph_output ,
1088+ ignore_mutable_buffers = not alloc_mutable_buffers ,
10811089 )
10821090
10831091 # Get extra padding for XNNPACK if needed
0 commit comments