Skip to content

Commit

Permalink
remove lion
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 18, 2023
1 parent 03d4d8c commit a97961b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 17 deletions.
23 changes: 7 additions & 16 deletions imagen_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader
from torch.optim import Adam
from lion_pytorch import Lion
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torch.cuda.amp import autocast, GradScaler

Expand Down Expand Up @@ -254,7 +253,6 @@ def __init__(
checkpoint_fs = None,
fs_kwargs: dict = None,
max_checkpoints_keep = 20,
use_lion = False,
**kwargs
):
super().__init__()
Expand Down Expand Up @@ -337,20 +335,13 @@ def __init__(

for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)):

if use_lion:
optimizer = Lion(
unet.parameters(),
lr = unet_lr,
betas = (beta1, beta2)
)
else:
optimizer = Adam(
unet.parameters(),
lr = unet_lr,
eps = unet_eps,
betas = (beta1, beta2),
**kwargs
)
optimizer = Adam(
unet.parameters(),
lr = unet_lr,
eps = unet_eps,
betas = (beta1, beta2),
**kwargs
)

if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
'ema-pytorch>=0.0.3',
'fsspec',
'kornia',
'lion-pytorch',
'numpy',
'packaging',
'pillow',
Expand Down

0 comments on commit a97961b

Please sign in to comment.