Skip to content

Commit dd22958

Browse files
committed
add option for smooth l1 (huber / delta)
1 parent c6495de commit dd22958

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

library/train_util.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3091,7 +3091,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
30913091
"--loss_type",
30923092
type=str,
30933093
default="l2",
3094-
choices=["l2", "huber", "huber_scheduled"],
3094+
choices=["l2", "huber", "huber_scheduled", "smooth_l1", "smooth_l1_scheduled"],
30953095
help="The type of loss to use and whether it's scheduled based on the timestep"
30963096
)
30973097
parser.add_argument(
@@ -4608,7 +4608,7 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, num_train_timest
46084608

46094609
#TODO: if a huber loss is selected, it will use constant timesteps for each batch
46104610
# as. In the future there may be a smarter way
4611-
if args.loss_type == 'huber_scheduled':
4611+
if args.loss_type == 'huber_scheduled' or args.loss_type == 'smooth_l1_scheduled': #NOTE: Will unify scheduled and vanilla soon
46124612
timesteps = torch.randint(
46134613
min_timestep, max_timestep, (1,), device='cpu'
46144614
)
@@ -4617,7 +4617,7 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, num_train_timest
46174617
alpha = - math.log(args.huber_c) / num_train_timesteps
46184618
huber_c = math.exp(-alpha * timestep)
46194619
timesteps = timesteps.repeat(b_size).to(device)
4620-
elif args.loss_type == 'huber':
4620+
elif args.loss_type == 'huber' or args.loss_type == 'smooth_l1':
46214621
# for fairness in comparison
46224622
timesteps = torch.randint(
46234623
min_timestep, max_timestep, (1,), device='cpu'
@@ -4670,6 +4670,12 @@ def conditional_loss(model_pred:torch.Tensor, target:torch.Tensor, reduction:str
46704670
loss = torch.mean(loss)
46714671
elif reduction == "sum":
46724672
loss = torch.sum(loss)
4673+
elif loss_type == 'smooth_l1' or loss_type == 'smooth_l1_scheduled': # NOTE: Will unify in the next commits
4674+
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
4675+
if reduction == "mean":
4676+
loss = torch.mean(loss)
4677+
elif reduction == "sum":
4678+
loss = torch.sum(loss)
46734679
else:
46744680
raise NotImplementedError(f'Unsupported Loss Type {loss_type}')
46754681
return loss

0 commit comments

Comments
 (0)