Skip to content

Commit 8d5cf8f

Browse files
committed
Fix pix2pix latent inputs + improve inpainting a bit + fix naming
1 parent 4024765 commit 8d5cf8f

File tree

1 file changed

+89
-57
lines changed

1 file changed

+89
-57
lines changed

stable-diffusion.cpp

Lines changed: 89 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,30 @@ class StableDiffusionGGML {
10721072
return latent;
10731073
}
10741074

1075+
ggml_tensor*
1076+
get_first_stage_encoding_mode(ggml_context* work_ctx, ggml_tensor* moments) {
1077+
// ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample
1078+
ggml_tensor* latent = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]);
1079+
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latent);
1080+
ggml_tensor_set_f32_randn(noise, rng);
1081+
// noise = load_tensor_from_file(work_ctx, "noise.bin");
1082+
{
1083+
float mean = 0;
1084+
for (int i = 0; i < latent->ne[3]; i++) {
1085+
for (int j = 0; j < latent->ne[2]; j++) {
1086+
for (int k = 0; k < latent->ne[1]; k++) {
1087+
for (int l = 0; l < latent->ne[0]; l++) {
1088+
// mode and mean are the same for gaussians
1089+
mean = ggml_tensor_get_f32(moments, l, k, j, i);
1090+
ggml_tensor_set_f32(latent, mean, l, k, j, i);
1091+
}
1092+
}
1093+
}
1094+
}
1095+
}
1096+
return latent;
1097+
}
1098+
10751099
ggml_tensor* compute_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode) {
10761100
int64_t W = x->ne[0];
10771101
int64_t H = x->ne[1];
@@ -1250,7 +1274,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
12501274
float slg_scale = 0,
12511275
float skip_layer_start = 0.01,
12521276
float skip_layer_end = 0.2,
1253-
ggml_tensor* masked_image = NULL) {
1277+
ggml_tensor* masked_latent = NULL) {
12541278
if (seed < 0) {
12551279
// Generally, when using the provided command line, the seed is always >0.
12561280
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
@@ -1439,42 +1463,43 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
14391463
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
14401464
ggml_tensor* noise_mask = nullptr;
14411465
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
1442-
if (masked_image == NULL) {
1443-
int64_t mask_channels = 1;
1444-
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
1445-
mask_channels = 8 * 8; // flatten the whole mask
1446-
}
1447-
// no mask, set the whole image as masked
1448-
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
1449-
for (int64_t x = 0; x < masked_image->ne[0]; x++) {
1450-
for (int64_t y = 0; y < masked_image->ne[1]; y++) {
1451-
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
1452-
// TODO: this might be wrong
1453-
for (int64_t c = 0; c < init_latent->ne[2]; c++) {
1454-
ggml_tensor_set_f32(masked_image, 0, x, y, c);
1455-
}
1456-
for (int64_t c = init_latent->ne[2]; c < masked_image->ne[2]; c++) {
1457-
ggml_tensor_set_f32(masked_image, 1, x, y, c);
1458-
}
1459-
} else {
1460-
ggml_tensor_set_f32(masked_image, 1, x, y, 0);
1461-
for (int64_t c = 1; c < masked_image->ne[2]; c++) {
1462-
ggml_tensor_set_f32(masked_image, 0, x, y, c);
1463-
}
1466+
int64_t mask_channels = 1;
1467+
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
1468+
mask_channels = 8 * 8; // flatten the whole mask
1469+
}
1470+
auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
1471+
// no mask, set the whole image as masked
1472+
for (int64_t x = 0; x < empty_latent->ne[0]; x++) {
1473+
for (int64_t y = 0; y < empty_latent->ne[1]; y++) {
1474+
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
1475+
// TODO: this might be wrong
1476+
for (int64_t c = 0; c < init_latent->ne[2]; c++) {
1477+
ggml_tensor_set_f32(empty_latent, 0, x, y, c);
1478+
}
1479+
for (int64_t c = init_latent->ne[2]; c < empty_latent->ne[2]; c++) {
1480+
ggml_tensor_set_f32(empty_latent, 1, x, y, c);
1481+
}
1482+
} else {
1483+
ggml_tensor_set_f32(empty_latent, 1, x, y, 0);
1484+
for (int64_t c = 1; c < empty_latent->ne[2]; c++) {
1485+
ggml_tensor_set_f32(empty_latent, 0, x, y, c);
14641486
}
14651487
}
14661488
}
14671489
}
1468-
cond.c_concat = masked_image;
1469-
uncond.c_concat = masked_image;
1470-
// noise_mask = masked_image;
1490+
if (masked_latent == NULL) {
1491+
masked_latent = empty_latent;
1492+
}
1493+
cond.c_concat = masked_latent;
1494+
uncond.c_concat = empty_latent;
1495+
// noise_mask = masked_latent;
14711496
} else if (sd_ctx->sd->version == VERSION_INSTRUCT_PIX2PIX) {
1472-
cond.c_concat = masked_image;
1473-
auto empty_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_image->ne[0], masked_image->ne[1], masked_image->ne[2], masked_image->ne[3]);
1474-
ggml_set_f32(empty_img, 0);
1475-
uncond.c_concat = empty_img;
1497+
cond.c_concat = masked_latent;
1498+
auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_latent->ne[0], masked_latent->ne[1], masked_latent->ne[2], masked_latent->ne[3]);
1499+
ggml_set_f32(empty_latent, 0);
1500+
uncond.c_concat = empty_latent;
14761501
} else {
1477-
noise_mask = masked_image;
1502+
noise_mask = masked_latent;
14781503
}
14791504

14801505
for (int b = 0; b < batch_count; b++) {
@@ -1744,71 +1769,78 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
17441769

17451770
sd_image_to_tensor(init_image.data, init_img);
17461771

1747-
ggml_tensor* init_latent = NULL;
1772+
ggml_tensor* masked_latent;
1773+
1774+
ggml_tensor* init_latent = NULL;
1775+
ggml_tensor* init_moments = NULL;
17481776
if (!sd_ctx->sd->use_tiny_autoencoder) {
1749-
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
1750-
init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
1777+
init_moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
1778+
init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, init_moments);
17511779
} else {
17521780
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
17531781
}
17541782

1755-
ggml_tensor* masked_image;
1756-
17571783
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
17581784
int64_t mask_channels = 1;
17591785
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
17601786
mask_channels = 8 * 8; // flatten the whole mask
17611787
}
17621788
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
1789+
// Restore init_img (encode_first_stage has side effects) TODO: remove the side effects?
1790+
sd_image_to_tensor(init_image.data, init_img);
17631791
sd_apply_mask(init_img, mask_img, masked_img);
1764-
ggml_tensor* masked_image_0 = NULL;
1792+
ggml_tensor* masked_latent_0 = NULL;
17651793
if (!sd_ctx->sd->use_tiny_autoencoder) {
17661794
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
1767-
masked_image_0 = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
1795+
masked_latent_0 = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
17681796
} else {
1769-
masked_image_0 = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
1797+
masked_latent_0 = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
17701798
}
1771-
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_image_0->ne[0], masked_image_0->ne[1], mask_channels + masked_image_0->ne[2], 1);
1772-
for (int ix = 0; ix < masked_image_0->ne[0]; ix++) {
1773-
for (int iy = 0; iy < masked_image_0->ne[1]; iy++) {
1799+
masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_latent_0->ne[0], masked_latent_0->ne[1], mask_channels + masked_latent_0->ne[2], 1);
1800+
for (int ix = 0; ix < masked_latent_0->ne[0]; ix++) {
1801+
for (int iy = 0; iy < masked_latent_0->ne[1]; iy++) {
17741802
int mx = ix * 8;
17751803
int my = iy * 8;
17761804
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
1777-
for (int k = 0; k < masked_image_0->ne[2]; k++) {
1778-
float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k);
1779-
ggml_tensor_set_f32(masked_image, v, ix, iy, k);
1805+
for (int k = 0; k < masked_latent_0->ne[2]; k++) {
1806+
float v = ggml_tensor_get_f32(masked_latent_0, ix, iy, k);
1807+
ggml_tensor_set_f32(masked_latent, v, ix, iy, k);
17801808
}
17811809
// "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image
17821810
for (int x = 0; x < 8; x++) {
17831811
for (int y = 0; y < 8; y++) {
17841812
float m = ggml_tensor_get_f32(mask_img, mx + x, my + y);
17851813
// TODO: check if the way the mask is flattened is correct (is it supposed to be x*8+y or x+8*y?)
17861814
// python code was using "b (h 8) (w 8) -> b (8 8) h w"
1787-
ggml_tensor_set_f32(masked_image, m, ix, iy, masked_image_0->ne[2] + x * 8 + y);
1815+
ggml_tensor_set_f32(masked_latent, m, ix, iy, masked_latent_0->ne[2] + x * 8 + y);
17881816
}
17891817
}
17901818
} else {
17911819
float m = ggml_tensor_get_f32(mask_img, mx, my);
1792-
ggml_tensor_set_f32(masked_image, m, ix, iy, 0);
1793-
for (int k = 0; k < masked_image_0->ne[2]; k++) {
1794-
float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k);
1795-
ggml_tensor_set_f32(masked_image, v, ix, iy, k + mask_channels);
1820+
ggml_tensor_set_f32(masked_latent, m, ix, iy, 0);
1821+
for (int k = 0; k < masked_latent_0->ne[2]; k++) {
1822+
float v = ggml_tensor_get_f32(masked_latent_0, ix, iy, k);
1823+
ggml_tensor_set_f32(masked_latent, v, ix, iy, k + mask_channels);
17961824
}
17971825
}
17981826
}
17991827
}
18001828
} else if (sd_ctx->sd->version == VERSION_INSTRUCT_PIX2PIX) {
1801-
// Not actually masked, we're just highjacking the masked_image variable since it will be used the same way
1802-
masked_image = init_latent;
1829+
// Not actually masked, we're just highjacking the masked_latent variable since it will be used the same way
1830+
if (!sd_ctx->sd->use_tiny_autoencoder) {
1831+
masked_latent = sd_ctx->sd->get_first_stage_encoding_mode(work_ctx, init_moments);
1832+
} else {
1833+
masked_latent = init_latent;
1834+
}
18031835
} else {
18041836
// LOG_WARN("Inpainting with a base model is not great");
1805-
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1);
1806-
for (int ix = 0; ix < masked_image->ne[0]; ix++) {
1807-
for (int iy = 0; iy < masked_image->ne[1]; iy++) {
1837+
masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1);
1838+
for (int ix = 0; ix < masked_latent->ne[0]; ix++) {
1839+
for (int iy = 0; iy < masked_latent->ne[1]; iy++) {
18081840
int mx = ix * 8;
18091841
int my = iy * 8;
18101842
float m = ggml_tensor_get_f32(mask_img, mx, my);
1811-
ggml_tensor_set_f32(masked_image, m, ix, iy);
1843+
ggml_tensor_set_f32(masked_latent, m, ix, iy);
18121844
}
18131845
}
18141846
}
@@ -1849,7 +1881,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
18491881
slg_scale,
18501882
skip_layer_start,
18511883
skip_layer_end,
1852-
masked_image);
1884+
masked_latent);
18531885

18541886
size_t t2 = ggml_time_ms();
18551887

0 commit comments

Comments
 (0)