@@ -115,6 +115,7 @@ class StableDiffusionGGML {
115115 int n_threads = -1 ;
116116 float scale_factor = 0 .18215f ;
117117 float shift_factor = 0 .f;
118+ float default_flow_shift = INFINITY;
118119
119120 std::shared_ptr<Conditioner> cond_stage_model;
120121 std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd or wan2.1 i2v
@@ -881,7 +882,6 @@ class StableDiffusionGGML {
881882 // init denoiser
882883 {
883884 prediction_t pred_type = sd_ctx_params->prediction ;
884- float flow_shift = sd_ctx_params->flow_shift ;
885885
886886 if (pred_type == PREDICTION_COUNT) {
887887 if (sd_version_is_sd2 (version)) {
@@ -906,22 +906,19 @@ class StableDiffusionGGML {
906906 sd_version_is_qwen_image (version) ||
907907 sd_version_is_z_image (version)) {
908908 pred_type = FLOW_PRED;
909- if (flow_shift == INFINITY) {
910- if (sd_version_is_wan (version)) {
911- flow_shift = 5 .f ;
912- } else {
913- flow_shift = 3 .f ;
914- }
909+ if (sd_version_is_wan (version)) {
910+ default_flow_shift = 5 .f ;
911+ } else {
912+ default_flow_shift = 3 .f ;
915913 }
916914 } else if (sd_version_is_flux (version)) {
917915 pred_type = FLUX_FLOW_PRED;
918916
919- if (flow_shift == INFINITY) {
920- flow_shift = 1 .0f ; // TODO: validate
921- for (const auto & [name, tensor_storage] : tensor_storage_map) {
922- if (starts_with (name, " model.diffusion_model.guidance_in.in_layer.weight" )) {
923- flow_shift = 1 .15f ;
924- }
917+ default_flow_shift = 1 .0f ; // TODO: validate
918+ for (const auto & [name, tensor_storage] : tensor_storage_map) {
919+ if (starts_with (name, " model.diffusion_model.guidance_in.in_layer.weight" )) {
920+ default_flow_shift = 1 .15f ;
921+ break ;
925922 }
926923 }
927924 } else if (sd_version_is_flux2 (version)) {
@@ -945,12 +942,12 @@ class StableDiffusionGGML {
945942 break ;
946943 case FLOW_PRED: {
947944 LOG_INFO (" running in FLOW mode" );
948- denoiser = std::make_shared<DiscreteFlowDenoiser>(flow_shift );
945+ denoiser = std::make_shared<DiscreteFlowDenoiser>();
949946 break ;
950947 }
951948 case FLUX_FLOW_PRED: {
952949 LOG_INFO (" running in Flux FLOW mode" );
953- denoiser = std::make_shared<FluxFlowDenoiser>(flow_shift );
950+ denoiser = std::make_shared<FluxFlowDenoiser>();
954951 break ;
955952 }
956953 case FLUX2_FLOW_PRED: {
@@ -2711,6 +2708,16 @@ class StableDiffusionGGML {
27112708 ggml_ext_tensor_clamp_inplace (result, 0 .0f , 1 .0f );
27122709 return result;
27132710 }
2711+
2712+ void set_flow_shift (float flow_shift = INFINITY) {
2713+ auto flow_denoiser = std::dynamic_pointer_cast<DiscreteFlowDenoiser>(denoiser);
2714+ if (flow_denoiser) {
2715+ if (flow_shift == INFINITY) {
2716+ flow_shift = default_flow_shift;
2717+ }
2718+ flow_denoiser->set_shift (flow_shift);
2719+ }
2720+ }
27142721};
27152722
27162723/* ================================================= SD API ==================================================*/
@@ -2931,7 +2938,6 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
29312938 sd_ctx_params->chroma_use_dit_mask = true ;
29322939 sd_ctx_params->chroma_use_t5_mask = false ;
29332940 sd_ctx_params->chroma_t5_mask_pad = 1 ;
2934- sd_ctx_params->flow_shift = INFINITY;
29352941}
29362942
29372943char * sd_ctx_params_to_str (const sd_ctx_params_t * sd_ctx_params) {
@@ -3023,6 +3029,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
30233029 sample_params->sample_steps = 20 ;
30243030 sample_params->custom_sigmas = nullptr ;
30253031 sample_params->custom_sigmas_count = 0 ;
3032+ sample_params->flow_shift = INFINITY;
30263033}
30273034
30283035char * sd_sample_params_to_str (const sd_sample_params_t * sample_params) {
@@ -3043,7 +3050,8 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
30433050 " sample_method: %s, "
30443051 " sample_steps: %d, "
30453052 " eta: %.2f, "
3046- " shifted_timestep: %d)" ,
3053+ " shifted_timestep: %d, "
3054+ " flow_shift: %.2f)" ,
30473055 sample_params->guidance .txt_cfg ,
30483056 std::isfinite (sample_params->guidance .img_cfg )
30493057 ? sample_params->guidance .img_cfg
@@ -3057,7 +3065,8 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
30573065 sd_sample_method_name (sample_params->sample_method ),
30583066 sample_params->sample_steps ,
30593067 sample_params->eta ,
3060- sample_params->shifted_timestep );
3068+ sample_params->shifted_timestep ,
3069+ sample_params->flow_shift );
30613070
30623071 return buf;
30633072}
@@ -3528,6 +3537,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
35283537
35293538 size_t t0 = ggml_time_ms ();
35303539
3540+ sd_ctx->sd ->set_flow_shift (sd_img_gen_params->sample_params .flow_shift );
3541+
35313542 // Apply lora
35323543 sd_ctx->sd ->apply_loras (sd_img_gen_params->loras , sd_img_gen_params->lora_count );
35333544
@@ -3803,6 +3814,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
38033814 }
38043815 LOG_INFO (" generate_video %dx%dx%d" , width, height, frames);
38053816
3817+ sd_ctx->sd ->set_flow_shift (sd_vid_gen_params->sample_params .flow_shift );
3818+
38063819 enum sample_method_t sample_method = sd_vid_gen_params->sample_params .sample_method ;
38073820 if (sample_method == SAMPLE_METHOD_COUNT) {
38083821 sample_method = sd_get_default_sample_method (sd_ctx);
0 commit comments