-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Doc] Add Chinese doc for tutorials04_customized_models_md #630
Changes from 15 commits
8a9deea
5369c21
04d2e67
178194e
5d9055a
60cbc6e
da616df
6adb79e
12afb80
6ddb19d
42908c5
e5a0265
0b942b1
faec079
fb7ef56
ea8d79e
3ee7416
04e80b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1 +1,230 @@ | ||||||
# 教程 4: 自定义模型 | ||||||
|
||||||
## 自定义优化器 (optimizer) | ||||||
|
||||||
假设您想增加一个新的叫 `MyOptimizer` 的优化器,它的参数分别为 `a`, `b`, 和 `c`。 | ||||||
您首先需要在一个文件里实施这个新优化器,例如在 `mmseg/core/optimizer/my_optimizer.py` 里面: | ||||||
|
||||||
```python | ||||||
from mmcv.runner import OPTIMIZERS | ||||||
from torch.optim import Optimizer | ||||||
|
||||||
|
||||||
@OPTIMIZERS.register_module | ||||||
class MyOptimizer(Optimizer): | ||||||
|
||||||
def __init__(self, a, b, c) | ||||||
|
||||||
``` | ||||||
|
||||||
然后增加这个模块到 `mmseg/core/optimizer/__init__.py` 里面,这样注册表 (registry) 将会发现这个新的模块并添加它: | ||||||
|
||||||
```python | ||||||
from .my_optimizer import MyOptimizer | ||||||
``` | ||||||
|
||||||
之后您可以在配置文件的 `optimizer` 域里使用 `MyOptimizer`, | ||||||
如下所示,在配置文件里,优化器被 `optimizer` 域所定义: | ||||||
|
||||||
```python | ||||||
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) | ||||||
``` | ||||||
|
||||||
为了使用您自己的优化器,域可以被修改为: | ||||||
|
||||||
```python | ||||||
optimizer = dict(type='MyOptimizer', a=a_value, b=b_value, c=c_value) | ||||||
``` | ||||||
|
||||||
我们已经支持了 PyTorch 自带的全部优化器,唯一修改的地方是在配置文件里的 `optimizer` 域。例如,如果您想使用 `ADAM`,尽管数值表现会掉点,还是可以如下修改: | ||||||
|
||||||
```python | ||||||
optimizer = dict(type='Adam', lr=0.0003, weight_decay=0.0001) | ||||||
``` | ||||||
|
||||||
使用者可以直接按照 PyTorch [文档教程](https://pytorch.org/docs/stable/optim.html?highlight=optim#module-torch.optim) 去设置参数。 | ||||||
|
||||||
## 定制优化器的构造器 (optimizer constructor) | ||||||
|
||||||
对于优化,一些模型可能会有一些特别定义的参数,例如 批归一化 (BatchNorm) 层里面的权重衰减 (weight decay)。 | ||||||
使用者可以通过定制优化器的构造器来微调这些细粒度的优化器参数。 | ||||||
|
||||||
``` | ||||||
from mmcv.utils import build_from_cfg | ||||||
|
||||||
from mmcv.runner import OPTIMIZER_BUILDERS | ||||||
from .cocktail_optimizer import CocktailOptimizer | ||||||
|
||||||
|
||||||
@OPTIMIZER_BUILDERS.register_module | ||||||
class CocktailOptimizerConstructor(object): | ||||||
|
||||||
def __init__(self, optimizer_cfg, paramwise_cfg=None): | ||||||
|
||||||
def __call__(self, model): | ||||||
|
||||||
return my_optimizer | ||||||
|
||||||
``` | ||||||
|
||||||
## 开发新的组件 | ||||||
|
||||||
MMSegmentation 里主要有2种组件: | ||||||
|
||||||
- 骨架 (backbone): 通常是卷积网络的堆叠,来做特征提取,例如 ResNet, HRNet。 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
- 头 (head): 用于语义分割图的解码的组件。 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
### 添加新的骨架 | ||||||
|
||||||
这里我们以 MobileNet 为例,展示如何增加新的骨架组件: | ||||||
|
||||||
1. 创建一个新的文件 `mmseg/models/backbones/mobilenet.py`. | ||||||
|
||||||
```python | ||||||
import torch.nn as nn | ||||||
|
||||||
from ..registry import BACKBONES | ||||||
|
||||||
|
||||||
@BACKBONES.register_module | ||||||
class MobileNet(nn.Module): | ||||||
|
||||||
def __init__(self, arg1, arg2): | ||||||
pass | ||||||
|
||||||
def forward(self, x): # should return a tuple | ||||||
pass | ||||||
|
||||||
def init_weights(self, pretrained=None): | ||||||
pass | ||||||
``` | ||||||
|
||||||
2. 在 `mmseg/models/backbones/__init__.py` 里面导入模块。 | ||||||
|
||||||
```python | ||||||
from .mobilenet import MobileNet | ||||||
``` | ||||||
|
||||||
3. 在您的配置文件里使用它。 | ||||||
|
||||||
```python | ||||||
model = dict( | ||||||
... | ||||||
backbone=dict( | ||||||
type='MobileNet', | ||||||
arg1=xxx, | ||||||
arg2=xxx), | ||||||
... | ||||||
``` | ||||||
|
||||||
### 增加新的头组件 | ||||||
|
||||||
在 MMSegmentation 里面,对于所有的分割头,我们提供一个基础的 [BaseDecodeHead](https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/decode_head.py) 。 | ||||||
所有创新的解码头都应该由它生成。这里我们以 [PSPNet](https://arxiv.org/abs/1612.01105) 为例, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
展示如何开发一个新的头组件: | ||||||
|
||||||
首先,在 `mmseg/models/decode_heads/psp_head.py` 里添加一个新的解码头。 | ||||||
PSPNet 为了语义分割的解码,已经设置了一个解码头,基于这个已有的解码头我们想添加如下包含3个函数的新模块: | ||||||
|
||||||
```python | ||||||
@HEADS.register_module() | ||||||
class PSPHead(BaseDecodeHead): | ||||||
|
||||||
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs): | ||||||
super(PSPHead, self).__init__(**kwargs) | ||||||
|
||||||
def init_weights(self): | ||||||
|
||||||
def forward(self, inputs): | ||||||
|
||||||
``` | ||||||
|
||||||
接着,使用者需要在 `mmseg/models/decode_heads/__init__.py` 里面添加这个模块,这样对应的注册表 (registry) 可以查找并加载它们。 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
PSPNet的配置文件如下所示: | ||||||
|
||||||
```python | ||||||
norm_cfg = dict(type='SyncBN', requires_grad=True) | ||||||
model = dict( | ||||||
type='EncoderDecoder', | ||||||
pretrained='pretrain_model/resnet50_v1c_trick-2cccc1ad.pth', | ||||||
backbone=dict( | ||||||
type='ResNetV1c', | ||||||
depth=50, | ||||||
num_stages=4, | ||||||
out_indices=(0, 1, 2, 3), | ||||||
dilations=(1, 1, 2, 4), | ||||||
strides=(1, 2, 1, 1), | ||||||
norm_cfg=norm_cfg, | ||||||
norm_eval=False, | ||||||
style='pytorch', | ||||||
contract_dilation=True), | ||||||
decode_head=dict( | ||||||
type='PSPHead', | ||||||
in_channels=2048, | ||||||
in_index=3, | ||||||
channels=512, | ||||||
pool_scales=(1, 2, 3, 6), | ||||||
dropout_ratio=0.1, | ||||||
num_classes=19, | ||||||
norm_cfg=norm_cfg, | ||||||
align_corners=False, | ||||||
loss_decode=dict( | ||||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))) | ||||||
|
||||||
``` | ||||||
|
||||||
### 增加新的损失函数 | ||||||
|
||||||
假设您想添加一个新的损失函数 `MyLoss` 到语义分割解码器里。 | ||||||
为了添加一个新的损失函数,使用者需要在 `mmseg/models/losses/my_loss.py` 里面去实施它。 | ||||||
`weighted_loss` 可以对计算损失时的每个成分做加权。 | ||||||
|
||||||
```python | ||||||
import torch | ||||||
import torch.nn as nn | ||||||
|
||||||
from ..builder import LOSSES | ||||||
from .utils import weighted_loss | ||||||
|
||||||
@weighted_loss | ||||||
def my_loss(pred, target): | ||||||
assert pred.size() == target.size() and target.numel() > 0 | ||||||
loss = torch.abs(pred - target) | ||||||
return loss | ||||||
|
||||||
@LOSSES.register_module | ||||||
class MyLoss(nn.Module): | ||||||
|
||||||
def __init__(self, reduction='mean', loss_weight=1.0): | ||||||
super(MyLoss, self).__init__() | ||||||
self.reduction = reduction | ||||||
self.loss_weight = loss_weight | ||||||
|
||||||
def forward(self, | ||||||
pred, | ||||||
target, | ||||||
weight=None, | ||||||
avg_factor=None, | ||||||
reduction_override=None): | ||||||
assert reduction_override in (None, 'none', 'mean', 'sum') | ||||||
reduction = ( | ||||||
reduction_override if reduction_override else self.reduction) | ||||||
loss = self.loss_weight * my_loss( | ||||||
pred, target, weight, reduction=reduction, avg_factor=avg_factor) | ||||||
return loss | ||||||
``` | ||||||
|
||||||
然后使用者需要在 `mmseg/models/losses/__init__.py` 里面添加它: | ||||||
|
||||||
```python | ||||||
from .my_loss import MyLoss, my_loss | ||||||
|
||||||
``` | ||||||
|
||||||
为了使用它,修改 `loss_xxx` 域。之后您需要在头组件里修改 `loss_decode` 域。 | ||||||
`loss_weight` 可以被用来对不同的损失函数做加权。 | ||||||
|
||||||
```python | ||||||
loss_decode=dict(type='MyLoss', loss_weight=1.0)) | ||||||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.