Skip to content

Commit ae87928

Browse files
committed
Update on "[ET-VK][ez] Make squeeze insertion requirements more strict"
## Context Refactor the `SqueezeUnsqueezeInputs` pass to be more clear about its intention. For Llama models, input shapes to 4 bit linear will oftentimes have the shape `[1, seq_len, dim]`; under the current implementation of the pass, the input would be squeezed to `[seq_len, dim]` even though the squeeze is not necessary. The original intention of thispass was to squeeze inputs with shape `[batch_size, 1, dim]` to `[batch_size, dim]` before calling the 4-bit linear operator. ## Changes To avoid inserting unnecessary squeeze/unsqueezes, be more specific about when squeeze/unsqueeze should be added. I would like to consider refactoring this pass in the future, since the logic is currently a bit uninttuitive. Squeeze/unsqueeze is also inserted for gelu and relu, but this is to create a chain of unsqueeze/squeeze that will be eliminated by a later pass (see #8601 / D69673068). I think eventually it will be good to rewrite the pass to make shape management more explicit and self contained within the pass rather than inserting ops which are expected to be removed later on. Differential Revision: [D72480178](https://our.internmc.facebook.com/intern/diff/D72480178/) [ghstack-poisoned]
2 parents 6a9c32e + 9d60161 commit ae87928

File tree

1 file changed

+0
-12
lines changed

1 file changed

+0
-12
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,6 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3333

3434
layout(constant_id = 3) const int group_size = 64;
3535

36-
uint8_t get_first(const uint8_t packed) {
37-
return uint8_t((packed & 0xF0) >> 4);
38-
}
39-
40-
uint8_t get_second(const uint8_t packed) {
41-
return uint8_t(packed & 0x0F);
42-
}
43-
44-
uint8_t combine(const uint8_t first, const uint8_t second) {
45-
return uint8_t(first << 4 | second);
46-
}
47-
4836
/*
4937
* This shader computes a linear operator between a floating point input matrix
5038
* x and a weights matrix that is quantized to 4 bits.

0 commit comments

Comments
 (0)