Skip to content

Commit e64baa3

Browse files
authored
refactor: reuse DiT's patchify/unpatchify functions (leejet#1304)
1 parent cec4aed commit e64baa3

12 files changed

Lines changed: 150 additions & 361 deletions

File tree

src/anima.hpp

Lines changed: 7 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -6,81 +6,13 @@
66
#include <utility>
77
#include <vector>
88

9-
#include "common.hpp"
9+
#include "common_block.hpp"
1010
#include "flux.hpp"
11-
#include "ggml_extend.hpp"
1211
#include "rope.hpp"
1312

1413
namespace Anima {
1514
constexpr int ANIMA_GRAPH_SIZE = 65536;
1615

17-
__STATIC_INLINE__ struct ggml_tensor* patchify_2d(struct ggml_context* ctx,
18-
struct ggml_tensor* x,
19-
int64_t patch_size) {
20-
// x: [W*r, H*q, T, C]
21-
// return: [W, H, T, C*q*r]
22-
if (patch_size == 1) {
23-
return x;
24-
}
25-
GGML_ASSERT(x->ne[2] == 1);
26-
27-
int64_t W = x->ne[0];
28-
int64_t H = x->ne[1];
29-
int64_t T = x->ne[2];
30-
int64_t C = x->ne[3];
31-
int64_t p = patch_size;
32-
int64_t h = H / p;
33-
int64_t w = W / p;
34-
35-
GGML_ASSERT(T == 1);
36-
GGML_ASSERT(h * p == H && w * p == W);
37-
38-
// Reuse Flux patchify layout on a [W, H, C, N] view.
39-
x = ggml_reshape_4d(ctx, x, W, H, C, T); // [W, H, C, N]
40-
41-
// Flux patchify: [N, C, H, W] -> [N, h*w, C*p*p]
42-
x = ggml_reshape_4d(ctx, x, p, w, p, h * C * T); // [p, w, p, h*C*N]
43-
x = ggml_ext_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [p, p, w, h*C*N]
44-
x = ggml_reshape_4d(ctx, x, p * p, w * h, C, T); // [p*p, h*w, C, N]
45-
x = ggml_ext_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [p*p, C, h*w, N]
46-
x = ggml_reshape_3d(ctx, x, p * p * C, w * h, T); // [C*p*p, h*w, N]
47-
48-
// Return [w, h, T, C*p*p]
49-
x = ggml_reshape_4d(ctx, x, p * p * C, w, h, T); // [C*p*p, w, h, N]
50-
x = ggml_ext_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [w, h, N, C*p*p]
51-
return x;
52-
}
53-
54-
__STATIC_INLINE__ struct ggml_tensor* unpatchify_2d(struct ggml_context* ctx,
55-
struct ggml_tensor* x,
56-
int64_t patch_size) {
57-
// x: [W, H, T, C*q*r]
58-
// return: [W*r, H*q, T, C]
59-
if (patch_size == 1) {
60-
return x;
61-
}
62-
GGML_ASSERT(x->ne[2] == 1);
63-
64-
int64_t w = x->ne[0];
65-
int64_t h = x->ne[1];
66-
int64_t T = x->ne[2];
67-
int64_t p = patch_size;
68-
int64_t nm = p * p;
69-
int64_t Cp = x->ne[3];
70-
int64_t C = Cp / nm;
71-
int64_t W = w * p;
72-
int64_t H = h * p;
73-
74-
GGML_ASSERT(T == 1);
75-
GGML_ASSERT(C * nm == Cp);
76-
77-
// [w, h, 1, C*p*p] -> [W, H, 1, C]
78-
x = ggml_reshape_4d(ctx, x, w, h * C, p, p); // [w, h*C, p2, p1]
79-
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 3, 1)); // [p2, w, p1, h*C]
80-
x = ggml_reshape_4d(ctx, x, W, H, T, C); // [W, H, 1, C]
81-
return x;
82-
}
83-
8416
__STATIC_INLINE__ struct ggml_tensor* apply_gate(struct ggml_context* ctx,
8517
struct ggml_tensor* x,
8618
struct ggml_tensor* gate) {
@@ -491,7 +423,7 @@ namespace Anima {
491423
int64_t text_embed_dim = 1024;
492424
int64_t num_heads = 16;
493425
int64_t head_dim = 128;
494-
int64_t patch_size = 2;
426+
int patch_size = 2;
495427
int64_t num_layers = 28;
496428
std::vector<int> axes_dim = {44, 42, 42};
497429
int theta = 10000;
@@ -533,24 +465,10 @@ namespace Anima {
533465
int64_t W = x->ne[0];
534466
int64_t H = x->ne[1];
535467

536-
x = ggml_reshape_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]); // [N*C, T, H, W] style
537-
538-
int64_t pad_h = (patch_size - H % patch_size) % patch_size;
539-
int64_t pad_w = (patch_size - W % patch_size) % patch_size;
540-
if (pad_h > 0 || pad_w > 0) {
541-
x = ggml_ext_pad(ctx->ggml_ctx, x, static_cast<int>(pad_w), static_cast<int>(pad_h), 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
542-
}
543-
544-
auto padding_mask = ggml_ext_zeros(ctx->ggml_ctx, x->ne[0], x->ne[1], x->ne[2], 1);
545-
x = ggml_concat(ctx->ggml_ctx, x, padding_mask, 3); // concat mask channel
546-
547-
x = patchify_2d(ctx->ggml_ctx, x, patch_size); // [C*4, T, H/2, W/2]
468+
auto padding_mask = ggml_ext_zeros(ctx->ggml_ctx, x->ne[0], x->ne[1], 1, x->ne[3]);
469+
x = ggml_concat(ctx->ggml_ctx, x, padding_mask, 2); // [N, C + 1, H, W]
548470

549-
int64_t w_len = x->ne[0];
550-
int64_t h_len = x->ne[1];
551-
int64_t t_len = x->ne[2];
552-
x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3], 1);
553-
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, n_token, C]
471+
x = DiT::pad_and_patchify(ctx, x, patch_size, patch_size); // [N, h*w, (C+1)*ph*pw]
554472

555473
x = x_embedder->forward(ctx, x);
556474

@@ -586,15 +504,9 @@ namespace Anima {
586504
x = block->forward(ctx, x, encoder_hidden_states, embedded_timestep, temb, image_pe);
587505
}
588506

589-
x = final_layer->forward(ctx, x, embedded_timestep, temb); // [N, n_token, C*4]
590-
591-
x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [n_token, C*4, N]
592-
x = ggml_reshape_4d(ctx->ggml_ctx, x, w_len, h_len, t_len, x->ne[1]); // [C*4, T, H/2, W/2]
593-
x = unpatchify_2d(ctx->ggml_ctx, x, patch_size); // [C, T, H, W]
507+
x = final_layer->forward(ctx, x, embedded_timestep, temb); // [N, h*w, ph*pw*C]
594508

595-
x = ggml_ext_slice(ctx->ggml_ctx, x, 1, 0, H); // [C, T, H, W + pad]
596-
x = ggml_ext_slice(ctx->ggml_ctx, x, 0, 0, W); // [C, T, H, W]
597-
x = ggml_reshape_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], x->ne[3], x->ne[2]); // [N, C, H, W]
509+
x = DiT::unpatchify_and_crop(ctx->ggml_ctx, x, H, W, patch_size, patch_size, false); // [N, C, H, W]
598510

599511
return x;
600512
}

src/common.hpp renamed to src/common_block.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef __COMMON_HPP__
2-
#define __COMMON_HPP__
1+
#ifndef __COMMON_BLOCK_HPP__
2+
#define __COMMON_BLOCK_HPP__
33

44
#include "ggml_extend.hpp"
55

@@ -590,4 +590,4 @@ class VideoResBlock : public ResBlock {
590590
}
591591
};
592592

593-
#endif // __COMMON_HPP__
593+
#endif // __COMMON_BLOCK_HPP__

src/common_dit.hpp

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#ifndef __COMMON_DIT_HPP__
2+
#define __COMMON_DIT_HPP__
3+
4+
#include "ggml_extend.hpp"
5+
6+
namespace DiT {
7+
ggml_tensor* patchify(ggml_context* ctx,
8+
ggml_tensor* x,
9+
int pw,
10+
int ph,
11+
bool patch_last = true) {
12+
// x: [N, C, H, W]
13+
// return: [N, h*w, C*ph*pw] if patch_last else [N, h*w, ph*pw*C]
14+
int64_t N = x->ne[3];
15+
int64_t C = x->ne[2];
16+
int64_t H = x->ne[1];
17+
int64_t W = x->ne[0];
18+
int64_t h = H / ph;
19+
int64_t w = W / pw;
20+
21+
GGML_ASSERT(h * ph == H && w * pw == W);
22+
23+
x = ggml_reshape_4d(ctx, x, pw, w, ph, h * C * N); // [N*C*h, ph, w, pw]
24+
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, ph, pw]
25+
x = ggml_reshape_4d(ctx, x, pw * ph, w * h, C, N); // [N, C, h*w, ph*pw]
26+
if (patch_last) {
27+
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, h*w, C, ph*pw]
28+
x = ggml_reshape_3d(ctx, x, pw * ph * C, w * h, N); // [N, h*w, C*ph*pw]
29+
} else {
30+
x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [N, h*w, C, ph*pw]
31+
x = ggml_reshape_3d(ctx, x, C * pw * ph, w * h, N); // [N, h*w, ph*pw*C]
32+
}
33+
return x;
34+
}
35+
36+
ggml_tensor* unpatchify(ggml_context* ctx,
37+
ggml_tensor* x,
38+
int64_t h,
39+
int64_t w,
40+
int ph,
41+
int pw,
42+
bool patch_last = true) {
43+
// x: [N, h*w, C*ph*pw] if patch_last else [N, h*w, ph*pw*C]
44+
// return: [N, C, H, W]
45+
int64_t N = x->ne[2];
46+
int64_t C = x->ne[0] / ph / pw;
47+
int64_t H = h * ph;
48+
int64_t W = w * pw;
49+
50+
GGML_ASSERT(C * ph * pw == x->ne[0]);
51+
52+
if (patch_last) {
53+
x = ggml_reshape_4d(ctx, x, pw * ph, C, w * h, N); // [N, h*w, C, ph*pw]
54+
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, h*w, ph*pw]
55+
} else {
56+
x = ggml_reshape_4d(ctx, x, C, pw * ph, w * h, N); // [N, h*w, ph*pw, C]
57+
x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // [N, C, h*w, ph*pw]
58+
}
59+
60+
x = ggml_reshape_4d(ctx, x, pw, ph, w, h * C * N); // [N*C*h, w, ph, pw]
61+
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, ph, w, pw]
62+
x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*ph, w*pw]
63+
64+
return x;
65+
}
66+
67+
ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
68+
ggml_tensor* x,
69+
int ph,
70+
int pw) {
71+
int64_t W = x->ne[0];
72+
int64_t H = x->ne[1];
73+
74+
int pad_h = (ph - H % ph) % ph;
75+
int pad_w = (pw - W % pw) % pw;
76+
x = ggml_ext_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0, ctx->circular_x_enabled, ctx->circular_y_enabled);
77+
return x;
78+
}
79+
80+
ggml_tensor* pad_and_patchify(GGMLRunnerContext* ctx,
81+
ggml_tensor* x,
82+
int ph,
83+
int pw,
84+
bool patch_last = true) {
85+
x = pad_to_patch_size(ctx, x, ph, pw);
86+
x = patchify(ctx->ggml_ctx, x, ph, pw, patch_last);
87+
return x;
88+
}
89+
90+
ggml_tensor* unpatchify_and_crop(ggml_context* ctx,
91+
ggml_tensor* x,
92+
int64_t H,
93+
int64_t W,
94+
int ph,
95+
int pw,
96+
bool patch_last = true) {
97+
int pad_h = (ph - H % ph) % ph;
98+
int pad_w = (pw - W % pw) % pw;
99+
int64_t h = ((H + pad_h) / ph);
100+
int64_t w = ((W + pad_w) / pw);
101+
x = unpatchify(ctx, x, h, w, ph, pw, patch_last); // [N, C, H + pad_h, W + pad_w]
102+
x = ggml_ext_slice(ctx, x, 1, 0, H); // [N, C, H, W + pad_w]
103+
x = ggml_ext_slice(ctx, x, 0, 0, W); // [N, C, H, W]
104+
return x;
105+
}
106+
} // namespace DiT
107+
108+
#endif // __COMMON_DIT_HPP__

src/control.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
#ifndef __CONTROL_HPP__
22
#define __CONTROL_HPP__
33

4-
#include "common.hpp"
5-
#include "ggml_extend.hpp"
4+
#include "common_block.hpp"
65
#include "model.h"
76

87
#define CONTROL_NET_GRAPH_SIZE 1536

0 commit comments

Comments
 (0)