Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 14 additions & 20 deletions examples/imagenette.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,13 @@ class Pooling(str, Enum):
maxblurpool = 'MaxBlurPool'

class OptimizerChoice(str, Enum):
adan = 'adan'
adam = 'adam'
lamb = 'lamb'
lion = 'lion'
ranger = 'ranger'
sgd = 'sgd'
sophia = 'sophia'
stableadam = 'stableadam'
adan = 'adan'
adam = 'adam'
lamb = 'lamb'
lion = 'lion'
ranger = 'ranger'
sgd = 'sgd'
stableadamw = 'stableadamw'

class Scheduler(str, Enum):
onecycle = 'one_cycle'
Expand All @@ -96,8 +95,8 @@ class WarmMode(str, Enum):
auto = 'auto'

class WarmSched(str, Enum):
SchedCos = 'SchedCos'
SchedLin = 'SchedLin'
SchedCos = 'cosine'
SchedLin = 'linear'

class ResizeMode(str, Enum):
batch = 'batch'
Expand Down Expand Up @@ -345,18 +344,15 @@ def train(ctx:typer.Context, # Typer Context to grab config for --verbose and pa
# Optimizer
optimizer:OptimizerChoice=typer.Option(OptimizerChoice.ranger, show_default=OptimizerChoice.ranger.value, help="Which optimizer to use. Make sure to set learning rate if changed.", case_sensitive=False, rich_help_panel="Optimizer"),
weight_decay:Optional[float]=typer.Option(None, help="Weight decay for Optimizer. If None, use optimizer's default.", rich_help_panel="Optimizer"),
decouple_wd:bool=typer.Option(True, "--true-wd/--l2-wd", help="Apply true (decoupled) weight decay or L2 regularization. Doesn't apply to Adan, Lion, or Sophia.", rich_help_panel="Optimizer"),
fused_opt:bool=typer.Option(True, "--fused/--standard", help="Use faster For Each fused Optimizer or slower standard fastai Optimizer.", rich_help_panel="Optimizer"),
decouple_wd:bool=typer.Option(True, "--decouple-wd/--l2-reg", help="Apply decoupled weight decay or L2 regularization. Doesn't apply to Adan, Lion, or StableAdamW.", rich_help_panel="Optimizer"),
decouple_lr:bool=typer.Option(False, "--decouple-lr", help="Apply fully decoupled weight decay. Doesn't apply to Adan, Lion, or StableAdamW.", rich_help_panel="Optimizer"),
eight_bit:bool=typer.Option(False, "--eight-bit", help="Use bitsandbytes 8-bit optimizer. Avalible for Adam, LAMB, Lion, & SGD with Momentum.", rich_help_panel="Optimizer"),
mom:Optional[float]=typer.Option(None, help="Gradient moving average (β1) coefficient. If None, uses optimizer's default.", rich_help_panel="Optimizer"),
sqr_mom:Optional[float]=typer.Option(None, help="Gradient squared moving average (β2) coefficient. If None, use optimizer's default.", rich_help_panel="Optimizer"),
beta1:Optional[float]=typer.Option(None, help="Adan: Gradient moving average (β1) coefficient. Lion: Update gradient moving average (β1) coefficient. If None, use optimizer's default.", rich_help_panel="Optimizer"),
beta2:Optional[float]=typer.Option(None, help="Adan: Gradient difference moving average (β2) coefficient. Lion: Gradient moving average (β2) coefficient. If None, use optimizer's default.", rich_help_panel="Optimizer"),
beta3:Optional[float]=typer.Option(None, help="Adan: Gradient squared moving average (β3) coefficient. If None, use optimizer's default.", rich_help_panel="Optimizer"),
hess_mom:Optional[float]=typer.Option(None, help="Sophia: Hessian moving average (β2) coefficient. If None, use optimizer's default.", rich_help_panel="Optimizer"),
rho:Optional[float]=typer.Option(None, help="Sophia: Maximum update size, set higher for more agressive updates. If None, use optimizer's default.", rich_help_panel="Optimizer"),
eps:Optional[float]=typer.Option(None, help="Added for numerical stability. If None, uses optimizer's default.", rich_help_panel="Optimizer"),
paper_init:bool=typer.Option(False, "--paperinit/--zeroinit", help="Adan: Initialize prior gradient with current gradient per paper or zeroes.", rich_help_panel="Optimizer"),
# Scheduler
scheduler:Scheduler=typer.Option(Scheduler.flatcos, show_default=Scheduler.flatcos.value, help="Which fastai or fastxtend scheduler to use. fit_one_cycle, fit_flat_cos, fit_flat_warmup, or fit_cos_anneal.", case_sensitive=False, rich_help_panel="Scheduler"),
epochs:int=typer.Option(20, help="Number of epochs to train for.", rich_help_panel="Scheduler"),
Expand Down Expand Up @@ -457,6 +453,7 @@ def train(ctx:typer.Context, # Typer Context to grab config for --verbose and pa

if profile:
from fastxtend.callback import profiler
from fastxtend.callback.profiler import ProfileMode

if torch_compile:
from fastxtend.callback import compiler
Expand Down Expand Up @@ -499,9 +496,6 @@ def train(ctx:typer.Context, # Typer Context to grab config for --verbose and pa
elif cutmixup:
cbs += [CutMixUp(mixup_alpha=mixup_alpha, cutmix_alpha=cutmix_alpha, mixup_ratio=mixup_ratio,
cutmix_ratio=cutmix_ratio, element=elementwise, interp_label=False)]
if optimizer.value=='sophia':
print('add sophia callback')
cbs += [SophiaCallback()]

# Create the dataloaders
with less_random(seed):
Expand Down Expand Up @@ -537,10 +531,10 @@ def train(ctx:typer.Context, # Typer Context to grab config for --verbose and pa
# Create Learner
with less_random(seed):
learn = Learner(dls, arch(), loss_func=nn.CrossEntropyLoss(label_smoothing=label_smoothing),
opt_func=opt(foreach=fused_opt, **opt_kwargs), metrics=Accuracy(), cbs=cbs)
opt_func=opt(**opt_kwargs), metrics=Accuracy(), cbs=cbs)
learn.to_channelslast() if channels_last else learn.to_fp16()
if profile:
learn.profile()
learn.profile(mode=ProfileMode.Simple)
if torch_compile:
learn.compile(backend=backend)

Expand Down
Loading
Loading