|
6 | 6 | #include <utility> |
7 | 7 | #include <vector> |
8 | 8 |
|
9 | | -#include "common.hpp" |
| 9 | +#include "common_block.hpp" |
10 | 10 | #include "flux.hpp" |
11 | | -#include "ggml_extend.hpp" |
12 | 11 | #include "rope.hpp" |
13 | 12 |
|
14 | 13 | namespace Anima { |
15 | 14 | constexpr int ANIMA_GRAPH_SIZE = 65536; |
16 | 15 |
|
17 | | - __STATIC_INLINE__ struct ggml_tensor* patchify_2d(struct ggml_context* ctx, |
18 | | - struct ggml_tensor* x, |
19 | | - int64_t patch_size) { |
20 | | - // x: [W*r, H*q, T, C] |
21 | | - // return: [W, H, T, C*q*r] |
22 | | - if (patch_size == 1) { |
23 | | - return x; |
24 | | - } |
25 | | - GGML_ASSERT(x->ne[2] == 1); |
26 | | - |
27 | | - int64_t W = x->ne[0]; |
28 | | - int64_t H = x->ne[1]; |
29 | | - int64_t T = x->ne[2]; |
30 | | - int64_t C = x->ne[3]; |
31 | | - int64_t p = patch_size; |
32 | | - int64_t h = H / p; |
33 | | - int64_t w = W / p; |
34 | | - |
35 | | - GGML_ASSERT(T == 1); |
36 | | - GGML_ASSERT(h * p == H && w * p == W); |
37 | | - |
38 | | - // Reuse Flux patchify layout on a [W, H, C, N] view. |
39 | | - x = ggml_reshape_4d(ctx, x, W, H, C, T); // [W, H, C, N] |
40 | | - |
41 | | - // Flux patchify: [N, C, H, W] -> [N, h*w, C*p*p] |
42 | | - x = ggml_reshape_4d(ctx, x, p, w, p, h * C * T); // [p, w, p, h*C*N] |
43 | | - x = ggml_ext_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [p, p, w, h*C*N] |
44 | | - x = ggml_reshape_4d(ctx, x, p * p, w * h, C, T); // [p*p, h*w, C, N] |
45 | | - x = ggml_ext_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [p*p, C, h*w, N] |
46 | | - x = ggml_reshape_3d(ctx, x, p * p * C, w * h, T); // [C*p*p, h*w, N] |
47 | | - |
48 | | - // Return [w, h, T, C*p*p] |
49 | | - x = ggml_reshape_4d(ctx, x, p * p * C, w, h, T); // [C*p*p, w, h, N] |
50 | | - x = ggml_ext_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [w, h, N, C*p*p] |
51 | | - return x; |
52 | | - } |
53 | | - |
54 | | - __STATIC_INLINE__ struct ggml_tensor* unpatchify_2d(struct ggml_context* ctx, |
55 | | - struct ggml_tensor* x, |
56 | | - int64_t patch_size) { |
57 | | - // x: [W, H, T, C*q*r] |
58 | | - // return: [W*r, H*q, T, C] |
59 | | - if (patch_size == 1) { |
60 | | - return x; |
61 | | - } |
62 | | - GGML_ASSERT(x->ne[2] == 1); |
63 | | - |
64 | | - int64_t w = x->ne[0]; |
65 | | - int64_t h = x->ne[1]; |
66 | | - int64_t T = x->ne[2]; |
67 | | - int64_t p = patch_size; |
68 | | - int64_t nm = p * p; |
69 | | - int64_t Cp = x->ne[3]; |
70 | | - int64_t C = Cp / nm; |
71 | | - int64_t W = w * p; |
72 | | - int64_t H = h * p; |
73 | | - |
74 | | - GGML_ASSERT(T == 1); |
75 | | - GGML_ASSERT(C * nm == Cp); |
76 | | - |
77 | | - // [w, h, 1, C*p*p] -> [W, H, 1, C] |
78 | | - x = ggml_reshape_4d(ctx, x, w, h * C, p, p); // [w, h*C, p2, p1] |
79 | | - x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 3, 1)); // [p2, w, p1, h*C] |
80 | | - x = ggml_reshape_4d(ctx, x, W, H, T, C); // [W, H, 1, C] |
81 | | - return x; |
82 | | - } |
83 | | - |
84 | 16 | __STATIC_INLINE__ struct ggml_tensor* apply_gate(struct ggml_context* ctx, |
85 | 17 | struct ggml_tensor* x, |
86 | 18 | struct ggml_tensor* gate) { |
@@ -491,7 +423,7 @@ namespace Anima { |
491 | 423 | int64_t text_embed_dim = 1024; |
492 | 424 | int64_t num_heads = 16; |
493 | 425 | int64_t head_dim = 128; |
494 | | - int64_t patch_size = 2; |
| 426 | + int patch_size = 2; |
495 | 427 | int64_t num_layers = 28; |
496 | 428 | std::vector<int> axes_dim = {44, 42, 42}; |
497 | 429 | int theta = 10000; |
@@ -533,24 +465,10 @@ namespace Anima { |
533 | 465 | int64_t W = x->ne[0]; |
534 | 466 | int64_t H = x->ne[1]; |
535 | 467 |
|
536 | | - x = ggml_reshape_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]); // [N*C, T, H, W] style |
537 | | - |
538 | | - int64_t pad_h = (patch_size - H % patch_size) % patch_size; |
539 | | - int64_t pad_w = (patch_size - W % patch_size) % patch_size; |
540 | | - if (pad_h > 0 || pad_w > 0) { |
541 | | - x = ggml_ext_pad(ctx->ggml_ctx, x, static_cast<int>(pad_w), static_cast<int>(pad_h), 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled); |
542 | | - } |
543 | | - |
544 | | - auto padding_mask = ggml_ext_zeros(ctx->ggml_ctx, x->ne[0], x->ne[1], x->ne[2], 1); |
545 | | - x = ggml_concat(ctx->ggml_ctx, x, padding_mask, 3); // concat mask channel |
546 | | - |
547 | | - x = patchify_2d(ctx->ggml_ctx, x, patch_size); // [C*4, T, H/2, W/2] |
| 468 | + auto padding_mask = ggml_ext_zeros(ctx->ggml_ctx, x->ne[0], x->ne[1], 1, x->ne[3]); |
| 469 | + x = ggml_concat(ctx->ggml_ctx, x, padding_mask, 2); // [N, C + 1, H, W] |
548 | 470 |
|
549 | | - int64_t w_len = x->ne[0]; |
550 | | - int64_t h_len = x->ne[1]; |
551 | | - int64_t t_len = x->ne[2]; |
552 | | - x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3], 1); |
553 | | - x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, n_token, C] |
| 471 | + x = DiT::pad_and_patchify(ctx, x, patch_size, patch_size); // [N, h*w, (C+1)*ph*pw] |
554 | 472 |
|
555 | 473 | x = x_embedder->forward(ctx, x); |
556 | 474 |
|
@@ -586,15 +504,9 @@ namespace Anima { |
586 | 504 | x = block->forward(ctx, x, encoder_hidden_states, embedded_timestep, temb, image_pe); |
587 | 505 | } |
588 | 506 |
|
589 | | - x = final_layer->forward(ctx, x, embedded_timestep, temb); // [N, n_token, C*4] |
590 | | - |
591 | | - x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [n_token, C*4, N] |
592 | | - x = ggml_reshape_4d(ctx->ggml_ctx, x, w_len, h_len, t_len, x->ne[1]); // [C*4, T, H/2, W/2] |
593 | | - x = unpatchify_2d(ctx->ggml_ctx, x, patch_size); // [C, T, H, W] |
| 507 | + x = final_layer->forward(ctx, x, embedded_timestep, temb); // [N, h*w, ph*pw*C] |
594 | 508 |
|
595 | | - x = ggml_ext_slice(ctx->ggml_ctx, x, 1, 0, H); // [C, T, H, W + pad] |
596 | | - x = ggml_ext_slice(ctx->ggml_ctx, x, 0, 0, W); // [C, T, H, W] |
597 | | - x = ggml_reshape_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], x->ne[3], x->ne[2]); // [N, C, H, W] |
| 509 | + x = DiT::unpatchify_and_crop(ctx->ggml_ctx, x, H, W, patch_size, patch_size, false); // [N, C, H, W] |
598 | 510 |
|
599 | 511 | return x; |
600 | 512 | } |
|
0 commit comments