Skip to content

Commit fa56e2f

Browse files
committed
prepare for other pix2pix-like models
1 parent 8d5cf8f commit fa56e2f

File tree

4 files changed

+12
-8
lines changed

4 files changed

+12
-8
lines changed

model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1560,7 +1560,7 @@ SDVersion ModelLoader::get_sd_version() {
15601560
return VERSION_SD1_INPAINT;
15611561
}
15621562
if(is_ip2p) {
1563-
return VERSION_INSTRUCT_PIX2PIX;
1563+
return VERSION_SD1_PIX2PIX;
15641564
}
15651565
return VERSION_SD1;
15661566
} else if (token_embedding_weight.ne[0] == 1024) {

model.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
enum SDVersion {
2222
VERSION_SD1,
2323
VERSION_SD1_INPAINT,
24-
VERSION_INSTRUCT_PIX2PIX,
24+
VERSION_SD1_PIX2PIX,
2525
VERSION_SD2,
2626
VERSION_SD2_INPAINT,
2727
VERSION_SDXL,
@@ -48,7 +48,7 @@ static inline bool sd_version_is_sd3(SDVersion version) {
4848
}
4949

5050
static inline bool sd_version_is_sd1(SDVersion version) {
51-
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_INSTRUCT_PIX2PIX) {
51+
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX) {
5252
return true;
5353
}
5454
return false;
@@ -82,8 +82,12 @@ static inline bool sd_version_is_dit(SDVersion version) {
8282
return false;
8383
}
8484

85+
static inline bool sd_version_is_edit(SDVersion version) {
86+
return version == VERSION_SD1_PIX2PIX;
87+
}
88+
8589
static bool sd_version_use_concat(SDVersion version) {
86-
return version == VERSION_INSTRUCT_PIX2PIX || sd_version_is_inpaint(version);
90+
return sd_version_is_edit(version) || sd_version_is_inpaint(version);
8791
}
8892

8993
enum PMVersion {

stable-diffusion.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,7 +1422,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
14221422
sd_ctx->sd->diffusion_model->get_adm_in_channels());
14231423

14241424
SDCondition uncond;
1425-
if (cfg_scale != 1.0 || sd_ctx->sd->version == VERSION_INSTRUCT_PIX2PIX && cfg_scale != guidance) {
1425+
if (cfg_scale != 1.0 || sd_version_use_concat(sd_ctx->sd->version) && cfg_scale != guidance) {
14261426
bool force_zero_embeddings = false;
14271427
if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0) {
14281428
force_zero_embeddings = true;
@@ -1493,7 +1493,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
14931493
cond.c_concat = masked_latent;
14941494
uncond.c_concat = empty_latent;
14951495
// noise_mask = masked_latent;
1496-
} else if (sd_ctx->sd->version == VERSION_INSTRUCT_PIX2PIX) {
1496+
} else if (sd_version_is_edit(sd_ctx->sd->version)) {
14971497
cond.c_concat = masked_latent;
14981498
auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_latent->ne[0], masked_latent->ne[1], masked_latent->ne[2], masked_latent->ne[3]);
14991499
ggml_set_f32(empty_latent, 0);
@@ -1825,7 +1825,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
18251825
}
18261826
}
18271827
}
1828-
} else if (sd_ctx->sd->version == VERSION_INSTRUCT_PIX2PIX) {
1828+
} else if (sd_version_is_edit(sd_ctx->sd->version)) {
18291829
// Not actually masked, we're just highjacking the masked_latent variable since it will be used the same way
18301830
if (!sd_ctx->sd->use_tiny_autoencoder) {
18311831
masked_latent = sd_ctx->sd->get_first_stage_encoding_mode(work_ctx, init_moments);

unet.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ class UnetModelBlock : public GGMLBlock {
207207
}
208208
if (sd_version_is_inpaint(version)) {
209209
in_channels = 9;
210-
} else if (version == VERSION_INSTRUCT_PIX2PIX) {
210+
} else if (version == VERSION_SD1_PIX2PIX) {
211211
in_channels = 8;
212212
}
213213

0 commit comments

Comments
 (0)