Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom optimizer #132

Merged
merged 21 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ model.tune()
- [**Callbacks**](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/callbacks/README.md): Allow custom code to be executed at different stages of training.
- [**Optimizers**](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#optimizer): Control how the model's weights are updated.
- [**Schedulers**](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#scheduler): Adjust the learning rate during training.
- [**Training Strategy**](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#training-strategy): Specify a custom combination of optimizer and scheduler to tailor the training process for specific use cases.

**Creating Custom Components:**

Expand All @@ -581,6 +582,7 @@ Registered components can be referenced in the config file. Custom components ne
- **Callbacks** - [`lightning.pytorch.callbacks.Callback`](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html), requires manual registration to the `CALLBACKS` registry
- **Optimizers** - [`torch.optim.Optimizer`](https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer), requires manual registration to the `OPTIMIZERS` registry
- **Schedulers** - [`torch.optim.lr_scheduler.LRScheduler`](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate), requires manual registration to the `SCHEDULERS` registry
- **Training Strategy** - [`BaseTrainingStrategy`](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/strategies/base_strategy.py)

**Examples:**

Expand Down
31 changes: 31 additions & 0 deletions configs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,37 @@ trainer:
eta_min: 0
```

### Training Strategy
klemen1999 marked this conversation as resolved.
Show resolved Hide resolved

Defines the training strategy to be used. Currently, only the `TripleLRSGDStrategy` is supported, but more strategies will be added in the future.

| Key | Type | Default value | Description |
| ----------------- | ------- | ----------------------- | ---------------------------------------------- |
| `name` | `str` | `"TripleLRSGDStrategy"` | Name of the training strategy |
| `warmup_epochs` | `int` | `3` | Number of epochs for the warmup phase |
| `warmup_bias_lr` | `float` | `0.1` | Learning rate for bias during the warmup phase |
| `warmup_momentum` | `float` | `0.8` | Momentum value during the warmup phase |
| `lr` | `float` | `0.02` | Initial learning rate |
| `lre` | `float` | `0.0002` | End learning rate |
| `momentum` | `float` | `0.937` | Momentum for the optimizer |
| `weight_decay` | `float` | `0.0005` | Weight decay value |
| `nesterov` | `bool` | `true` | Use Nesterov momentum or not |

**Example:**

```yaml
training_strategy:
name: "TripleLRSGDStrategy"
warmup_epochs: 3
warmup_bias_lr: 0.1
warmup_momentum: 0.8
lr: 0.02
lre: 0.0002
momentum: 0.937
weight_decay: 0.0005
nesterov: true
```

## Exporter

Here you can define configuration for exporting.
Expand Down
1 change: 1 addition & 0 deletions luxonis_train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .nodes import *
from .optimizers import *
from .schedulers import *
from .strategies import *
from .utils import *
except ImportError as e:
warnings.warn(
Expand Down
3 changes: 3 additions & 0 deletions luxonis_train/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .metadata_logger import MetadataLogger
from .module_freezer import ModuleFreezer
from .test_on_train_end import TestOnTrainEnd
from .training_manager import TrainingManager
from .upload_checkpoint import UploadCheckpoint

CALLBACKS.register_module(module=EarlyStopping)
Expand All @@ -38,6 +39,7 @@
CALLBACKS.register_module(module=ModelPruning)
CALLBACKS.register_module(module=GradCamCallback)
CALLBACKS.register_module(module=EMACallback)
CALLBACKS.register_module(module=TrainingManager)


__all__ = [
Expand All @@ -53,4 +55,5 @@
"GPUStatsMonitor",
"GradCamCallback",
"EMACallback",
"TrainingManager",
]
28 changes: 28 additions & 0 deletions luxonis_train/callbacks/training_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytorch_lightning as pl

from luxonis_train.strategies.base_strategy import BaseTrainingStrategy


class TrainingManager(pl.Callback):
def __init__(self, strategy: BaseTrainingStrategy | None = None):
"""Training manager callback that updates the parameters of the
training strategy.

@type strategy: BaseTrainingStrategy
@param strategy: The strategy to be used.
"""
self.strategy = strategy

Check warning on line 14 in luxonis_train/callbacks/training_manager.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/callbacks/training_manager.py#L14

Added line #L14 was not covered by tests

def on_after_backward(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
):
"""PyTorch Lightning hook that is called after the backward
pass.

@type trainer: pl.Trainer
@param trainer: The trainer object.
@type pl_module: pl.LightningModule
@param pl_module: The pl_module object.
"""
if self.strategy is not None:
self.strategy.update_parameters(pl_module)

Check warning on line 28 in luxonis_train/callbacks/training_manager.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/callbacks/training_manager.py#L27-L28

Added lines #L27 - L28 were not covered by tests
6 changes: 6 additions & 0 deletions luxonis_train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,11 @@ class SchedulerConfig(BaseModelExtraForbid):
params: Params = {}


class TrainingStrategyConfig(BaseModelExtraForbid):
name: str
params: Params = {}


class TrainerConfig(BaseModelExtraForbid):
preprocessing: PreprocessingConfig = PreprocessingConfig()
use_rich_progress_bar: bool = True
Expand Down Expand Up @@ -382,6 +387,7 @@ class TrainerConfig(BaseModelExtraForbid):

optimizer: OptimizerConfig = OptimizerConfig()
scheduler: SchedulerConfig = SchedulerConfig()
training_strategy: TrainingStrategyConfig | None = None

@model_validator(mode="after")
def validate_deterministic(self) -> Self:
Expand Down
31 changes: 30 additions & 1 deletion luxonis_train/models/luxonis_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
combine_visualizations,
get_denormalized_images,
)
from luxonis_train.callbacks import BaseLuxonisProgressBar, ModuleFreezer
from luxonis_train.callbacks import (
BaseLuxonisProgressBar,
ModuleFreezer,
TrainingManager,
)
from luxonis_train.config import AttachedModuleConfig, Config
from luxonis_train.nodes import BaseNode
from luxonis_train.utils import (
Expand All @@ -42,6 +46,7 @@
CALLBACKS,
OPTIMIZERS,
SCHEDULERS,
STRATEGIES,
Registry,
)

Expand Down Expand Up @@ -268,6 +273,24 @@

self.load_checkpoint(self.cfg.model.weights)

if self.cfg.trainer.training_strategy is not None:
if self.cfg.trainer.optimizer is not None:

Check warning on line 277 in luxonis_train/models/luxonis_lightning.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/models/luxonis_lightning.py#L277

Added line #L277 was not covered by tests
logger.warning(
"Training strategy is active; the specified optimizer will be ignored."
)
if self.cfg.trainer.scheduler is not None:

Check warning on line 281 in luxonis_train/models/luxonis_lightning.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/models/luxonis_lightning.py#L281

Added line #L281 was not covered by tests
logger.warning(
"Training strategy is active; the specified scheduler will be ignored."
)
self.training_strategy = STRATEGIES.get(

Check warning on line 285 in luxonis_train/models/luxonis_lightning.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/models/luxonis_lightning.py#L285

Added line #L285 was not covered by tests
self.cfg.trainer.training_strategy.name
)(
pl_module=self,
params=self.cfg.trainer.training_strategy.params, # type: ignore
)
else:
self.training_strategy = None

@property
def core(self) -> "luxonis_train.core.LuxonisModel":
"""Returns the core model."""
Expand Down Expand Up @@ -849,6 +872,9 @@
CALLBACKS.get(callback.name)(**callback.params)
)

if self.training_strategy is not None:
callbacks.append(TrainingManager(strategy=self.training_strategy)) # type: ignore

Check warning on line 876 in luxonis_train/models/luxonis_lightning.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/models/luxonis_lightning.py#L876

Added line #L876 was not covered by tests

return callbacks

def configure_optimizers(
Expand All @@ -858,6 +884,9 @@
list[torch.optim.lr_scheduler.LRScheduler],
]:
"""Configures model optimizers and schedulers."""
if self.training_strategy is not None:
return self.training_strategy.configure_optimizers()

Check warning on line 888 in luxonis_train/models/luxonis_lightning.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/models/luxonis_lightning.py#L888

Added line #L888 was not covered by tests

cfg_optimizer = self.cfg.trainer.optimizer
cfg_scheduler = self.cfg.trainer.scheduler

Expand Down
14 changes: 14 additions & 0 deletions luxonis_train/nodes/backbones/efficientrep/efficientrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,23 @@ def __init__(
)
)

self.initialize_weights()

if download_weights and var.weights_path:
self.load_checkpoint(var.weights_path)

def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
pass
elif isinstance(m, nn.BatchNorm2d):
m.eps = 0.001
m.momentum = 0.03
elif isinstance(
m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU)
):
m.inplace = True

def set_export_mode(self, mode: bool = True) -> None:
"""Reparametrizes instances of L{RepVGGBlock} in the network.

Expand Down
13 changes: 13 additions & 0 deletions luxonis_train/nodes/blocks/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ def __init__(self, n_classes: int, in_channels: int):

prior_prob = 1e-2
self._initialize_weights_and_biases(prior_prob)
self.initialize_weights()

def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
pass
elif isinstance(m, nn.BatchNorm2d):
m.eps = 0.001
m.momentum = 0.03
elif isinstance(
m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU)
):
m.inplace = True

def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
out_feature = self.decoder(x)
Expand Down
14 changes: 14 additions & 0 deletions luxonis_train/nodes/heads/efficient_bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,26 @@ def __init__(
f"output{i+1}_yolov6r2" for i in range(self.n_heads)
]

self.initialize_weights()

if download_weights:
# TODO: Handle variants of head in a nicer way
if self.in_channels == [32, 64, 128]:
weights_path = "https://github.com/luxonis/luxonis-train/releases/download/v0.1.0-beta/efficientbbox_head_n_coco.ckpt"
self.load_checkpoint(weights_path, strict=False)

def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
pass
elif isinstance(m, nn.BatchNorm2d):
m.eps = 0.001
m.momentum = 0.03
elif isinstance(
m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU)
):
m.inplace = True

def forward(
self, inputs: list[Tensor]
) -> tuple[list[Tensor], list[Tensor], list[Tensor]]:
Expand Down
14 changes: 14 additions & 0 deletions luxonis_train/nodes/necks/reppan_neck/reppan_neck.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,23 @@ def __init__(
out_channels = channels_list_down_blocks[2 * i + 1]
curr_n_repeats = n_repeats_down_blocks[i]

self.initialize_weights()

if download_weights and var.weights_path:
self.load_checkpoint(var.weights_path)

def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
pass
elif isinstance(m, nn.BatchNorm2d):
m.eps = 0.001
m.momentum = 0.03
elif isinstance(
m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU)
):
m.inplace = True

def forward(self, inputs: list[Tensor]) -> list[Tensor]:
x = inputs[-1]
up_block_outs: list[Tensor] = []
Expand Down
7 changes: 7 additions & 0 deletions luxonis_train/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .base_strategy import BaseTrainingStrategy
from .triple_lr_sgd import TripleLRScheduler

__all__ = [
"TripleLRScheduler",
"BaseTrainingStrategy",
]
28 changes: 28 additions & 0 deletions luxonis_train/strategies/base_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from abc import ABC, abstractmethod

import pytorch_lightning as pl
from luxonis_ml.utils.registry import AutoRegisterMeta
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

from luxonis_train.utils.registry import STRATEGIES


class BaseTrainingStrategy(
klemen1999 marked this conversation as resolved.
Show resolved Hide resolved
ABC,
metaclass=AutoRegisterMeta,
register=False,
registry=STRATEGIES,
):
def __init__(self, pl_module: pl.LightningModule):
self.pl_module = pl_module

Check warning on line 18 in luxonis_train/strategies/base_strategy.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/strategies/base_strategy.py#L18

Added line #L18 was not covered by tests

@abstractmethod
def configure_optimizers(
self,
) -> tuple[list[Optimizer], list[LRScheduler]]:
pass

@abstractmethod
def update_parameters(self, *args, **kwargs):
pass
Loading
Loading