@@ -27,25 +27,38 @@ class SqueezeUnsqueezeInputs(ExportPass):
27
27
exir_ops .edge .aten .gelu .default ,
28
28
}
29
29
30
+ def should_squeeze (self , op , shape : List [int ]) -> bool : # pyre-ignore
31
+ if len (shape ) == 3 :
32
+ return shape [1 ] == 1 and shape [0 ] > 1
33
+ if len (shape ) == 4 :
34
+ # No need to squeeze if all dims are 1 except the width dim
35
+ if all (dim == 1 for dim in shape [:- 1 ]):
36
+ return False
37
+ # Otherwise, check for squeezable dim
38
+ return 1 in shape [:- 1 ]
39
+
40
+ # Prefer not to introduce additional orchestration ops by default
41
+ return False
42
+
30
43
def call_operator (
31
44
self ,
32
45
op , # pyre-ignore
33
46
args : Tuple [Argument , ...],
34
47
kwargs : Dict [str , Argument ],
35
48
meta : NodeMetadata ,
36
49
) -> ProxyValue :
37
- def _squeezable (shape : List [int ]) -> bool :
38
- return len (shape ) > 2 and 1 in shape
39
-
40
50
if op not in self ._squeezable_ops :
41
51
return super ().call_operator (op , args , kwargs , meta )
42
-
43
52
# pyre-ignore[16]: `None` has no attribute `node`
44
53
input_shape = args [0 ].node .meta ["val" ].shape
45
54
output_shape = meta ["val" ].shape
46
- if not _squeezable (input_shape ):
55
+
56
+ if not self .should_squeeze (op , input_shape ):
47
57
return super ().call_operator (op , args , kwargs , meta )
48
58
59
+ def _squeezable (shape : List [int ]) -> bool :
60
+ return len (shape ) > 2 and 1 in shape
61
+
49
62
# squeeze input tensor
50
63
squeeze_shape = list (input_shape )
51
64
while _squeezable (squeeze_shape ):
0 commit comments