@@ -747,6 +747,16 @@ class StableDiffusionGGML {
747747 denoiser->scheduler = std::make_shared<GITSSchedule>();
748748 denoiser->scheduler ->version = version;
749749 break ;
750+ case SGM_UNIFORM:
751+ LOG_INFO (" Running with SGM Uniform schedule" );
752+ denoiser->scheduler = std::make_shared<SGMUniformSchedule>();
753+ denoiser->scheduler ->version = version;
754+ break ;
755+ case SIMPLE:
756+ LOG_INFO (" Running with Simple schedule" );
757+ denoiser->scheduler = std::make_shared<SimpleSchedule>();
758+ denoiser->scheduler ->version = version;
759+ break ;
750760 case SMOOTHSTEP:
751761 LOG_INFO (" Running with SmoothStep scheduler" );
752762 denoiser->scheduler = std::make_shared<SmoothStepSchedule>();
@@ -1033,6 +1043,7 @@ class StableDiffusionGGML {
10331043 float control_strength,
10341044 sd_guidance_params_t guidance,
10351045 float eta,
1046+ int shifted_timestep,
10361047 sample_method_t method,
10371048 const std::vector<float >& sigmas,
10381049 int start_merge_step,
@@ -1042,6 +1053,10 @@ class StableDiffusionGGML {
10421053 ggml_tensor* denoise_mask = NULL ,
10431054 ggml_tensor* vace_context = NULL ,
10441055 float vace_strength = 1 .f) {
1056+ if (shifted_timestep > 0 && !sd_version_is_sdxl (version)) {
1057+ LOG_WARN (" timestep shifting is only supported for SDXL models!" );
1058+ shifted_timestep = 0 ;
1059+ }
10451060 std::vector<int > skip_layers (guidance.slg .layers , guidance.slg .layers + guidance.slg .layer_count );
10461061
10471062 float cfg_scale = guidance.txt_cfg ;
@@ -1102,7 +1117,17 @@ class StableDiffusionGGML {
11021117 float c_in = scaling[2 ];
11031118
11041119 float t = denoiser->sigma_to_t (sigma);
1105- std::vector<float > timesteps_vec (1 , t); // [N, ]
1120+ std::vector<float > timesteps_vec;
1121+ if (shifted_timestep > 0 && sd_version_is_sdxl (version)) {
1122+ float shifted_t_float = t * (float (shifted_timestep) / float (TIMESTEPS));
1123+ int64_t shifted_t = static_cast <int64_t >(roundf (shifted_t_float));
1124+ shifted_t = std::max ((int64_t )0 , std::min ((int64_t )(TIMESTEPS - 1 ), shifted_t ));
1125+ LOG_DEBUG (" shifting timestep from %.2f to %" PRId64 " (sigma: %.4f)" , t, shifted_t , sigma);
1126+ timesteps_vec.assign (1 , (float )shifted_t );
1127+ } else {
1128+ timesteps_vec.assign (1 , t);
1129+ }
1130+
11061131 timesteps_vec = process_timesteps (timesteps_vec, init_latent, denoise_mask);
11071132 auto timesteps = vector_to_ggml_tensor (work_ctx, timesteps_vec);
11081133 std::vector<float > guidance_vec (1 , guidance.distilled_guidance );
@@ -1200,6 +1225,19 @@ class StableDiffusionGGML {
12001225 float * vec_input = (float *)input->data ;
12011226 float * positive_data = (float *)out_cond->data ;
12021227 int ne_elements = (int )ggml_nelements (denoised);
1228+
1229+ if (shifted_timestep > 0 && sd_version_is_sdxl (version)) {
1230+ int64_t shifted_t_idx = static_cast <int64_t >(roundf (timesteps_vec[0 ]));
1231+ float shifted_sigma = denoiser->t_to_sigma ((float )shifted_t_idx);
1232+ std::vector<float > shifted_scaling = denoiser->get_scalings (shifted_sigma);
1233+ float shifted_c_skip = shifted_scaling[0 ];
1234+ float shifted_c_out = shifted_scaling[1 ];
1235+ float shifted_c_in = shifted_scaling[2 ];
1236+
1237+ c_skip = shifted_c_skip * c_in / shifted_c_in;
1238+ c_out = shifted_c_out;
1239+ }
1240+
12031241 for (int i = 0 ; i < ne_elements; i++) {
12041242 float latent_result = positive_data[i];
12051243 if (has_unconditioned) {
@@ -1222,6 +1260,7 @@ class StableDiffusionGGML {
12221260 // denoised = (v * c_out + input * c_skip) or (input + eps * c_out)
12231261 vec_denoised[i] = latent_result * c_out + vec_input[i] * c_skip;
12241262 }
1263+
12251264 int64_t t1 = ggml_time_us ();
12261265 if (step > 0 ) {
12271266 pretty_progress (step, (int )steps, (t1 - t0) / 1000000 .f );
@@ -1588,6 +1627,8 @@ const char* schedule_to_str[] = {
15881627 " exponential" ,
15891628 " ays" ,
15901629 " gits" ,
1630+ " sgm_uniform" ,
1631+ " simple" ,
15911632 " smoothstep" ,
15921633};
15931634
@@ -1720,7 +1761,8 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
17201761 " scheduler: %s, "
17211762 " sample_method: %s, "
17221763 " sample_steps: %d, "
1723- " eta: %.2f)" ,
1764+ " eta: %.2f, "
1765+ " shifted_timestep: %d)" ,
17241766 sample_params->guidance .txt_cfg ,
17251767 sample_params->guidance .img_cfg ,
17261768 sample_params->guidance .distilled_guidance ,
@@ -1731,7 +1773,8 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
17311773 sd_schedule_name (sample_params->scheduler ),
17321774 sd_sample_method_name (sample_params->sample_method ),
17331775 sample_params->sample_steps ,
1734- sample_params->eta );
1776+ sample_params->eta ,
1777+ sample_params->shifted_timestep );
17351778
17361779 return buf;
17371780}
@@ -1863,6 +1906,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
18631906 int clip_skip,
18641907 sd_guidance_params_t guidance,
18651908 float eta,
1909+ int shifted_timestep,
18661910 int width,
18671911 int height,
18681912 enum sample_method_t sample_method,
@@ -2101,6 +2145,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
21012145 control_strength,
21022146 guidance,
21032147 eta,
2148+ shifted_timestep,
21042149 sample_method,
21052150 sigmas,
21062151 start_merge_step,
@@ -2394,6 +2439,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
23942439 sd_img_gen_params->clip_skip ,
23952440 sd_img_gen_params->sample_params .guidance ,
23962441 sd_img_gen_params->sample_params .eta ,
2442+ sd_img_gen_params->sample_params .shifted_timestep ,
23972443 width,
23982444 height,
23992445 sample_method,
@@ -2734,6 +2780,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
27342780 0 ,
27352781 sd_vid_gen_params->high_noise_sample_params .guidance ,
27362782 sd_vid_gen_params->high_noise_sample_params .eta ,
2783+ sd_vid_gen_params->high_noise_sample_params .shifted_timestep ,
27372784 sd_vid_gen_params->high_noise_sample_params .sample_method ,
27382785 high_noise_sigmas,
27392786 -1 ,
@@ -2769,6 +2816,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
27692816 0 ,
27702817 sd_vid_gen_params->sample_params .guidance ,
27712818 sd_vid_gen_params->sample_params .eta ,
2819+ sd_vid_gen_params->sample_params .shifted_timestep ,
27722820 sd_vid_gen_params->sample_params .sample_method ,
27732821 sigmas,
27742822 -1 ,
0 commit comments