@@ -3091,7 +3091,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
30913091 "--loss_type" ,
30923092 type = str ,
30933093 default = "l2" ,
3094- choices = ["l2" , "huber" , "huber_scheduled" ],
3094+ choices = ["l2" , "huber" , "huber_scheduled" , "smooth_l1" , "smooth_l1_scheduled" ],
30953095 help = "The type of loss to use and whether it's scheduled based on the timestep"
30963096 )
30973097 parser .add_argument (
@@ -4608,7 +4608,7 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, num_train_timest
46084608
46094609 #TODO: if a huber loss is selected, it will use constant timesteps for each batch
46104610 # as. In the future there may be a smarter way
4611- if args .loss_type == 'huber_scheduled' :
4611+ if args .loss_type == 'huber_scheduled' or args . loss_type == 'smooth_l1_scheduled' : #NOTE: Will unify scheduled and vanilla soon
46124612 timesteps = torch .randint (
46134613 min_timestep , max_timestep , (1 ,), device = 'cpu'
46144614 )
@@ -4617,7 +4617,7 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, num_train_timest
46174617 alpha = - math .log (args .huber_c ) / num_train_timesteps
46184618 huber_c = math .exp (- alpha * timestep )
46194619 timesteps = timesteps .repeat (b_size ).to (device )
4620- elif args .loss_type == 'huber' :
4620+ elif args .loss_type == 'huber' or args . loss_type == 'smooth_l1' :
46214621 # for fairness in comparison
46224622 timesteps = torch .randint (
46234623 min_timestep , max_timestep , (1 ,), device = 'cpu'
@@ -4670,6 +4670,12 @@ def conditional_loss(model_pred:torch.Tensor, target:torch.Tensor, reduction:str
46704670 loss = torch .mean (loss )
46714671 elif reduction == "sum" :
46724672 loss = torch .sum (loss )
4673+ elif loss_type == 'smooth_l1' or loss_type == 'smooth_l1_scheduled' : # NOTE: Will unify in the next commits
4674+ loss = 2 * (torch .sqrt ((model_pred - target ) ** 2 + huber_c ** 2 ) - huber_c )
4675+ if reduction == "mean" :
4676+ loss = torch .mean (loss )
4677+ elif reduction == "sum" :
4678+ loss = torch .sum (loss )
46734679 else :
46744680 raise NotImplementedError (f'Unsupported Loss Type { loss_type } ' )
46754681 return loss
0 commit comments