Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom optimizer #132

Merged
merged 21 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions luxonis_train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ class CallbackConfig(BaseModelExtraForbid):

class OptimizerConfig(BaseModelExtraForbid):
name: str = "Adam"
apply_custom_lr: bool = False
params: Params = {}


Expand Down
129 changes: 124 additions & 5 deletions luxonis_train/models/luxonis_lightning.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import math
from collections import defaultdict
from collections.abc import Mapping
from logging import getLogger
from pathlib import Path
from typing import Literal, cast

import lightning.pytorch as pl
import numpy as np
import torch
from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary
from lightning.pytorch.utilities import rank_zero_only # type: ignore
Expand Down Expand Up @@ -857,14 +859,83 @@
list[torch.optim.Optimizer],
list[torch.optim.lr_scheduler.LRScheduler],
]:
"""Configures model optimizers and schedulers."""
"""Configures model optimizers and schedulers with optional
custom learning rates and warm-up logic."""

cfg_optimizer = self.cfg.trainer.optimizer
cfg_scheduler = self.cfg.trainer.scheduler

optim_params = cfg_optimizer.params | {
"params": filter(lambda p: p.requires_grad, self.parameters()),
}
optimizer = OPTIMIZERS.get(cfg_optimizer.name)(**optim_params)
apply_custom_lr = cfg_optimizer.apply_custom_lr

if apply_custom_lr:
assert cfg_optimizer.name == "SGD", (
"Custom learning rates are supported only for SGD optimizer. "
f"Got {cfg_optimizer.name}."
)
self.max_stepnum = math.ceil(
len(self._core.loaders["train"]) / self.cfg.trainer.batch_size

Check failure on line 876 in luxonis_train/models/luxonis_lightning.py

View workflow job for this annotation

GitHub Actions / type-check

"loaders" is not a known attribute of "None" (reportOptionalMemberAccess)
kozlov721 marked this conversation as resolved.
Show resolved Hide resolved
)
self.warmup_stepnum = max(
round(
self.cfg.trainer.optimizer.params["warmup_epochs"]
* self.max_stepnum
),
1000,
)
self.step = 0
batch_norm_weights, regular_weights, biases = [], [], []
for module in self.modules():
if hasattr(module, "bias") and isinstance(
module.bias, torch.nn.Parameter
):
biases.append(module.bias)
if isinstance(module, torch.nn.BatchNorm2d):
batch_norm_weights.append(module.weight)
elif hasattr(module, "weight") and isinstance(
module.weight, torch.nn.Parameter
):
regular_weights.append(module.weight)

optimizer = torch.optim.SGD(
[
{
"params": batch_norm_weights,
"lr": cfg_optimizer.params["lr"],
"momentum": cfg_optimizer.params["momentum"],
"nesterov": True,
},
{
"params": regular_weights,
"weight_decay": cfg_optimizer.params["weight_decay"],
},
{"params": biases},
],
lr=cfg_optimizer.params["lr"],
momentum=cfg_optimizer.params["momentum"],
nesterov=cfg_optimizer.params["nesterov"],
)

lrf = (
self.cfg.trainer.optimizer.params["lre"]
/ self.cfg.trainer.optimizer.params["lr"]
)
self.lf = (
lambda x: (
(1 - math.cos(x * math.pi / self.cfg.trainer.epochs)) / 2
)
* (lrf - 1)
+ 1
)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda=self.lf
)
return [optimizer], [scheduler]

else:
optim_params = cfg_optimizer.params | {
"params": filter(lambda p: p.requires_grad, self.parameters()),
}
optimizer = OPTIMIZERS.get(cfg_optimizer.name)(**optim_params)

def get_scheduler(scheduler_cfg, optimizer):
scheduler_class = SCHEDULERS.get(
Expand Down Expand Up @@ -895,6 +966,54 @@

return [optimizer], [scheduler]

def on_after_backward(self):
"""Custom logic to adjust learning rates and momentum after
loss.backward."""
if self.cfg.trainer.optimizer.apply_custom_lr:
self.custom_logic()

def custom_logic(self):
"""Custom logic to adjust learning rates and momentum after
loss.backward."""

# Increment step counter
self.step = (
self.step % self.max_stepnum
) # Reset step counter after each epoch
curr_step = self.step + self.max_stepnum * self.current_epoch

# Warm-up phase adjustments
if curr_step <= self.warmup_stepnum:
optimizer = self.optimizers()
for k, param in enumerate(optimizer.param_groups):

Check failure on line 988 in luxonis_train/models/luxonis_lightning.py

View workflow job for this annotation

GitHub Actions / type-check

Cannot access attribute "param_groups" for class "List[LightningOptimizer]"   Attribute "param_groups" is unknown (reportAttributeAccessIssue)
warmup_bias_lr = (
self.cfg.trainer.optimizer.params["warmup_bias_lr"]
if k == 2
else 0.0
)
param["lr"] = np.interp(
curr_step,
[0, self.warmup_stepnum],
[
warmup_bias_lr,
self.cfg.trainer.optimizer.params["lr"]
* self.lf(self.current_epoch),
],
)
if "momentum" in param:
param["momentum"] = np.interp(
curr_step,
[0, self.warmup_stepnum],
[
self.cfg.trainer.optimizer.params[
"warmup_momentum"
],
self.cfg.trainer.optimizer.params["momentum"],
],
)

self.step += 1

def load_checkpoint(self, path: str | Path | None) -> None:
"""Loads checkpoint weights from provided path.

Expand Down
13 changes: 13 additions & 0 deletions luxonis_train/nodes/backbones/efficientrep/efficientrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,22 @@
)
)

self.initialize_weights()

if download_weights and var.weights_path:
self.load_checkpoint(var.weights_path)

def initialize_weights(self):
for m in self.modules():
t = type(m)
if t is nn.Conv2d:
pass
elif t is nn.BatchNorm2d:
m.eps = 1e-3

Check failure on line 139 in luxonis_train/nodes/backbones/efficientrep/efficientrep.py

View workflow job for this annotation

GitHub Actions / type-check

Argument of type "float" cannot be assigned to parameter "value" of type "Tensor | Module" in function "__setattr__"   Type "float" is not assignable to type "Tensor | Module"     "float" is not assignable to "Tensor"     "float" is not assignable to "Module" (reportArgumentType)
m.momentum = 0.03

Check failure on line 140 in luxonis_train/nodes/backbones/efficientrep/efficientrep.py

View workflow job for this annotation

GitHub Actions / type-check

Argument of type "float" cannot be assigned to parameter "value" of type "Tensor | Module" in function "__setattr__"   Type "float" is not assignable to type "Tensor | Module"     "float" is not assignable to "Tensor"     "float" is not assignable to "Module" (reportArgumentType)
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
m.inplace = True

Check failure on line 142 in luxonis_train/nodes/backbones/efficientrep/efficientrep.py

View workflow job for this annotation

GitHub Actions / type-check

Argument of type "Literal[True]" cannot be assigned to parameter "value" of type "Tensor | Module" in function "__setattr__"   Type "Literal[True]" is not assignable to type "Tensor | Module"     "Literal[True]" is not assignable to "Tensor"     "Literal[True]" is not assignable to "Module" (reportArgumentType)

def set_export_mode(self, mode: bool = True) -> None:
"""Reparametrizes instances of L{RepVGGBlock} in the network.

Expand Down
12 changes: 12 additions & 0 deletions luxonis_train/nodes/blocks/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@

prior_prob = 1e-2
self._initialize_weights_and_biases(prior_prob)
self.initialize_weights()

def initialize_weights(self):
for m in self.modules():
t = type(m)
if t is nn.Conv2d:
pass
elif t is nn.BatchNorm2d:
m.eps = 1e-3

Check failure on line 67 in luxonis_train/nodes/blocks/blocks.py

View workflow job for this annotation

GitHub Actions / type-check

Argument of type "float" cannot be assigned to parameter "value" of type "Tensor | Module" in function "__setattr__"   Type "float" is not assignable to type "Tensor | Module"     "float" is not assignable to "Tensor"     "float" is not assignable to "Module" (reportArgumentType)
m.momentum = 0.03

Check failure on line 68 in luxonis_train/nodes/blocks/blocks.py

View workflow job for this annotation

GitHub Actions / type-check

Argument of type "float" cannot be assigned to parameter "value" of type "Tensor | Module" in function "__setattr__"   Type "float" is not assignable to type "Tensor | Module"     "float" is not assignable to "Tensor"     "float" is not assignable to "Module" (reportArgumentType)
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
m.inplace = True

Check failure on line 70 in luxonis_train/nodes/blocks/blocks.py

View workflow job for this annotation

GitHub Actions / type-check

Argument of type "Literal[True]" cannot be assigned to parameter "value" of type "Tensor | Module" in function "__setattr__"   Type "Literal[True]" is not assignable to type "Tensor | Module"     "Literal[True]" is not assignable to "Tensor"     "Literal[True]" is not assignable to "Module" (reportArgumentType)

def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
out_feature = self.decoder(x)
Expand Down
13 changes: 13 additions & 0 deletions luxonis_train/nodes/heads/efficient_bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,25 @@
f"output{i+1}_yolov6r2" for i in range(self.n_heads)
]

self.initialize_weights()

if download_weights:
# TODO: Handle variants of head in a nicer way
if self.in_channels == [32, 64, 128]:
weights_path = "https://github.com/luxonis/luxonis-train/releases/download/v0.1.0-beta/efficientbbox_head_n_coco.ckpt"
self.load_checkpoint(weights_path, strict=False)

def initialize_weights(self):
for m in self.modules():
t = type(m)
if t is nn.Conv2d:
pass
elif t is nn.BatchNorm2d:
m.eps = 1e-3

Check failure on line 112 in luxonis_train/nodes/heads/efficient_bbox_head.py

View workflow job for this annotation

GitHub Actions / type-check

Argument of type "float" cannot be assigned to parameter "value" of type "Tensor | Module" in function "__setattr__"   Type "float" is not assignable to type "Tensor | Module"     "float" is not assignable to "Tensor"     "float" is not assignable to "Module" (reportArgumentType)
m.momentum = 0.03

Check failure on line 113 in luxonis_train/nodes/heads/efficient_bbox_head.py

View workflow job for this annotation

GitHub Actions / type-check

Argument of type "float" cannot be assigned to parameter "value" of type "Tensor | Module" in function "__setattr__"   Type "float" is not assignable to type "Tensor | Module"     "float" is not assignable to "Tensor"     "float" is not assignable to "Module" (reportArgumentType)
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
m.inplace = True

def forward(
self, inputs: list[Tensor]
) -> tuple[list[Tensor], list[Tensor], list[Tensor]]:
Expand Down
13 changes: 13 additions & 0 deletions luxonis_train/nodes/necks/reppan_neck/reppan_neck.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,22 @@ def __init__(
out_channels = channels_list_down_blocks[2 * i + 1]
curr_n_repeats = n_repeats_down_blocks[i]

self.initialize_weights()

if download_weights and var.weights_path:
self.load_checkpoint(var.weights_path)

def initialize_weights(self):
for m in self.modules():
t = type(m)
if t is nn.Conv2d:
pass
elif t is nn.BatchNorm2d:
m.eps = 1e-3
m.momentum = 0.03
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
m.inplace = True

def forward(self, inputs: list[Tensor]) -> list[Tensor]:
x = inputs[-1]
up_block_outs: list[Tensor] = []
Expand Down