Skip to content

Commit b314d80

Browse files
wbrunaleejet
andauthored
feat: turn flow_shift into a generation parameter (leejet#1289)
* feat: turn flow_shift into a generation parameter * format code * simplify set_shift/set_parameters * fix sd_sample_params_to_str * remove unused variable * update docs --------- Co-authored-by: leejet <leejet714@gmail.com>
1 parent c9cd497 commit b314d80

6 files changed

Lines changed: 48 additions & 78 deletions

File tree

examples/cli/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ Context Options:
4444
CPU physical cores
4545
--chroma-t5-mask-pad <int> t5 mask pad size of chroma
4646
--vae-tile-overlap <float> tile overlap for vae tiling, in fraction of tile size (default: 0.5)
47-
--flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto)
4847
--vae-tiling process vae in tiles to reduce memory usage
4948
--force-sdxl-vae-conv-scale force use of conv scale on sdxl vae
5049
--offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed
@@ -109,6 +108,7 @@ Generation Options:
109108
--skip-layer-start <float> SLG enabling point (default: 0.01)
110109
--skip-layer-end <float> SLG disabling point (default: 0.2)
111110
--eta <float> eta in DDIM, only for DDIM/TCD/res_multistep/res_2s (default: 0)
111+
--flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto)
112112
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
113113
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
114114
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5)

examples/common/common.hpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -581,10 +581,6 @@ struct SDContextParams {
581581
"--vae-tile-overlap",
582582
"tile overlap for vae tiling, in fraction of tile size (default: 0.5)",
583583
&vae_tiling_params.target_overlap},
584-
{"",
585-
"--flow-shift",
586-
"shift value for Flow models like SD3.x or WAN (default: auto)",
587-
&flow_shift},
588584
};
589585

590586
options.bool_options = {
@@ -903,7 +899,6 @@ struct SDContextParams {
903899
<< " photo_maker_path: \"" << photo_maker_path << "\",\n"
904900
<< " rng_type: " << sd_rng_type_name(rng_type) << ",\n"
905901
<< " sampler_rng_type: " << sd_rng_type_name(sampler_rng_type) << ",\n"
906-
<< " flow_shift: " << (std::isinf(flow_shift) ? "INF" : std::to_string(flow_shift)) << "\n"
907902
<< " offload_params_to_cpu: " << (offload_params_to_cpu ? "true" : "false") << ",\n"
908903
<< " enable_mmap: " << (enable_mmap ? "true" : "false") << ",\n"
909904
<< " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
@@ -986,7 +981,6 @@ struct SDContextParams {
986981
chroma_use_t5_mask,
987982
chroma_t5_mask_pad,
988983
qwen_image_zero_cond_t,
989-
flow_shift,
990984
};
991985
return sd_ctx_params;
992986
}
@@ -1206,6 +1200,10 @@ struct SDGenerationParams {
12061200
"--eta",
12071201
"eta in DDIM, only for DDIM and TCD (default: 0)",
12081202
&sample_params.eta},
1203+
{"",
1204+
"--flow-shift",
1205+
"shift value for Flow models like SD3.x or WAN (default: auto)",
1206+
&sample_params.flow_shift},
12091207
{"",
12101208
"--high-noise-cfg-scale",
12111209
"(high noise) unconditional guidance scale: (default: 7.0)",
@@ -1606,6 +1604,7 @@ struct SDGenerationParams {
16061604
load_if_exists("cfg_scale", sample_params.guidance.txt_cfg);
16071605
load_if_exists("img_cfg_scale", sample_params.guidance.img_cfg);
16081606
load_if_exists("guidance", sample_params.guidance.distilled_guidance);
1607+
load_if_exists("flow_shift", sample_params.flow_shift);
16091608

16101609
auto load_sampler_if_exists = [&](const char* key, enum sample_method_t& out) {
16111610
if (j.contains(key) && j[key].is_string()) {

examples/server/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ Context Options:
3636
CPU physical cores
3737
--chroma-t5-mask-pad <int> t5 mask pad size of chroma
3838
--vae-tile-overlap <float> tile overlap for vae tiling, in fraction of tile size (default: 0.5)
39-
--flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto)
4039
--vae-tiling process vae in tiles to reduce memory usage
4140
--force-sdxl-vae-conv-scale force use of conv scale on sdxl vae
4241
--offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed
@@ -101,6 +100,7 @@ Default Generation Options:
101100
--skip-layer-start <float> SLG enabling point (default: 0.01)
102101
--skip-layer-end <float> SLG disabling point (default: 0.2)
103102
--eta <float> eta in DDIM, only for DDIM/TCD/res_multistep/res_2s (default: 0)
103+
--flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto)
104104
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
105105
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
106106
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5)

include/stable-diffusion.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@ typedef struct {
201201
bool chroma_use_t5_mask;
202202
int chroma_t5_mask_pad;
203203
bool qwen_image_zero_cond_t;
204-
float flow_shift;
205204
} sd_ctx_params_t;
206205

207206
typedef struct {
@@ -235,6 +234,7 @@ typedef struct {
235234
int shifted_timestep;
236235
float* custom_sigmas;
237236
int custom_sigmas_count;
237+
float flow_shift;
238238
} sd_sample_params_t;
239239

240240
typedef struct {

src/denoiser.hpp

Lines changed: 9 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -657,9 +657,8 @@ struct DiscreteFlowDenoiser : public Denoiser {
657657

658658
float sigma_data = 1.0f;
659659

660-
DiscreteFlowDenoiser(float shift = 3.0f)
661-
: shift(shift) {
662-
set_parameters();
660+
DiscreteFlowDenoiser(float shift = 3.0f) {
661+
set_shift(shift);
663662
}
664663

665664
void set_parameters() {
@@ -668,6 +667,11 @@ struct DiscreteFlowDenoiser : public Denoiser {
668667
}
669668
}
670669

670+
void set_shift(float shift) {
671+
this->shift = shift;
672+
set_parameters();
673+
}
674+
671675
float sigma_min() override {
672676
return sigmas[0];
673677
}
@@ -710,34 +714,8 @@ float flux_time_shift(float mu, float sigma, float t) {
710714
return ::expf(mu) / (::expf(mu) + ::powf((1.0f / t - 1.0f), sigma));
711715
}
712716

713-
struct FluxFlowDenoiser : public Denoiser {
714-
float sigmas[TIMESTEPS];
715-
float shift = 1.15f;
716-
717-
float sigma_data = 1.0f;
718-
719-
FluxFlowDenoiser(float shift = 1.15f) {
720-
set_parameters(shift);
721-
}
722-
723-
void set_shift(float shift) {
724-
this->shift = shift;
725-
}
726-
727-
void set_parameters(float shift) {
728-
set_shift(shift);
729-
for (int i = 0; i < TIMESTEPS; i++) {
730-
sigmas[i] = t_to_sigma(static_cast<float>(i));
731-
}
732-
}
733-
734-
float sigma_min() override {
735-
return sigmas[0];
736-
}
737-
738-
float sigma_max() override {
739-
return sigmas[TIMESTEPS - 1];
740-
}
717+
struct FluxFlowDenoiser : public DiscreteFlowDenoiser {
718+
FluxFlowDenoiser() = default;
741719

742720
float sigma_to_t(float sigma) override {
743721
return sigma;
@@ -747,26 +725,6 @@ struct FluxFlowDenoiser : public Denoiser {
747725
t = t + 1;
748726
return flux_time_shift(shift, 1.0f, t / TIMESTEPS);
749727
}
750-
751-
std::vector<float> get_scalings(float sigma) override {
752-
float c_skip = 1.0f;
753-
float c_out = -sigma;
754-
float c_in = 1.0f;
755-
return {c_skip, c_out, c_in};
756-
}
757-
758-
// this function will modify noise/latent
759-
ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) override {
760-
ggml_ext_tensor_scale_inplace(noise, sigma);
761-
ggml_ext_tensor_scale_inplace(latent, 1.0f - sigma);
762-
ggml_ext_tensor_add_inplace(latent, noise);
763-
return latent;
764-
}
765-
766-
ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) override {
767-
ggml_ext_tensor_scale_inplace(latent, 1.0f / (1.0f - sigma));
768-
return latent;
769-
}
770728
};
771729

772730
struct Flux2FlowDenoiser : public FluxFlowDenoiser {

src/stable-diffusion.cpp

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class StableDiffusionGGML {
115115
int n_threads = -1;
116116
float scale_factor = 0.18215f;
117117
float shift_factor = 0.f;
118+
float default_flow_shift = INFINITY;
118119

119120
std::shared_ptr<Conditioner> cond_stage_model;
120121
std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd or wan2.1 i2v
@@ -881,7 +882,6 @@ class StableDiffusionGGML {
881882
// init denoiser
882883
{
883884
prediction_t pred_type = sd_ctx_params->prediction;
884-
float flow_shift = sd_ctx_params->flow_shift;
885885

886886
if (pred_type == PREDICTION_COUNT) {
887887
if (sd_version_is_sd2(version)) {
@@ -906,22 +906,19 @@ class StableDiffusionGGML {
906906
sd_version_is_qwen_image(version) ||
907907
sd_version_is_z_image(version)) {
908908
pred_type = FLOW_PRED;
909-
if (flow_shift == INFINITY) {
910-
if (sd_version_is_wan(version)) {
911-
flow_shift = 5.f;
912-
} else {
913-
flow_shift = 3.f;
914-
}
909+
if (sd_version_is_wan(version)) {
910+
default_flow_shift = 5.f;
911+
} else {
912+
default_flow_shift = 3.f;
915913
}
916914
} else if (sd_version_is_flux(version)) {
917915
pred_type = FLUX_FLOW_PRED;
918916

919-
if (flow_shift == INFINITY) {
920-
flow_shift = 1.0f; // TODO: validate
921-
for (const auto& [name, tensor_storage] : tensor_storage_map) {
922-
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
923-
flow_shift = 1.15f;
924-
}
917+
default_flow_shift = 1.0f; // TODO: validate
918+
for (const auto& [name, tensor_storage] : tensor_storage_map) {
919+
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
920+
default_flow_shift = 1.15f;
921+
break;
925922
}
926923
}
927924
} else if (sd_version_is_flux2(version)) {
@@ -945,12 +942,12 @@ class StableDiffusionGGML {
945942
break;
946943
case FLOW_PRED: {
947944
LOG_INFO("running in FLOW mode");
948-
denoiser = std::make_shared<DiscreteFlowDenoiser>(flow_shift);
945+
denoiser = std::make_shared<DiscreteFlowDenoiser>();
949946
break;
950947
}
951948
case FLUX_FLOW_PRED: {
952949
LOG_INFO("running in Flux FLOW mode");
953-
denoiser = std::make_shared<FluxFlowDenoiser>(flow_shift);
950+
denoiser = std::make_shared<FluxFlowDenoiser>();
954951
break;
955952
}
956953
case FLUX2_FLOW_PRED: {
@@ -2711,6 +2708,16 @@ class StableDiffusionGGML {
27112708
ggml_ext_tensor_clamp_inplace(result, 0.0f, 1.0f);
27122709
return result;
27132710
}
2711+
2712+
void set_flow_shift(float flow_shift = INFINITY) {
2713+
auto flow_denoiser = std::dynamic_pointer_cast<DiscreteFlowDenoiser>(denoiser);
2714+
if (flow_denoiser) {
2715+
if (flow_shift == INFINITY) {
2716+
flow_shift = default_flow_shift;
2717+
}
2718+
flow_denoiser->set_shift(flow_shift);
2719+
}
2720+
}
27142721
};
27152722

27162723
/*================================================= SD API ==================================================*/
@@ -2931,7 +2938,6 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
29312938
sd_ctx_params->chroma_use_dit_mask = true;
29322939
sd_ctx_params->chroma_use_t5_mask = false;
29332940
sd_ctx_params->chroma_t5_mask_pad = 1;
2934-
sd_ctx_params->flow_shift = INFINITY;
29352941
}
29362942

29372943
char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
@@ -3023,6 +3029,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
30233029
sample_params->sample_steps = 20;
30243030
sample_params->custom_sigmas = nullptr;
30253031
sample_params->custom_sigmas_count = 0;
3032+
sample_params->flow_shift = INFINITY;
30263033
}
30273034

30283035
char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
@@ -3043,7 +3050,8 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
30433050
"sample_method: %s, "
30443051
"sample_steps: %d, "
30453052
"eta: %.2f, "
3046-
"shifted_timestep: %d)",
3053+
"shifted_timestep: %d, "
3054+
"flow_shift: %.2f)",
30473055
sample_params->guidance.txt_cfg,
30483056
std::isfinite(sample_params->guidance.img_cfg)
30493057
? sample_params->guidance.img_cfg
@@ -3057,7 +3065,8 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
30573065
sd_sample_method_name(sample_params->sample_method),
30583066
sample_params->sample_steps,
30593067
sample_params->eta,
3060-
sample_params->shifted_timestep);
3068+
sample_params->shifted_timestep,
3069+
sample_params->flow_shift);
30613070

30623071
return buf;
30633072
}
@@ -3528,6 +3537,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
35283537

35293538
size_t t0 = ggml_time_ms();
35303539

3540+
sd_ctx->sd->set_flow_shift(sd_img_gen_params->sample_params.flow_shift);
3541+
35313542
// Apply lora
35323543
sd_ctx->sd->apply_loras(sd_img_gen_params->loras, sd_img_gen_params->lora_count);
35333544

@@ -3803,6 +3814,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
38033814
}
38043815
LOG_INFO("generate_video %dx%dx%d", width, height, frames);
38053816

3817+
sd_ctx->sd->set_flow_shift(sd_vid_gen_params->sample_params.flow_shift);
3818+
38063819
enum sample_method_t sample_method = sd_vid_gen_params->sample_params.sample_method;
38073820
if (sample_method == SAMPLE_METHOD_COUNT) {
38083821
sample_method = sd_get_default_sample_method(sd_ctx);

0 commit comments

Comments
 (0)