From 75f1081c77d8addeff11d0360da1746216d45499 Mon Sep 17 00:00:00 2001 From: Srameo Date: Mon, 22 Jan 2024 23:18:23 +0800 Subject: [PATCH 1/5] [UPDATE] guidance on using image_process.py with other archs --- README.md | 2 ++ docs/demo.md | 2 ++ 2 files changed, 4 insertions(+) diff --git a/README.md b/README.md index 6ffcf24..bd5247b 100644 --- a/README.md +++ b/README.md @@ -182,6 +182,8 @@ Or you can just use the following pretrained LED module for custumizing on your We provide a script for testing your own RAW images in [image_process.py](scripts/image_process.py).
You could run `python scripts/image_process.py --help` to get detailed information of this scripts. > If your camera model is one of {Sony A7S2, Nikon D850}, you can found our pretrained model in [pretrained-models.md](docs/pretrained-models.md). +> +> **Notice that**, if you wish to use the model from release v0.1.1, you need to add the `-opt` parameter: For NAFNet, add `-opt options/base/network_g/nafnet.yaml`. For Restormer, add `-opt options/base/network_g/restormer.yaml`. ```bash usage: image_process.py [-h] -p PRETRAINED_NETWORK --data_path DATA_PATH [--save_path SAVE_PATH] [-opt NETWORK_OPTIONS] [--ratio RATIO] [--target_exposure TARGET_EXPOSURE] [--bps BPS] [--led] diff --git a/docs/demo.md b/docs/demo.md index 6b353c7..580d50d 100644 --- a/docs/demo.md +++ b/docs/demo.md @@ -61,6 +61,8 @@ python scripts/cutomized_denoiser.py -t [TAG] \ We provide a script for testing your own RAW images in [image_process.py](/scripts/image_process.py).
You could run `python scripts/image_process.py --help` to get detailed information of this scripts. > If your camera model is one of {Sony A7S2, Nikon D850}, you can found our pretrained model in [pretrained-models.md](/docs/pretrained-models.md). +> +> **Notice that**, if you wish to use the model from release v0.1.1, you need to add the `-opt` parameter: For NAFNet, add `-opt options/base/network_g/nafnet.yaml`. For Restormer, add `-opt options/base/network_g/restormer.yaml`. ```bash usage: image_process.py [-h] -p PRETRAINED_NETWORK --data_path DATA_PATH [--save_path SAVE_PATH] [-opt NETWORK_OPTIONS] [--ratio RATIO] [--bps BPS] [--led] From 4a0dcd475f88365be5ce8c84eb514bac7bc6b45a Mon Sep 17 00:00:00 2001 From: Lollikit Date: Thu, 14 Mar 2024 23:01:46 +0800 Subject: [PATCH 2/5] Fix the bug related to 'repnr' during the finetuning process --- led/archs/repnr_utils.py | 194 +++++++++++++++++++++++++-------------- 1 file changed, 126 insertions(+), 68 deletions(-) diff --git a/led/archs/repnr_utils.py b/led/archs/repnr_utils.py index 4e48c16..98a66b4 100644 --- a/led/archs/repnr_utils.py +++ b/led/archs/repnr_utils.py @@ -7,13 +7,18 @@ from torch.nn.init import kaiming_normal_, kaiming_uniform_ from torch.nn.modules.utils import _reverse_repeat_tuple + def zero_init_(x, a=None): x.zero_() + def build_repnr_arch_from_base(base_arch, **repnr_kwargs): base_arch = deepcopy(base_arch) - dont_convert_module = [] if 'dont_convert_module' not in repnr_kwargs \ - else repnr_kwargs.pop('dont_convert_module') + dont_convert_module = ( + [] + if "dont_convert_module" not in repnr_kwargs + else repnr_kwargs.pop("dont_convert_module") + ) def recursive_converter(base_arch, **repnr_kwargs): if isinstance(base_arch, nn.Conv2d): @@ -30,14 +35,16 @@ def recursive_converter(base_arch, **repnr_kwargs): class RepNRConv2d(nn.Module): - def _init_conv(self, init_type='kaiming_uniform_'): + def _init_conv(self, init_type="kaiming_uniform_"): weight = torch.zeros_like(self.main_weight) init_func = eval(init_type) init_func(weight, a=math.sqrt(5)) fan_in, _ = init._calculate_fan_in_and_fan_out(weight) - bias = torch.zeros((weight.size(0), ), - dtype=self.main_weight.dtype, - device=self.main_weight.device) + bias = torch.zeros( + (weight.size(0),), + dtype=self.main_weight.dtype, + device=self.main_weight.device, + ) if fan_in != 0: bound = 1 / math.sqrt(fan_in) init.uniform_(bias, -bound, bound) @@ -49,7 +56,9 @@ def _init_from_conv2d(self, conv2d: nn.Conv2d): self.kernel_size = conv2d.kernel_size self.stride = conv2d.stride self.padding = _reverse_repeat_tuple(conv2d.padding, 2) - self.padding_mode = 'constant' if conv2d.padding_mode == 'zeros' else conv2d.padding_mode + self.padding_mode = ( + "constant" if conv2d.padding_mode == "zeros" else conv2d.padding_mode + ) self.dilation = conv2d.dilation self.groups = conv2d.groups self.bias = conv2d.bias is not None @@ -57,12 +66,16 @@ def _init_from_conv2d(self, conv2d: nn.Conv2d): main_weight = conv2d.weight.data.clone() main_bias = conv2d.bias.data.clone() if self.bias else None self.main_weight = nn.Parameter(main_weight, requires_grad=True) - self.main_bias = nn.Parameter(main_bias, requires_grad=True) if main_bias is not None else None + self.main_bias = ( + nn.Parameter(main_bias, requires_grad=True) + if main_bias is not None + else None + ) def _init_alignments(self, align_opts): - align_init_weight = align_opts.get('init_weight', 1.0) - if 'init_bias' in align_opts: - align_init_bias = align_opts['init_bias'] + align_init_weight = align_opts.get("init_weight", 1.0) + if "init_bias" in align_opts: + align_init_bias = align_opts["init_bias"] align_bias = True else: align_bias = False @@ -70,35 +83,57 @@ def _init_alignments(self, align_opts): self.align_weights = nn.ParameterList( [ - nn.Parameter(torch.ones((1, align_channles, 1, 1), - dtype=self.main_weight.dtype, - device=self.main_weight.device) * align_init_weight, - requires_grad=True) + nn.Parameter( + torch.ones( + (1, align_channles, 1, 1), + dtype=self.main_weight.dtype, + device=self.main_weight.device, + ) + * align_init_weight, + requires_grad=True, + ) for _ in range(self.branch_num + 1) ] ) - self.align_biases = nn.ParameterList( - [ - nn.Parameter(torch.ones((1, align_channles, 1, 1), - dtype=self.main_weight.dtype, - device=self.main_weight.device) * align_init_bias, - requires_grad=True) - for _ in range(self.branch_num + 1) - ] - ) if align_bias else None + self.align_biases = ( + nn.ParameterList( + [ + nn.Parameter( + torch.ones( + (1, align_channles, 1, 1), + dtype=self.main_weight.dtype, + device=self.main_weight.device, + ) + * align_init_bias, + requires_grad=True, + ) + for _ in range(self.branch_num + 1) + ] + ) + if align_bias + else None + ) def _init_aux_conv(self, aux_conv_opts): - if 'init' not in aux_conv_opts: - aux_conv_opts['init'] = 'kaiming_normal_' + if "init" not in aux_conv_opts: + aux_conv_opts["init"] = "kaiming_normal_" - aux_weight, aux_bias = self._init_conv(init_type=aux_conv_opts['init']) + aux_weight, aux_bias = self._init_conv(init_type=aux_conv_opts["init"]) self.aux_weight = nn.Parameter(aux_weight, requires_grad=True) - self.aux_bias = nn.Parameter(torch.zeros_like(aux_bias), requires_grad=True) \ - if aux_conv_opts.get('bias', False) else torch.zeros_like(aux_bias) + self.aux_bias = ( + nn.Parameter(torch.zeros_like(aux_bias), requires_grad=True) + if aux_conv_opts.get("bias", False) + else torch.zeros_like(aux_bias) + ) - def __init__(self, conv2d, branch_num, - align_opts, aux_conv_opts=None, - forward_type='reparameterize'): + def __init__( + self, + conv2d, + branch_num, + align_opts, + aux_conv_opts=None, + forward_type="reparameterize", + ): super().__init__() self.branch_num = branch_num self.forward_type = forward_type @@ -111,17 +146,20 @@ def __init__(self, conv2d, branch_num, self._init_aux_conv(aux_conv_opts) self.cur_branch = -1 - self.forward = self._reparameterize_forward if forward_type == 'reparameterize' \ + self.forward = ( + self._reparameterize_forward + if forward_type == "reparameterize" else self._trivial_forward + ) def switch_forward_type(self, *, trivial=False, reparameterize=False): assert not (trivial and reparameterize) if trivial: self.forward = self._trivial_forward - self.forward_type = 'trivial' + self.forward_type = "trivial" elif reparameterize: self.forward = self._reparameterize_forward - self.forward_type = 'reparameterize' + self.forward_type = "reparameterize" @staticmethod def _sequential_reparamterize(k1, b1, k2, b2): @@ -129,15 +167,15 @@ def _sequential_reparamterize(k1, b1, k2, b2): def depthwise_to_normal(k, padding): k = k.reshape(-1) k = torch.diag(k).unsqueeze(-1).unsqueeze(-1) - k = F.pad(k, _reverse_repeat_tuple(padding, 2), mode='constant', value=0.0) + k = F.pad(k, _reverse_repeat_tuple(padding, 2), mode="constant", value=0.0) return k def bias_pad(b, padding): - return F.pad(b, _reverse_repeat_tuple(padding, 2), mode='replicate') + return F.pad(b, _reverse_repeat_tuple(padding, 2), mode="replicate") padding = (k2.shape[-2] - 1) // 2, (k2.shape[-1] - 1) // 2 k1 = depthwise_to_normal(k1, padding) - k = F.conv2d(k2, k1, stride=1, padding='same') + k = F.conv2d(k2, k1, stride=1, padding="same") b = F.conv2d(bias_pad(b1, padding), k2, bias=b2, stride=1).reshape(-1) return k, b @@ -149,15 +187,20 @@ def _parallel_reparamterize(k1, b1, k2, b2): def _weight_and_bias(self): index = self.cur_branch align_weight = self.align_weights[index] - align_bias = self.align_biases[index] if self.align_biases is not None \ + align_bias = ( + self.align_biases[index] + if self.align_biases is not None else torch.zeros_like(align_weight) + ) main_weight, main_bias = self._sequential_reparamterize( - align_weight, align_bias, self.main_weight, self.main_bias) + align_weight, align_bias, self.main_weight, self.main_bias + ) - if hasattr(self, 'aux_weight'): + if hasattr(self, "aux_weight"): main_weight, main_bias = self._parallel_reparamterize( - main_weight, main_bias, self.aux_weight, self.aux_bias) + main_weight, main_bias, self.aux_weight, self.aux_bias + ) return main_weight, main_bias @@ -170,16 +213,23 @@ def _reparameterize_forward(self, x): def _trivial_forward(self, x): index = self.cur_branch align_weight = self.align_weights[index] - align_bias = self.align_biases[index] if self.align_biases is not None \ + align_bias = ( + self.align_biases[index] + if self.align_biases is not None else torch.zeros_like(align_weight) + ) aligned_x = x * align_weight + align_bias padded_aligned_x = F.pad(aligned_x, self.padding, self.padding_mode, value=0.0) - main_x = F.conv2d(padded_aligned_x, self.main_weight, self.main_bias, stride=self.stride) + main_x = F.conv2d( + padded_aligned_x, self.main_weight, self.main_bias, stride=self.stride + ) - if hasattr(self, 'aux_weight'): + if hasattr(self, "aux_weight"): padded_x = F.pad(x, self.padding, self.padding_mode, value=0.0) - aux_x = F.conv2d(padded_x, self.aux_weight, self.aux_bias, stride=self.stride) + aux_x = F.conv2d( + padded_x, self.aux_weight, self.aux_bias, stride=self.stride + ) main_x = main_x + aux_x return main_x @@ -188,24 +238,29 @@ def _trivial_forward_with_intermediate_features(self, x): index = self.cur_branch features = {} align_weight = self.align_weights[index] - align_bias = self.align_biases[index] if self.align_biases is not None \ + align_bias = ( + self.align_biases[index] + if self.align_biases is not None else torch.zeros_like(align_weight) + ) - features['in_feat'] = x + features["in_feat"] = x aligned_x = x * align_weight + align_bias - features['aligned_feat'] = aligned_x + features["aligned_feat"] = aligned_x padded_aligned_x = F.pad(aligned_x, self.padding, self.padding_mode, value=0.0) - main_x = F.conv2d(padded_aligned_x, self.main_weight, self.main_bias, stride=self.stride) - features['main_feat'] = main_x + main_x = F.conv2d( + padded_aligned_x, self.main_weight, self.main_bias, stride=self.stride + ) + features["main_feat"] = main_x - if hasattr(self, 'aux_weight'): + if hasattr(self, "aux_weight"): padded_x = F.pad(x, self.padding, self.padding_mode, value=0.0) aux_x = F.conv2d(padded_x, self.aux_weight, self.aux_bias, stride=1) - features['aux_feat'] = aux_x + features["aux_feat"] = aux_x main_x = main_x + aux_x - features['out_feat'] = aux_x + features["out_feat"] = aux_x self.current_intermediate_features = features return main_x @@ -215,30 +270,30 @@ def __repr__(self): extra_repr = self.extra_repr() # empty string will be split into list [''] if extra_repr: - extra_lines = extra_repr.split('\n') + extra_lines = extra_repr.split("\n") child_lines = [] lines = extra_lines + child_lines - main_str = self._get_name() + '(' + main_str = self._get_name() + "(" if lines: # simple one-liner info, which most builtin Modules will use if len(extra_lines) == 1 and not child_lines: main_str += extra_lines[0] else: - main_str += '\n ' + '\n '.join(lines) + '\n' + main_str += "\n " + "\n ".join(lines) + "\n" - main_str += ')' + main_str += ")" return main_str def extra_repr(self) -> str: s = ( - '{in_channels}, {out_channels}, kernel_size={kernel_size}, ' - 'stride={stride}, padding={padding}, padding_mode={padding_mode},\n' - 'branch_num={branch_num}, align_opts={align_opts}\n' - 'forward_type={forward_type}' + "{in_channels}, {out_channels}, kernel_size={kernel_size}, " + "stride={stride}, padding={padding}, padding_mode={padding_mode},\n" + "branch_num={branch_num}, align_opts={align_opts}\n" + "forward_type={forward_type}" ) - if hasattr(self, 'aux_weight'): - s += ', aux_conv_opts={aux_conv_opts}' + if hasattr(self, "aux_weight"): + s += ", aux_conv_opts={aux_conv_opts}" return s.format(**self.__dict__) @@ -287,7 +342,9 @@ def generalize_align_conv(align_conv: RepNRConv2d): generalize_align_conv(m) @staticmethod - def _set_requires_grad(align_conv: RepNRConv2d, *, pretrain=True, finetune=False, aux=True): + def _set_requires_grad( + align_conv: RepNRConv2d, *, pretrain=True, finetune=False, aux=True + ): assert not (pretrain and finetune) for i in range(align_conv.branch_num): align_conv.align_weights[i].requires_grad_(pretrain) @@ -303,7 +360,7 @@ def _set_requires_grad(align_conv: RepNRConv2d, *, pretrain=True, finetune=False align_conv.main_bias.requires_grad_(pretrain) # aux weight and bias - if hasattr(align_conv, 'aux_weight') and aux: + if hasattr(align_conv, "aux_weight") and aux: align_conv.aux_weight.requires_grad_(finetune) if isinstance(align_conv.aux_bias, nn.Parameter): align_conv.aux_bias.requires_grad_(finetune) @@ -319,7 +376,7 @@ def finetune(self, *, aux=False): self._set_requires_grad(m, pretrain=False, finetune=True, aux=aux) def __repr__(self): - return f'{self._get_name()}: {str(self.repnr_module)}' + return f"{self._get_name()}: {str(self.repnr_module)}" @torch.no_grad() def deploy(self): @@ -333,7 +390,8 @@ def deploy(m_bak, m): if isinstance(m, RepNRConv2d): weight, bias = m._weight_and_bias m_bak.weight.data = weight - m_bak.bias.data = bias + if m_bak.bias is not None: + m_bak.bias.data = bias return if not has_repnr_conv(m): m_bak.load_state_dict(m.state_dict()) From 3e8470281ada4b5bf2c37d92a97c5d2d17682736 Mon Sep 17 00:00:00 2001 From: natsunoshion Date: Thu, 14 Mar 2024 23:46:37 +0800 Subject: [PATCH 3/5] Fix the bug related to 'repnr' during the finetuning process --- led/archs/repnr_utils.py | 191 ++++++++++++++------------------------- 1 file changed, 67 insertions(+), 124 deletions(-) diff --git a/led/archs/repnr_utils.py b/led/archs/repnr_utils.py index 98a66b4..94b1354 100644 --- a/led/archs/repnr_utils.py +++ b/led/archs/repnr_utils.py @@ -7,18 +7,13 @@ from torch.nn.init import kaiming_normal_, kaiming_uniform_ from torch.nn.modules.utils import _reverse_repeat_tuple - def zero_init_(x, a=None): x.zero_() - def build_repnr_arch_from_base(base_arch, **repnr_kwargs): base_arch = deepcopy(base_arch) - dont_convert_module = ( - [] - if "dont_convert_module" not in repnr_kwargs - else repnr_kwargs.pop("dont_convert_module") - ) + dont_convert_module = [] if 'dont_convert_module' not in repnr_kwargs \ + else repnr_kwargs.pop('dont_convert_module') def recursive_converter(base_arch, **repnr_kwargs): if isinstance(base_arch, nn.Conv2d): @@ -35,16 +30,14 @@ def recursive_converter(base_arch, **repnr_kwargs): class RepNRConv2d(nn.Module): - def _init_conv(self, init_type="kaiming_uniform_"): + def _init_conv(self, init_type='kaiming_uniform_'): weight = torch.zeros_like(self.main_weight) init_func = eval(init_type) init_func(weight, a=math.sqrt(5)) fan_in, _ = init._calculate_fan_in_and_fan_out(weight) - bias = torch.zeros( - (weight.size(0),), - dtype=self.main_weight.dtype, - device=self.main_weight.device, - ) + bias = torch.zeros((weight.size(0), ), + dtype=self.main_weight.dtype, + device=self.main_weight.device) if fan_in != 0: bound = 1 / math.sqrt(fan_in) init.uniform_(bias, -bound, bound) @@ -56,9 +49,7 @@ def _init_from_conv2d(self, conv2d: nn.Conv2d): self.kernel_size = conv2d.kernel_size self.stride = conv2d.stride self.padding = _reverse_repeat_tuple(conv2d.padding, 2) - self.padding_mode = ( - "constant" if conv2d.padding_mode == "zeros" else conv2d.padding_mode - ) + self.padding_mode = 'constant' if conv2d.padding_mode == 'zeros' else conv2d.padding_mode self.dilation = conv2d.dilation self.groups = conv2d.groups self.bias = conv2d.bias is not None @@ -66,16 +57,12 @@ def _init_from_conv2d(self, conv2d: nn.Conv2d): main_weight = conv2d.weight.data.clone() main_bias = conv2d.bias.data.clone() if self.bias else None self.main_weight = nn.Parameter(main_weight, requires_grad=True) - self.main_bias = ( - nn.Parameter(main_bias, requires_grad=True) - if main_bias is not None - else None - ) + self.main_bias = nn.Parameter(main_bias, requires_grad=True) if main_bias is not None else None def _init_alignments(self, align_opts): - align_init_weight = align_opts.get("init_weight", 1.0) - if "init_bias" in align_opts: - align_init_bias = align_opts["init_bias"] + align_init_weight = align_opts.get('init_weight', 1.0) + if 'init_bias' in align_opts: + align_init_bias = align_opts['init_bias'] align_bias = True else: align_bias = False @@ -83,57 +70,35 @@ def _init_alignments(self, align_opts): self.align_weights = nn.ParameterList( [ - nn.Parameter( - torch.ones( - (1, align_channles, 1, 1), - dtype=self.main_weight.dtype, - device=self.main_weight.device, - ) - * align_init_weight, - requires_grad=True, - ) + nn.Parameter(torch.ones((1, align_channles, 1, 1), + dtype=self.main_weight.dtype, + device=self.main_weight.device) * align_init_weight, + requires_grad=True) for _ in range(self.branch_num + 1) ] ) - self.align_biases = ( - nn.ParameterList( - [ - nn.Parameter( - torch.ones( - (1, align_channles, 1, 1), - dtype=self.main_weight.dtype, - device=self.main_weight.device, - ) - * align_init_bias, - requires_grad=True, - ) - for _ in range(self.branch_num + 1) - ] - ) - if align_bias - else None - ) + self.align_biases = nn.ParameterList( + [ + nn.Parameter(torch.ones((1, align_channles, 1, 1), + dtype=self.main_weight.dtype, + device=self.main_weight.device) * align_init_bias, + requires_grad=True) + for _ in range(self.branch_num + 1) + ] + ) if align_bias else None def _init_aux_conv(self, aux_conv_opts): - if "init" not in aux_conv_opts: - aux_conv_opts["init"] = "kaiming_normal_" + if 'init' not in aux_conv_opts: + aux_conv_opts['init'] = 'kaiming_normal_' - aux_weight, aux_bias = self._init_conv(init_type=aux_conv_opts["init"]) + aux_weight, aux_bias = self._init_conv(init_type=aux_conv_opts['init']) self.aux_weight = nn.Parameter(aux_weight, requires_grad=True) - self.aux_bias = ( - nn.Parameter(torch.zeros_like(aux_bias), requires_grad=True) - if aux_conv_opts.get("bias", False) - else torch.zeros_like(aux_bias) - ) + self.aux_bias = nn.Parameter(torch.zeros_like(aux_bias), requires_grad=True) \ + if aux_conv_opts.get('bias', False) else torch.zeros_like(aux_bias) - def __init__( - self, - conv2d, - branch_num, - align_opts, - aux_conv_opts=None, - forward_type="reparameterize", - ): + def __init__(self, conv2d, branch_num, + align_opts, aux_conv_opts=None, + forward_type='reparameterize'): super().__init__() self.branch_num = branch_num self.forward_type = forward_type @@ -146,20 +111,17 @@ def __init__( self._init_aux_conv(aux_conv_opts) self.cur_branch = -1 - self.forward = ( - self._reparameterize_forward - if forward_type == "reparameterize" + self.forward = self._reparameterize_forward if forward_type == 'reparameterize' \ else self._trivial_forward - ) def switch_forward_type(self, *, trivial=False, reparameterize=False): assert not (trivial and reparameterize) if trivial: self.forward = self._trivial_forward - self.forward_type = "trivial" + self.forward_type = 'trivial' elif reparameterize: self.forward = self._reparameterize_forward - self.forward_type = "reparameterize" + self.forward_type = 'reparameterize' @staticmethod def _sequential_reparamterize(k1, b1, k2, b2): @@ -167,15 +129,15 @@ def _sequential_reparamterize(k1, b1, k2, b2): def depthwise_to_normal(k, padding): k = k.reshape(-1) k = torch.diag(k).unsqueeze(-1).unsqueeze(-1) - k = F.pad(k, _reverse_repeat_tuple(padding, 2), mode="constant", value=0.0) + k = F.pad(k, _reverse_repeat_tuple(padding, 2), mode='constant', value=0.0) return k def bias_pad(b, padding): - return F.pad(b, _reverse_repeat_tuple(padding, 2), mode="replicate") + return F.pad(b, _reverse_repeat_tuple(padding, 2), mode='replicate') padding = (k2.shape[-2] - 1) // 2, (k2.shape[-1] - 1) // 2 k1 = depthwise_to_normal(k1, padding) - k = F.conv2d(k2, k1, stride=1, padding="same") + k = F.conv2d(k2, k1, stride=1, padding='same') b = F.conv2d(bias_pad(b1, padding), k2, bias=b2, stride=1).reshape(-1) return k, b @@ -187,20 +149,15 @@ def _parallel_reparamterize(k1, b1, k2, b2): def _weight_and_bias(self): index = self.cur_branch align_weight = self.align_weights[index] - align_bias = ( - self.align_biases[index] - if self.align_biases is not None + align_bias = self.align_biases[index] if self.align_biases is not None \ else torch.zeros_like(align_weight) - ) main_weight, main_bias = self._sequential_reparamterize( - align_weight, align_bias, self.main_weight, self.main_bias - ) + align_weight, align_bias, self.main_weight, self.main_bias) - if hasattr(self, "aux_weight"): + if hasattr(self, 'aux_weight'): main_weight, main_bias = self._parallel_reparamterize( - main_weight, main_bias, self.aux_weight, self.aux_bias - ) + main_weight, main_bias, self.aux_weight, self.aux_bias) return main_weight, main_bias @@ -213,23 +170,16 @@ def _reparameterize_forward(self, x): def _trivial_forward(self, x): index = self.cur_branch align_weight = self.align_weights[index] - align_bias = ( - self.align_biases[index] - if self.align_biases is not None + align_bias = self.align_biases[index] if self.align_biases is not None \ else torch.zeros_like(align_weight) - ) aligned_x = x * align_weight + align_bias padded_aligned_x = F.pad(aligned_x, self.padding, self.padding_mode, value=0.0) - main_x = F.conv2d( - padded_aligned_x, self.main_weight, self.main_bias, stride=self.stride - ) + main_x = F.conv2d(padded_aligned_x, self.main_weight, self.main_bias, stride=self.stride) - if hasattr(self, "aux_weight"): + if hasattr(self, 'aux_weight'): padded_x = F.pad(x, self.padding, self.padding_mode, value=0.0) - aux_x = F.conv2d( - padded_x, self.aux_weight, self.aux_bias, stride=self.stride - ) + aux_x = F.conv2d(padded_x, self.aux_weight, self.aux_bias, stride=self.stride) main_x = main_x + aux_x return main_x @@ -238,29 +188,24 @@ def _trivial_forward_with_intermediate_features(self, x): index = self.cur_branch features = {} align_weight = self.align_weights[index] - align_bias = ( - self.align_biases[index] - if self.align_biases is not None + align_bias = self.align_biases[index] if self.align_biases is not None \ else torch.zeros_like(align_weight) - ) - features["in_feat"] = x + features['in_feat'] = x aligned_x = x * align_weight + align_bias - features["aligned_feat"] = aligned_x + features['aligned_feat'] = aligned_x padded_aligned_x = F.pad(aligned_x, self.padding, self.padding_mode, value=0.0) - main_x = F.conv2d( - padded_aligned_x, self.main_weight, self.main_bias, stride=self.stride - ) - features["main_feat"] = main_x + main_x = F.conv2d(padded_aligned_x, self.main_weight, self.main_bias, stride=self.stride) + features['main_feat'] = main_x - if hasattr(self, "aux_weight"): + if hasattr(self, 'aux_weight'): padded_x = F.pad(x, self.padding, self.padding_mode, value=0.0) aux_x = F.conv2d(padded_x, self.aux_weight, self.aux_bias, stride=1) - features["aux_feat"] = aux_x + features['aux_feat'] = aux_x main_x = main_x + aux_x - features["out_feat"] = aux_x + features['out_feat'] = aux_x self.current_intermediate_features = features return main_x @@ -270,30 +215,30 @@ def __repr__(self): extra_repr = self.extra_repr() # empty string will be split into list [''] if extra_repr: - extra_lines = extra_repr.split("\n") + extra_lines = extra_repr.split('\n') child_lines = [] lines = extra_lines + child_lines - main_str = self._get_name() + "(" + main_str = self._get_name() + '(' if lines: # simple one-liner info, which most builtin Modules will use if len(extra_lines) == 1 and not child_lines: main_str += extra_lines[0] else: - main_str += "\n " + "\n ".join(lines) + "\n" + main_str += '\n ' + '\n '.join(lines) + '\n' - main_str += ")" + main_str += ')' return main_str def extra_repr(self) -> str: s = ( - "{in_channels}, {out_channels}, kernel_size={kernel_size}, " - "stride={stride}, padding={padding}, padding_mode={padding_mode},\n" - "branch_num={branch_num}, align_opts={align_opts}\n" - "forward_type={forward_type}" + '{in_channels}, {out_channels}, kernel_size={kernel_size}, ' + 'stride={stride}, padding={padding}, padding_mode={padding_mode},\n' + 'branch_num={branch_num}, align_opts={align_opts}\n' + 'forward_type={forward_type}' ) - if hasattr(self, "aux_weight"): - s += ", aux_conv_opts={aux_conv_opts}" + if hasattr(self, 'aux_weight'): + s += ', aux_conv_opts={aux_conv_opts}' return s.format(**self.__dict__) @@ -342,9 +287,7 @@ def generalize_align_conv(align_conv: RepNRConv2d): generalize_align_conv(m) @staticmethod - def _set_requires_grad( - align_conv: RepNRConv2d, *, pretrain=True, finetune=False, aux=True - ): + def _set_requires_grad(align_conv: RepNRConv2d, *, pretrain=True, finetune=False, aux=True): assert not (pretrain and finetune) for i in range(align_conv.branch_num): align_conv.align_weights[i].requires_grad_(pretrain) @@ -360,7 +303,7 @@ def _set_requires_grad( align_conv.main_bias.requires_grad_(pretrain) # aux weight and bias - if hasattr(align_conv, "aux_weight") and aux: + if hasattr(align_conv, 'aux_weight') and aux: align_conv.aux_weight.requires_grad_(finetune) if isinstance(align_conv.aux_bias, nn.Parameter): align_conv.aux_bias.requires_grad_(finetune) @@ -376,7 +319,7 @@ def finetune(self, *, aux=False): self._set_requires_grad(m, pretrain=False, finetune=True, aux=aux) def __repr__(self): - return f"{self._get_name()}: {str(self.repnr_module)}" + return f'{self._get_name()}: {str(self.repnr_module)}' @torch.no_grad() def deploy(self): From e59de91166e8dbe905daadcd041caefc15f8cc36 Mon Sep 17 00:00:00 2001 From: Xin Jin Date: Sun, 24 Mar 2024 12:28:49 +0800 Subject: [PATCH 4/5] [UPDATE] MIPI@2024 has finished. --- README.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index bd5247b..bd290ab 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,11 @@ -#
Let's Prepare for MIPI@2024! \[Starting-Kit\]
+

- ICCV23_LED_LOGO
+ ICCV23_LED_LOGO

-##
Homepage | Paper | Google Drive | Baidu Cloud | 知乎 | Poster | Slides | Video
+##
Homepage | Paper | Google Drive | Baidu Cloud | 知乎 | MIPI Starting-Kit +
@@ -18,7 +19,7 @@ This repository contains the official implementation of the following papers: > Lighting Every Darkness in Two Pairs: A Calibration-Free Pipeline for RAW Denoising
> [Xin Jin](https://srameo.github.io)\*, [Jia-Wen Xiao](https://github.com/schuy1er)\*, [Ling-Hao Han](https://scholar.google.com/citations?user=0ooNdgUAAAAJ&hl=en), [Chunle Guo](https://mmcheng.net/clguo/)\#, [Ruixun Zhang](https://www.math.pku.edu.cn/teachers/ZhangRuixun%20/index.html), [Xialei Liu](https://mmcheng.net/xliu/), [Chongyi Li](https://li-chongyi.github.io/)
> (\* denotes equal contribution. \# denotes the corresponding author.)
-> In ICCV 2023, \[[Paper Link](https://arxiv.org/abs/2308.03448v1)\] +> In ICCV 2023, \[[Paper Link](https://arxiv.org/abs/2308.03448v1)\], \[[Poster](https://github.com/Srameo/LED/files/12733867/iccv23_poster.pdf)\], \[[Slides](https://srameo.github.io/projects/led-iccv23/assets/slides/iccv23_slides_en.pdf)\], \[[Video](https://youtu.be/Jo8OTAnUYkU)\] > Make Explicit Calibration Implicit: Calibrate Denoiser Instead of the Noise Model
> [Xin Jin](https://srameo.github.io), [Jia-Wen Xiao](https://github.com/schuy1er), [Ling-Hao Han](https://scholar.google.com/citations?user=0ooNdgUAAAAJ&hl=en), [Chunle Guo](https://mmcheng.net/clguo/)\#, [Xialei Liu](https://mmcheng.net/xliu/), [Chongyi Li](https://li-chongyi.github.io/), [Ming-Ming Cheng](https://mmcheng.net/cmm/)\#
From cfa13e46ebf9678d91c11c98c528523b90667a3d Mon Sep 17 00:00:00 2001 From: Xin Jin Date: Sun, 24 Mar 2024 12:30:00 +0800 Subject: [PATCH 5/5] [FIX] the link of MIPI starting-kit --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index bd290ab..74d6a1a 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ ICCV23_LED_LOGO

-##