Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support Inpaint models #511

Merged
merged 19 commits into from
Dec 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Flux Fill working!!
  • Loading branch information
stduhpf committed Dec 7, 2024
commit 0683c03120a711296a17da3d5b28d604553764da
2 changes: 1 addition & 1 deletion diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ struct FluxModel : public DiffusionModel {
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) {
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx, skip_layers);
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, output, output_ctx, skip_layers);
}
};

Expand Down
34 changes: 28 additions & 6 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ namespace Flux {
Flux() {}
Flux(FluxParams params)
: params(params) {
int64_t pe_dim = params.hidden_size / params.num_heads;
int64_t pe_dim = params.hidden_size / params.num_heads;

blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
blocks["time_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
Expand Down Expand Up @@ -789,6 +789,7 @@ namespace Flux {
struct ggml_tensor* x,
struct ggml_tensor* timestep,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor* pe,
Expand All @@ -797,6 +798,7 @@ namespace Flux {
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
// timestep: (N,) tensor of diffusion timesteps
// context: (N, L, D)
// c_concat: NULL, or for (N,C+M, H, W) for Fill
// y: (N, adm_in_channels) tensor of class labels
// guidance: (N,)
// pe: (L, d_head/2, 2, 2)
Expand All @@ -806,6 +808,7 @@ namespace Flux {

int64_t W = x->ne[0];
int64_t H = x->ne[1];
int64_t C = x->ne[2];
int64_t patch_size = 2;
int pad_h = (patch_size - H % patch_size) % patch_size;
int pad_w = (patch_size - W % patch_size) % patch_size;
Expand All @@ -814,6 +817,19 @@ namespace Flux {
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]

if (c_concat != NULL) {
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);

masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);

masked = patchify(ctx, masked, patch_size);
mask = patchify(ctx, mask, patch_size);

img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
}

auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]

// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
Expand Down Expand Up @@ -841,7 +857,7 @@ namespace Flux {
flux_params.guidance_embed = false;
flux_params.depth = 0;
flux_params.depth_single_blocks = 0;
if (version == VERSION_FLUX_INPAINT) {
if (version == VERSION_FLUX_FILL) {
flux_params.in_channels = 384;
}
for (auto pair : tensor_types) {
Expand Down Expand Up @@ -890,14 +906,18 @@ namespace Flux {
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
std::vector<int> skip_layers = std::vector<int>()) {
GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);

x = to_backend(x);
context = to_backend(context);
x = to_backend(x);
context = to_backend(context);
if (c_concat != NULL) {
c_concat = to_backend(c_concat);
}
y = to_backend(y);
timesteps = to_backend(timesteps);
if (flux_params.guidance_embed) {
Expand All @@ -917,6 +937,7 @@ namespace Flux {
x,
timesteps,
context,
c_concat,
y,
guidance,
pe,
Expand All @@ -931,6 +952,7 @@ namespace Flux {
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor** output = NULL,
Expand All @@ -942,7 +964,7 @@ namespace Flux {
// y: [N, adm_in_channels] or [1, adm_in_channels]
// guidance: [N, ]
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, y, guidance, skip_layers);
return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers);
};

GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
Expand Down Expand Up @@ -982,7 +1004,7 @@ namespace Flux {
struct ggml_tensor* out = NULL;

int t0 = ggml_time_ms();
compute(8, x, timesteps, context, y, guidance, &out, work_ctx);
compute(8, x, timesteps, context, NULL, y, guidance, &out, work_ctx);
int t1 = ggml_time_ms();

print_ggml_tensor(out);
Expand Down
2 changes: 1 addition & 1 deletion model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1514,7 +1514,7 @@ SDVersion ModelLoader::get_sd_version() {
if (is_flux) {
is_inpaint = input_block_weight.ne[0] == 384;
if (is_inpaint) {
return VERSION_FLUX_INPAINT;
return VERSION_FLUX_FILL;
}
return VERSION_FLUX;
}
Expand Down
6 changes: 3 additions & 3 deletions model.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ enum SDVersion {
VERSION_SVD,
VERSION_SD3,
VERSION_FLUX,
VERSION_FLUX_INPAINT,
VERSION_FLUX_FILL,
VERSION_COUNT,
};

static inline bool sd_version_is_flux(SDVersion version) {
if (version == VERSION_FLUX || version == VERSION_FLUX_INPAINT) {
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
return true;
}
return false;
Expand Down Expand Up @@ -67,7 +67,7 @@ static inline bool sd_version_is_sdxl(SDVersion version) {
}

static inline bool sd_version_is_inpaint(SDVersion version) {
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_INPAINT) {
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
return true;
}
return false;
Expand Down
63 changes: 47 additions & 16 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,6 @@ class StableDiffusionGGML {
} else if (sd_version_is_flux(version)) {
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, model_loader.tensor_storages_types);
diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn);
} else if (version == VERSION_LTXV) {
// TODO: cond for T5 only
cond_stage_model = std::make_shared<SimpleT5Embedder>(clip_backend, model_loader.tensor_storages_types);
diffusion_model = std::make_shared<LTXModel>(backend, model_loader.tensor_storages_types, diffusion_flash_attn);
} else {
if (id_embeddings_path.find("v2") != std::string::npos) {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, model_loader.tensor_storages_types, embeddings_path, version, PM_VERSION_2);
Expand Down Expand Up @@ -798,6 +794,7 @@ class StableDiffusionGGML {
float skip_layer_start = 0.01,
float skip_layer_end = 0.2,
ggml_tensor* noise_mask = nullptr) {
LOG_DEBUG("Sample");
struct ggml_init_params params;
size_t data_size = ggml_row_size(init_latent->type, init_latent->ne[0]);
for (int i = 1; i < 4; i++) {
Expand Down Expand Up @@ -1394,13 +1391,27 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
ggml_tensor* noise_mask = nullptr;
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
if (masked_image == NULL) {
int64_t mask_channels = 1;
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
mask_channels = 8 * 8; // flatten the whole mask
}
// no mask, set the whole image as masked
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2] + 1, 1);
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
for (int64_t x = 0; x < masked_image->ne[0]; x++) {
for (int64_t y = 0; y < masked_image->ne[1]; y++) {
ggml_tensor_set_f32(masked_image, 1, x, y, 0);
for (int64_t c = 1; c < masked_image->ne[2]; c++) {
ggml_tensor_set_f32(masked_image, 0, x, y, c);
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
// TODO: this might be wrong
for (int64_t c = 0; c < init_latent->ne[2]; c++) {
ggml_tensor_set_f32(masked_image, 0, x, y, c);
}
for (int64_t c = init_latent->ne[2]; c < masked_image->ne[2]; c++) {
ggml_tensor_set_f32(masked_image, 1, x, y, c);
}
} else {
ggml_tensor_set_f32(masked_image, 1, x, y, 0);
for (int64_t c = 1; c < masked_image->ne[2]; c++) {
ggml_tensor_set_f32(masked_image, 0, x, y, c);
}
}
}
}
Expand Down Expand Up @@ -1676,6 +1687,10 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
ggml_tensor* masked_image;

if (sd_version_is_inpaint(sd_ctx->sd->version)) {
int64_t mask_channels = 1;
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
mask_channels = 8 * 8; // flatten the whole mask
}
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
sd_apply_mask(init_img, mask_img, masked_img);
ggml_tensor* masked_image_0 = NULL;
Expand All @@ -1685,17 +1700,33 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
} else {
masked_image_0 = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
}
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_image_0->ne[0], masked_image_0->ne[1], masked_image_0->ne[2] + 1, 1);
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_image_0->ne[0], masked_image_0->ne[1], mask_channels + masked_image_0->ne[2], 1);
for (int ix = 0; ix < masked_image_0->ne[0]; ix++) {
for (int iy = 0; iy < masked_image_0->ne[1]; iy++) {
for (int k = 0; k < masked_image_0->ne[2]; k++) {
float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k);
ggml_tensor_set_f32(masked_image, v, ix, iy, k + 1);
int mx = ix * 8;
int my = iy * 8;
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
for (int k = 0; k < masked_image_0->ne[2]; k++) {
float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k);
ggml_tensor_set_f32(masked_image, v, ix, iy, k);
}
// "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image
for (int x = 0; x < 8; x++) {
for (int y = 0; y < 8; y++) {
float m = ggml_tensor_get_f32(mask_img, mx + x, my + y);
// TODO: check if the way the mask is flattened is correct (is it supposed to be x*8+y or x+8*y?)
// python code was using "b (h 8) (w 8) -> b (8 8) h w"
ggml_tensor_set_f32(masked_image, m, ix, iy, masked_image_0->ne[2] + x * 8 + y);
}
}
} else {
float m = ggml_tensor_get_f32(mask_img, mx, my);
ggml_tensor_set_f32(masked_image, m, ix, iy, 0);
for (int k = 0; k < masked_image_0->ne[2]; k++) {
float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k);
ggml_tensor_set_f32(masked_image, v, ix, iy, k + mask_channels);
}
}
int mx = ix * 8;
int my = iy * 8;
float m = ggml_tensor_get_f32(mask_img, mx, my);
ggml_tensor_set_f32(masked_image, m, ix, iy, 0);
}
}
} else {
Expand Down
Loading