@@ -3905,7 +3905,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
39053905 "--huber_c" ,
39063906 type = float ,
39073907 default = 0.1 ,
3908- help = "The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1" ,
3908+ help = "The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1" ,
3909+ )
3910+
3911+ parser .add_argument (
3912+ "--huber_scale" ,
3913+ type = float ,
3914+ default = 1.0 ,
3915+ help = "The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1" ,
39093916 )
39103917
39113918 parser .add_argument (
@@ -5821,29 +5828,10 @@ def save_sd_model_on_train_end_common(
58215828 huggingface_util .upload (args , out_dir , "/" + model_name , force_sync_upload = True )
58225829
58235830
5824- def get_timesteps_and_huber_c ( args , min_timestep , max_timestep , noise_scheduler , b_size , device ):
5831+ def get_timesteps ( min_timestep , max_timestep , b_size , device ):
58255832 timesteps = torch .randint (min_timestep , max_timestep , (b_size ,), device = "cpu" )
5826-
5827- if args .loss_type == "huber" or args .loss_type == "smooth_l1" :
5828- if args .huber_schedule == "exponential" :
5829- alpha = - math .log (args .huber_c ) / noise_scheduler .config .num_train_timesteps
5830- huber_c = torch .exp (- alpha * timesteps )
5831- elif args .huber_schedule == "snr" :
5832- alphas_cumprod = torch .index_select (noise_scheduler .alphas_cumprod , 0 , timesteps )
5833- sigmas = ((1.0 - alphas_cumprod ) / alphas_cumprod ) ** 0.5
5834- huber_c = (1 - args .huber_c ) / (1 + sigmas ) ** 2 + args .huber_c
5835- elif args .huber_schedule == "constant" :
5836- huber_c = torch .full ((b_size ,), args .huber_c )
5837- else :
5838- raise NotImplementedError (f"Unknown Huber loss schedule { args .huber_schedule } !" )
5839- huber_c = huber_c .to (device )
5840- elif args .loss_type == "l2" :
5841- huber_c = None # may be anything, as it's not used
5842- else :
5843- raise NotImplementedError (f"Unknown loss type { args .loss_type } " )
5844-
58455833 timesteps = timesteps .long ().to (device )
5846- return timesteps , huber_c
5834+ return timesteps
58475835
58485836
58495837def get_noise_noisy_latents_and_timesteps (args , noise_scheduler , latents ):
@@ -5865,7 +5853,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
58655853 min_timestep = 0 if args .min_timestep is None else args .min_timestep
58665854 max_timestep = noise_scheduler .config .num_train_timesteps if args .max_timestep is None else args .max_timestep
58675855
5868- timesteps , huber_c = get_timesteps_and_huber_c ( args , min_timestep , max_timestep , noise_scheduler , b_size , latents .device )
5856+ timesteps = get_timesteps ( min_timestep , max_timestep , b_size , latents .device )
58695857
58705858 # Add noise to the latents according to the noise magnitude at each timestep
58715859 # (this is the forward diffusion process)
@@ -5878,32 +5866,54 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
58785866 else :
58795867 noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
58805868
5881- return noise , noisy_latents , timesteps , huber_c
5869+ return noise , noisy_latents , timesteps
5870+
5871+
5872+ def get_huber_threshold (args , timesteps : torch .Tensor , noise_scheduler ) -> torch .Tensor :
5873+ b_size = timesteps .shape [0 ]
5874+ if args .huber_schedule == "exponential" :
5875+ alpha = - math .log (args .huber_c ) / noise_scheduler .config .num_train_timesteps
5876+ result = torch .exp (- alpha * timesteps ) * args .huber_scale
5877+ elif args .huber_schedule == "snr" :
5878+ if noise_scheduler is None or not hasattr (noise_scheduler , "alphas_cumprod" ):
5879+ raise NotImplementedError ("Huber schedule 'snr' is not supported with the current model." )
5880+ alphas_cumprod = torch .index_select (noise_scheduler .alphas_cumprod , 0 , timesteps .cpu ())
5881+ sigmas = ((1.0 - alphas_cumprod ) / alphas_cumprod ) ** 0.5
5882+ result = (1 - args .huber_c ) / (1 + sigmas ) ** 2 + args .huber_c
5883+ result = result .to (timesteps .device )
5884+ elif args .huber_schedule == "constant" :
5885+ result = torch .full ((b_size ,), args .huber_c * args .huber_scale , device = timesteps .device )
5886+ else :
5887+ raise NotImplementedError (f"Unknown Huber loss schedule { args .huber_schedule } !" )
5888+
5889+ return result
58825890
58835891
58845892def conditional_loss (
5885- model_pred : torch .Tensor , target : torch .Tensor , reduction : str , loss_type : str , huber_c : Optional [ torch . Tensor ]
5893+ args , model_pred : torch .Tensor , target : torch .Tensor , timesteps : torch . Tensor , reduction : str , noise_scheduler
58865894):
5887- if loss_type == "l2" :
5895+ if args . loss_type == "l2" :
58885896 loss = torch .nn .functional .mse_loss (model_pred , target , reduction = reduction )
5889- elif loss_type == "l1" :
5897+ elif args . loss_type == "l1" :
58905898 loss = torch .nn .functional .l1_loss (model_pred , target , reduction = reduction )
5891- elif loss_type == "huber" :
5899+ elif args .loss_type == "huber" :
5900+ huber_c = get_huber_threshold (args , timesteps , noise_scheduler )
58925901 huber_c = huber_c .view (- 1 , 1 , 1 , 1 )
58935902 loss = 2 * huber_c * (torch .sqrt ((model_pred - target ) ** 2 + huber_c ** 2 ) - huber_c )
58945903 if reduction == "mean" :
58955904 loss = torch .mean (loss )
58965905 elif reduction == "sum" :
58975906 loss = torch .sum (loss )
5898- elif loss_type == "smooth_l1" :
5907+ elif args .loss_type == "smooth_l1" :
5908+ huber_c = get_huber_threshold (args , timesteps , noise_scheduler )
58995909 huber_c = huber_c .view (- 1 , 1 , 1 , 1 )
59005910 loss = 2 * (torch .sqrt ((model_pred - target ) ** 2 + huber_c ** 2 ) - huber_c )
59015911 if reduction == "mean" :
59025912 loss = torch .mean (loss )
59035913 elif reduction == "sum" :
59045914 loss = torch .sum (loss )
59055915 else :
5906- raise NotImplementedError (f"Unsupported Loss Type { loss_type } " )
5916+ raise NotImplementedError (f"Unsupported Loss Type: { args . loss_type } " )
59075917 return loss
59085918
59095919
0 commit comments