Skip to content

Commit 049a14f

Browse files
committed
Update on "[ET-VK][ez] Make squeeze insertion requirements more strict"
## Context Refactor the `SqueezeUnsqueezeInputs` pass to be more clear about its intention. For Llama models, input shapes to 4 bit linear will oftentimes have the shape `[1, seq_len, dim]`; under the current implementation of the pass, the input would be squeezed to `[seq_len, dim]` even though the squeeze is not necessary. The original intention of thispass was to squeeze inputs with shape `[batch_size, 1, dim]` to `[batch_size, dim]` before calling the 4-bit linear operator. ## Changes To avoid inserting unnecessary squeeze/unsqueezes, be more specific about when squeeze/unsqueeze should be added. I would like to consider refactoring this pass in the future, since the logic is currently a bit uninttuitive. Squeeze/unsqueeze is also inserted for gelu and relu, but this is to create a chain of unsqueeze/squeeze that will be eliminated by a later pass (see #8601 / D69673068). I think eventually it will be good to rewrite the pass to make shape management more explicit and self contained within the pass rather than inserting ops which are expected to be removed later on. Differential Revision: [D72480178](https://our.internmc.facebook.com/intern/diff/D72480178/) [ghstack-poisoned]
2 parents b2a91e9 + f1e2f1a commit 049a14f

File tree

6 files changed

+725
-996
lines changed

6 files changed

+725
-996
lines changed

backends/vulkan/_passes/squeeze_unsqueeze_inputs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ class SqueezeUnsqueezeInputs(ExportPass):
3030
def should_squeeze(self, op, shape: List[int]) -> bool:
3131
if len(shape) == 3:
3232
return shape[1] == 1 and shape[0] > 1
33+
if len(shape) == 4:
34+
# No need to squeeze if all dims are 1 except the width dim
35+
if all(dim == 1 for dim in shape[:-1]):
36+
return False
37+
# Otherwise, check for squeezable dim
38+
return 1 in shape[:-1]
3339

3440
# Prefer not to introduce additional orchestration ops by default
3541
return False

0 commit comments

Comments
 (0)