Skip to content

Commit 10d9be1

Browse files
authored
[ET-VK][ez] Make squeeze insertion requirements more strict
Differential Revision: D72480178 Pull Request resolved: #9917
1 parent 9d60161 commit 10d9be1

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

backends/vulkan/_passes/squeeze_unsqueeze_inputs.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,38 @@ class SqueezeUnsqueezeInputs(ExportPass):
2727
exir_ops.edge.aten.gelu.default,
2828
}
2929

30+
def should_squeeze(self, op, shape: List[int]) -> bool: # pyre-ignore
31+
if len(shape) == 3:
32+
return shape[1] == 1 and shape[0] > 1
33+
if len(shape) == 4:
34+
# No need to squeeze if all dims are 1 except the width dim
35+
if all(dim == 1 for dim in shape[:-1]):
36+
return False
37+
# Otherwise, check for squeezable dim
38+
return 1 in shape[:-1]
39+
40+
# Prefer not to introduce additional orchestration ops by default
41+
return False
42+
3043
def call_operator(
3144
self,
3245
op, # pyre-ignore
3346
args: Tuple[Argument, ...],
3447
kwargs: Dict[str, Argument],
3548
meta: NodeMetadata,
3649
) -> ProxyValue:
37-
def _squeezable(shape: List[int]) -> bool:
38-
return len(shape) > 2 and 1 in shape
39-
4050
if op not in self._squeezable_ops:
4151
return super().call_operator(op, args, kwargs, meta)
42-
4352
# pyre-ignore[16]: `None` has no attribute `node`
4453
input_shape = args[0].node.meta["val"].shape
4554
output_shape = meta["val"].shape
46-
if not _squeezable(input_shape):
55+
56+
if not self.should_squeeze(op, input_shape):
4757
return super().call_operator(op, args, kwargs, meta)
4858

59+
def _squeezable(shape: List[int]) -> bool:
60+
return len(shape) > 2 and 1 in shape
61+
4962
# squeeze input tensor
5063
squeeze_shape = list(input_shape)
5164
while _squeezable(squeeze_shape):

0 commit comments

Comments
 (0)