Skip to content

Commit 9d60161

Browse files
committed
Update base for Update on "[ET-VK][ez] Make squeeze insertion requirements more strict"
## Context Refactor the `SqueezeUnsqueezeInputs` pass to be more clear about its intention. For Llama models, input shapes to 4 bit linear will oftentimes have the shape `[1, seq_len, dim]`; under the current implementation of the pass, the input would be squeezed to `[seq_len, dim]` even though the squeeze is not necessary. The original intention of thispass was to squeeze inputs with shape `[batch_size, 1, dim]` to `[batch_size, dim]` before calling the 4-bit linear operator. ## Changes To avoid inserting unnecessary squeeze/unsqueezes, be more specific about when squeeze/unsqueeze should be added. I would like to consider refactoring this pass in the future, since the logic is currently a bit uninttuitive. Squeeze/unsqueeze is also inserted for gelu and relu, but this is to create a chain of unsqueeze/squeeze that will be eliminated by a later pass (see #8601 / D69673068). I think eventually it will be good to rewrite the pass to make shape management more explicit and self contained within the pass rather than inserting ops which are expected to be removed later on. Differential Revision: [D72480178](https://our.internmc.facebook.com/intern/diff/D72480178/) [ghstack-poisoned]
1 parent 711282e commit 9d60161

File tree

1 file changed

+0
-12
lines changed

1 file changed

+0
-12
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,6 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3333

3434
layout(constant_id = 3) const int group_size = 64;
3535

36-
uint8_t get_first(const uint8_t packed) {
37-
return uint8_t((packed & 0xF0) >> 4);
38-
}
39-
40-
uint8_t get_second(const uint8_t packed) {
41-
return uint8_t(packed & 0x0F);
42-
}
43-
44-
uint8_t combine(const uint8_t first, const uint8_t second) {
45-
return uint8_t(first << 4 | second);
46-
}
47-
4836
/*
4937
* This shader computes a linear operator between a floating point input matrix
5038
* x and a weights matrix that is quantized to 4 bits.

0 commit comments

Comments
 (0)