Skip to content

Commit

Permalink
Add ECBSR (#478)
Browse files Browse the repository at this point in the history
* add ecbsr arch

* first run train_ECBSR_x4_m4c16_prelu

* 255 range

* clean arch

* improve datasets

* update ecbsr option files

* update readme

* update readme

* reorganize history updates

* update readme: ecbsr

* update readme

* update readme

* update license of ecbsr
  • Loading branch information
xinntao authored Oct 5, 2021
1 parent 9309e26 commit a129e46
Show file tree
Hide file tree
Showing 10 changed files with 591 additions and 37 deletions.
2 changes: 2 additions & 0 deletions LICENSE/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ This BasicSR project is released under the Apache 2.0 license.
- We use the implementation of `DropPath` and `trunc_normal_` from [pytorch-image-models](https://github.com/rwightman/pytorch-image-models/). The LICENSE is included as [LICENSE_pytorch-image-models](LICENSE/LICENSE_pytorch-image-models).
- [SwinIR](https://github.com/JingyunLiang/SwinIR)
- The arch implementation of SwinIR is from [SwinIR](https://github.com/JingyunLiang/SwinIR). The LICENSE is included as [LICENSE_SwinIR](LICENSE/LICENSE_SwinIR).
- [ECBSR](https://github.com/xindongzhang/ECBSR)
- The arch implementation of ECBSR is from [ECBSR](https://github.com/xindongzhang/ECBSR). The LICENSE of ECBSR is [Apache License 2.0](https://github.com/xindongzhang/ECBSR/blob/main/LICENSE)

## References

Expand Down
21 changes: 6 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

:loudspeaker: **技术交流QQ群****320960100**   入群答案:**互帮互助共同进步**

:compass: [入群二维码](#e-mail-contact)    [入群指南 (腾讯文档)](https://docs.qq.com/doc/DYXBSUmxOT0xBZ05u)
:compass: [入群二维码](#e-mail-contact) (QQ、微信)    [入群指南 (腾讯文档)](https://docs.qq.com/doc/DYXBSUmxOT0xBZ05u)

---

Expand All @@ -33,22 +33,13 @@ BasicSR (**Basic** **S**uper **R**estoration) 是一个基于 PyTorch 的开源

:triangular_flag_on_post: **New Features/Updates**

- :white_check_mark: Sep 2, 2021. Add **SwinIR training and testing** codes: [SwinIR](https://github.com/JingyunLiang/SwinIR) by [Jingyun Liang](https://github.com/JingyunLiang):+1:. More details are in [HOWTOs.md](docs/HOWTOs.md#how-to-train-swinir-sr)
- :white_check_mark: Oct 5, 2021. Add **ECBSR training and testing** codes: [ECBSR](https://github.com/xindongzhang/ECBSR).
> ACMMM21: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
- :white_check_mark: Sep 2, 2021. Add **SwinIR training and testing** codes: [SwinIR](https://github.com/JingyunLiang/SwinIR) by [Jingyun Liang](https://github.com/JingyunLiang). More details are in [HOWTOs.md](docs/HOWTOs.md#how-to-train-swinir-sr)
- :white_check_mark: Aug 5, 2021. Add NIQE, which produces the same results as MATLAB (both are 5.7296 for tests/data/baboon.png).
- :white_check_mark: July 31, 2021. Add **bi-directional video super-resolution** codes: [**BasicVSR** and IconVSR](https://arxiv.org/abs/2012.02181).
- :white_check_mark: July 20, 2021. Add **dual-blind face restoration** codes: [HiFaceGAN](https://github.com/Lotayou/Face-Renovation) codes by [Lotayou](https://lotayou.github.io/).
- :white_check_mark: Nov 29, 2020. Add **ESRGAN** and **DFDNet** [colab demo](colab)
- :white_check_mark: Sep 8, 2020. Add **blind face restoration** inference codes: [DFDNet](https://github.com/csxmli2016/DFDNet).
- :white_check_mark: Aug 27, 2020. Add **StyleGAN2 training and testing** codes: [StyleGAN2](https://github.com/rosinality/stylegan2-pytorch).

<details>
<summary>More</summary>
<ul>
<li> Sep 8, 2020. Add <b>blind face restoration</b> inference codes: <b>DFDNet</b>. <br> <i><font color="#DCDCDC">ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries</font></i> <br> <i><font color="#DCDCDC">Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang</font></i> </li>
<li> Aug 27, 2020. Add <b>StyleGAN2</b> training and testing codes. <br> <i><font color="#DCDCDC">CVPR20: Analyzing and Improving the Image Quality of StyleGAN</font></i> <br> <i><font color="#DCDCDC">Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila</font></i> </li>
<li>Aug 19, 2020. A <b>brand-new</b> BasicSR v1.0.0 online.</li>
</ul>
</details>
> CVPR21: BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond
- **[More](docs/history_updates.md)**

:sparkles: **Projects that use BasicSR**
- [**Real-ESRGAN**](https://github.com/xinntao/Real-ESRGAN): A practical algorithm for general image restoration
Expand Down
21 changes: 6 additions & 15 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

:loudspeaker: **技术交流QQ群****320960100** &emsp; 入群答案:**互帮互助共同进步**

:compass: [入群二维码](#e-mail-%E8%81%94%E7%B3%BB) &emsp;&emsp; [入群指南 (腾讯文档)](https://docs.qq.com/doc/DYXBSUmxOT0xBZ05u)
:compass: [入群二维码](#e-mail-%E8%81%94%E7%B3%BB) (QQ、微信) &emsp;&emsp; [入群指南 (腾讯文档)](https://docs.qq.com/doc/DYXBSUmxOT0xBZ05u)

---

Expand All @@ -31,22 +31,13 @@ BasicSR (**Basic** **S**uper **R**estoration) 是一个基于 PyTorch 的开源

:triangular_flag_on_post: **新的特性/更新**

- :white_check_mark: Sep 2, 2021. 添加 **SwinIR 训练和测试** 代码: [SwinIR](https://github.com/JingyunLiang/SwinIR) by [Jingyun Liang](https://github.com/JingyunLiang):+1:. 更多内容参见 [HOWTOs.md](docs/HOWTOs.md#how-to-train-swinir-sr)
- :white_check_mark: Oct 5, 2021. 添加 **ECBSR 训练和测试** 代码: [ECBSR](https://github.com/xindongzhang/ECBSR).
> ACMMM21: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
- :white_check_mark: Sep 2, 2021. 添加 **SwinIR 训练和测试** 代码: [SwinIR](https://github.com/JingyunLiang/SwinIR) by [Jingyun Liang](https://github.com/JingyunLiang). 更多内容参见 [HOWTOs.md](docs/HOWTOs.md#how-to-train-swinir-sr)
- :white_check_mark: Aug 5, 2021. 添加了NIQE, 它输出和MATLAB一样的结果 (both are 5.7296 for tests/data/baboon.png).
- :white_check_mark: July 31, 2021. Add **bi-directional video super-resolution** codes: [**BasicVSR** and IconVSR](https://arxiv.org/abs/2012.02181).
- :white_check_mark: July 20, 2021. Add **dual-blind face restoration** codes: [**HiFaceGAN**](https://github.com/Lotayou/Face-Renovation) codes by [Lotayou](https://lotayou.github.io/).
- :white_check_mark: Nov 29, 2020. 添加 **ESRGAN** and **DFDNet** [colab demo](colab).
- :white_check_mark: Sep 8, 2020. 添加 **盲人脸复原**测试代码: [DFDNet](https://github.com/csxmli2016/DFDNet).
- :white_check_mark: Aug 27, 2020. 添加 **StyleGAN2 训练和测试** 代码: [StyleGAN2](https://github.com/rosinality/stylegan2-pytorch).

<details>
<summary>更多</summary>
<ul>
<li> Sep 8, 2020. 添加 <b>盲人脸复原</b> 测试代码: <b>DFDNet</b>. <br> <i><font color="#DCDCDC">ECCV20: Blind Face Restoration via Deep Multi-scale Component Dictionaries</font></i> <br> <i><font color="#DCDCDC">Xiaoming Li, Chaofeng Chen, Shangchen Zhou, Xianhui Lin, Wangmeng Zuo and Lei Zhang</font></i> </li>
<li> Aug 27, 2020. 添加 <b>StyleGAN2</b> 训练和测试代码. <br> <i><font color="#DCDCDC">CVPR20: Analyzing and Improving the Image Quality of StyleGAN</font></i> <br> <i><font color="#DCDCDC">Tero Karras, Samuli Laine, Miika Aittala, Janne Hellsten, Jaakko Lehtinen and Timo Aila</font></i> </li>
<li>Aug 19, 2020. <b>全新的</b> BasicSR v1.0.0 上线.</li>
</ul>
</details>
> CVPR21: BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond
- **[更多](docs/history_updates.md)**

:sparkles: **使用 BasicSR 的项目**
- [**Real-ESRGAN**](https://github.com/xinntao/Real-ESRGAN): 通用图像复原的实用算法
Expand Down
245 changes: 245 additions & 0 deletions basicsr/archs/ecbsr_arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from basicsr.utils.registry import ARCH_REGISTRY


class SeqConv3x3(nn.Module):

def __init__(self, seq_type, inp_planes, out_planes, depth_multiplier=1):
super(SeqConv3x3, self).__init__()
self.seq_type = seq_type
self.inp_planes = inp_planes
self.out_planes = out_planes

if self.seq_type == 'conv1x1-conv3x3':
self.mid_planes = int(out_planes * depth_multiplier)
conv0 = torch.nn.Conv2d(self.inp_planes, self.mid_planes, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias

conv1 = torch.nn.Conv2d(self.mid_planes, self.out_planes, kernel_size=3)
self.k1 = conv1.weight
self.b1 = conv1.bias

elif self.seq_type == 'conv1x1-sobelx':
conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias

# init scale and bias
scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
self.scale = nn.Parameter(scale)
bias = torch.randn(self.out_planes) * 1e-3
bias = torch.reshape(bias, (self.out_planes, ))
self.bias = nn.Parameter(bias)
# init mask
self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
for i in range(self.out_planes):
self.mask[i, 0, 0, 0] = 1.0
self.mask[i, 0, 1, 0] = 2.0
self.mask[i, 0, 2, 0] = 1.0
self.mask[i, 0, 0, 2] = -1.0
self.mask[i, 0, 1, 2] = -2.0
self.mask[i, 0, 2, 2] = -1.0
self.mask = nn.Parameter(data=self.mask, requires_grad=False)

elif self.seq_type == 'conv1x1-sobely':
conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias

# init scale and bias
scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
self.scale = nn.Parameter(torch.FloatTensor(scale))
bias = torch.randn(self.out_planes) * 1e-3
bias = torch.reshape(bias, (self.out_planes, ))
self.bias = nn.Parameter(torch.FloatTensor(bias))
# init mask
self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
for i in range(self.out_planes):
self.mask[i, 0, 0, 0] = 1.0
self.mask[i, 0, 0, 1] = 2.0
self.mask[i, 0, 0, 2] = 1.0
self.mask[i, 0, 2, 0] = -1.0
self.mask[i, 0, 2, 1] = -2.0
self.mask[i, 0, 2, 2] = -1.0
self.mask = nn.Parameter(data=self.mask, requires_grad=False)

elif self.seq_type == 'conv1x1-laplacian':
conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
self.k0 = conv0.weight
self.b0 = conv0.bias

# init scale and bias
scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
self.scale = nn.Parameter(torch.FloatTensor(scale))
bias = torch.randn(self.out_planes) * 1e-3
bias = torch.reshape(bias, (self.out_planes, ))
self.bias = nn.Parameter(torch.FloatTensor(bias))
# init mask
self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
for i in range(self.out_planes):
self.mask[i, 0, 0, 1] = 1.0
self.mask[i, 0, 1, 0] = 1.0
self.mask[i, 0, 1, 2] = 1.0
self.mask[i, 0, 2, 1] = 1.0
self.mask[i, 0, 1, 1] = -4.0
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
else:
raise ValueError('The type of seqconv is not supported!')

def forward(self, x):
if self.seq_type == 'conv1x1-conv3x3':
# conv-1x1
y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
# explicitly padding with bias
y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
b0_pad = self.b0.view(1, -1, 1, 1)
y0[:, :, 0:1, :] = b0_pad
y0[:, :, -1:, :] = b0_pad
y0[:, :, :, 0:1] = b0_pad
y0[:, :, :, -1:] = b0_pad
# conv-3x3
y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
else:
y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
# explicitly padding with bias
y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
b0_pad = self.b0.view(1, -1, 1, 1)
y0[:, :, 0:1, :] = b0_pad
y0[:, :, -1:, :] = b0_pad
y0[:, :, :, 0:1] = b0_pad
y0[:, :, :, -1:] = b0_pad
# conv-3x3
y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_planes)
return y1

def rep_params(self):
device = self.k0.get_device()
if device < 0:
device = None

if self.seq_type == 'conv1x1-conv3x3':
# re-param conv kernel
rep_weight = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
# re-param conv bias
rep_bias = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
rep_bias = F.conv2d(input=rep_bias, weight=self.k1).view(-1, ) + self.b1
else:
tmp = self.scale * self.mask
k1 = torch.zeros((self.out_planes, self.out_planes, 3, 3), device=device)
for i in range(self.out_planes):
k1[i, i, :, :] = tmp[i, 0, :, :]
b1 = self.bias
# re-param conv kernel
rep_weight = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
# re-param conv bias
rep_bias = torch.ones(1, self.out_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
rep_bias = F.conv2d(input=rep_bias, weight=k1).view(-1, ) + b1
return rep_weight, rep_bias


class ECB(nn.Module):

def __init__(self, inp_planes, out_planes, depth_multiplier, act_type='prelu', with_idt=False):
super(ECB, self).__init__()

self.depth_multiplier = depth_multiplier
self.inp_planes = inp_planes
self.out_planes = out_planes
self.act_type = act_type

if with_idt and (self.inp_planes == self.out_planes):
self.with_idt = True
else:
self.with_idt = False

self.conv3x3 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=3, padding=1)
self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.inp_planes, self.out_planes, self.depth_multiplier)
self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.inp_planes, self.out_planes)
self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.inp_planes, self.out_planes)
self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.inp_planes, self.out_planes)

if self.act_type == 'prelu':
self.act = nn.PReLU(num_parameters=self.out_planes)
elif self.act_type == 'relu':
self.act = nn.ReLU(inplace=True)
elif self.act_type == 'rrelu':
self.act = nn.RReLU(lower=-0.05, upper=0.05)
elif self.act_type == 'softplus':
self.act = nn.Softplus()
elif self.act_type == 'linear':
pass
else:
raise ValueError('The type of activation if not support!')

def forward(self, x):
if self.training:
y = self.conv3x3(x) + self.conv1x1_3x3(x) + self.conv1x1_sbx(x) + self.conv1x1_sby(x) + self.conv1x1_lpl(x)
if self.with_idt:
y += x
else:
rep_weight, rep_bias = self.rep_params()
y = F.conv2d(input=x, weight=rep_weight, bias=rep_bias, stride=1, padding=1)
if self.act_type != 'linear':
y = self.act(y)
return y

def rep_params(self):
weight0, bias0 = self.conv3x3.weight, self.conv3x3.bias
weight1, bias1 = self.conv1x1_3x3.rep_params()
weight2, bias2 = self.conv1x1_sbx.rep_params()
weight3, bias3 = self.conv1x1_sby.rep_params()
weight4, bias4 = self.conv1x1_lpl.rep_params()
rep_weight, rep_bias = (weight0 + weight1 + weight2 + weight3 + weight4), (
bias0 + bias1 + bias2 + bias3 + bias4)

if self.with_idt:
device = rep_weight.get_device()
if device < 0:
device = None
weight_idt = torch.zeros(self.out_planes, self.out_planes, 3, 3, device=device)
for i in range(self.out_planes):
weight_idt[i, i, 1, 1] = 1.0
bias_idt = 0.0
rep_weight, rep_bias = rep_weight + weight_idt, rep_bias + bias_idt
return rep_weight, rep_bias


@ARCH_REGISTRY.register()
class ECBSR(nn.Module):
"""ECBSR architecture.
Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
Ref git repo: https://github.com/xindongzhang/ECBSR
Args:
num_in_ch (int): Channel number of inputs.
num_out_ch (int): Channel number of outputs.
num_block (int): Block number in the trunk network.
num_channel (int): Channel number.
with_idt (bool): Whether use identity in convolution layers.
act_type (str): Activation type.
scale (int): Upsampling factor.
"""

def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_type, scale):
super(ECBSR, self).__init__()

backbone = []
backbone += [ECB(num_in_ch, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
for _ in range(num_block):
backbone += [ECB(num_channel, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
backbone += [
ECB(num_channel, num_out_ch * scale * scale, depth_multiplier=2.0, act_type='linear', with_idt=with_idt)
]

self.backbone = nn.Sequential(*backbone)
self.upsampler = nn.PixelShuffle(scale)

def forward(self, x):
y = self.backbone(x) + x # will repeat the input in the channel dimension (repeat scale * scale times)
y = self.upsampler(y)
return y
11 changes: 10 additions & 1 deletion basicsr/data/paired_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, imfrombytes, img2tensor
from basicsr.utils.matlab_functions import rgb2ycbcr
from basicsr.utils.registry import DATASET_REGISTRY


Expand Down Expand Up @@ -87,7 +88,15 @@ def __getitem__(self, index):
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot'])

# TODO: color space transform
if self.opt['color'] == 'y':
img_gt = rgb2ycbcr(img_gt, y_only=True)[..., None]
img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]

# crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
# TODO: It is better to update the datasets, rather than force to crop
if self.opt['phase'] != 'train':
img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]

# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
# normalize
Expand Down
Loading

0 comments on commit a129e46

Please sign in to comment.