revert the L_k padding, introduced as part of #736#774
revert the L_k padding, introduced as part of #736#774Green-Sky wants to merge 1 commit intoleejet:masterfrom
Conversation
|
For the WAN model, flash attention is almost essential; otherwise, the speed will be very slow. So in this PR #778, I removed many restrictions of the flash attention. It is basically controlled entirely by parameters. As for the possible side effects, this is for the user to weigh and decide. |
@leejet , so we'll need to add an extra runtime flag and command-line parameter to disable the kv_pad? Fine by me, but isn't that a little bit cluttered for the interface? (note that FA seems to work fine for image generation otherwise). I could implement that flag, but it'll necessarily conflict with the Wan PR; and since right now it just crashes for image generation, I won't be able to test it. Wouldn't it be better to apply this PR for now, and I'll figure out how to not break it again inside the Wan branch? |
|
No, not before. We should think of a solution on top of the wan pr, since it touches that code. |
|
I'm not sure if this will be of any help. diff --git a/ggml_extend.hpp b/ggml_extend.hpp
index 110bbbc..f583869 100644
--- a/ggml_extend.hpp
+++ b/ggml_extend.hpp
@@ -1013,6 +1013,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
if (flash_attn) {
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
bool can_use_flash_attn = true;
+ if (L_k < 256) {
+ can_use_flash_attn = false;
+ }
if (can_use_flash_attn && L_k % 256 != 0) {
kv_pad = GGML_PAD(L_k, 256) - L_k;
} |
|
That would turn flash attention off for the length-77 tensors that seemed to be the problem, yes. But... If I may get way over my head here, would something like this make sense instead? diff --git a/ggml_extend.hpp b/ggml_extend.hpp
index 110bbbc..d5721d5 100644
--- a/ggml_extend.hpp
+++ b/ggml_extend.hpp
@@ -1031,12 +1031,13 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
ggml_tensor* kqv = nullptr;
if (flash_attn) {
// LOG_DEBUG(" uses flash attention");
if (kv_pad != 0) {
// LOG_DEBUG(" padding k and v dim1 by %d", kv_pad);
k = ggml_pad(ctx, k, 0, kv_pad, 0, 0);
+ scale *= sqrtf((float)(L_k + kv_pad) / (float)L_k);
}
k = ggml_cast(ctx, k, GGML_TYPE_F16);
v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
if (kv_pad != 0) {As I (barely) understand it, this would compensate for the attention mechanism being applied over the larger, padded tensor. It does seem to fix #756 for me, although I can't really say if it could cause other issues. |
|
This issue occurred because no mask was applied after kv_pad. I fixed it in this commit aa5566f |
fixes #756