33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
55
6+ import hashlib
7+ from collections import defaultdict
8+
69import torch
710from executorch .backends .arm ._passes .arm_pass_utils import (
811 get_constant_placeholder_kind ,
@@ -21,7 +24,7 @@ class FuseEqualPlaceholdersPass(ExportPass):
2124 """
2225 This pass optimizes memory usage by finding constant placeholders
2326 pointing to identical tensors and fusing them to one single placeholder
24- with multiple users.
27+ with multiple users, using a cache for faster comparison .
2528 """
2629
2730 def __init__ (self , exported_program : ExportedProgram ):
@@ -30,58 +33,54 @@ def __init__(self, exported_program: ExportedProgram):
3033
3134 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
3235 modified = False
33- const_placeholder_nodes = []
34- for node in graph_module .graph .nodes :
35- if is_param_node (self .exported_program , node ):
36- const_placeholder_nodes .append (node )
37-
38- while const_placeholder_nodes :
3936
40- # Find equal tensors
41- node1 = const_placeholder_nodes .pop ()
42- eq_nodes = [node1 ]
43- tensor1 = get_param_tensor (self .exported_program , node1 )
44- if tensor1 is None :
37+ # Build a cache of params: mapping hash_key -> list of (node, tensor)
38+ hash_buckets = defaultdict (list )
39+ for node in graph_module .graph .nodes :
40+ if not is_param_node (self .exported_program , node ):
4541 continue
42+ tensor = get_param_tensor (self .exported_program , node )
43+ if tensor is None :
44+ continue
45+ # Create a lightweight fingerprint: dtype + shape + SHA1 of raw bytes
46+ # Ensure tensor is on CPU and contiguous
47+ t_cpu = tensor .detach ().cpu ().contiguous ()
48+ data_bytes = t_cpu .numpy ().tobytes ()
49+ key = (
50+ str (t_cpu .dtype ),
51+ tuple (t_cpu .shape ),
52+ hashlib .sha1 (data_bytes ).hexdigest (),
53+ )
54+ hash_buckets [key ].append ((node , t_cpu ))
4655
47- for node2 in const_placeholder_nodes :
48- tensor2 = get_param_tensor (self .exported_program , node2 )
49- if tensor2 is None :
50- continue
51-
52- if (
53- tensor1 .dtype == tensor2 .dtype
54- and tensor1 .shape == tensor2 .shape
55- and torch .allclose (tensor1 , tensor2 , atol = 1e-08 )
56- ):
57- eq_nodes .append (node2 )
56+ # For each bucket with more than one entry, fuse:
57+ for nodes_tensors in hash_buckets .values ():
58+ if len (nodes_tensors ) < 2 :
59+ continue
5860
59- if len (eq_nodes ) > 1 :
60- common_name = node1 .name + "_common"
61- common_kind = get_constant_placeholder_kind (
62- self .exported_program , node1
61+ # Create a new placeholder from first in list of equal placeholders.
62+ rep_node , rep_tensor = nodes_tensors [0 ]
63+ common_name = rep_node .name + "_common"
64+ common_kind = get_constant_placeholder_kind (self .exported_program , rep_node )
65+ common_persistent = True
66+ with graph_module .graph .inserting_before (rep_node ):
67+ common_node = create_constant_placeholder (
68+ self .exported_program ,
69+ graph_module .graph ,
70+ common_name ,
71+ common_kind ,
72+ rep_tensor ,
73+ common_persistent ,
6374 )
64- common_persisten_buffer = True
65-
66- with graph_module .graph .inserting_before (node1 ):
67- common_node = create_constant_placeholder (
68- self .exported_program ,
69- graph_module .graph ,
70- common_name ,
71- common_kind ,
72- tensor1 ,
73- common_persisten_buffer ,
74- )
75-
76- for eq_node in eq_nodes :
77- eq_node .replace_all_uses_with (common_node )
78- delete_constant_placeholder (self .exported_program , eq_node )
79- if eq_node != node1 :
80- const_placeholder_nodes .remove (eq_node )
8175
76+ # Replace uses and delete duplicates
77+ for node , _ in nodes_tensors :
78+ node .replace_all_uses_with (common_node )
79+ delete_constant_placeholder (self .exported_program , node )
8280 modified = True
8381
8482 if modified :
8583 graph_module .recompile ()
8684 graph_module = super ().call (graph_module ).graph_module
85+
8786 return PassResult (graph_module = graph_module , modified = modified )
0 commit comments