forked from dslisleedh/NAFNet-flax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
nafnet.py
68 lines (62 loc) · 2.71 KB
/
nafnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from einops import rearrange
from typing import List
from layers import *
import jax
import jax.numpy as jnp
import flax.linen as nn
class NAFNet(nn.Module):
n_filters: int = 16
n_enc_blocks: List = 1, 1, 1, 28
n_middle_blocks: int = 1
n_dec_blocks: List = 1, 1, 1, 1
dropout_rate: float = .1
train_size: List = None, 256, 256, 3
base_rate: float = 1.5
@nn.compact
def __call__(self, x, training=False):
n_stages = len(self.n_enc_blocks)
kh, kw = int(self.train_size[1] * self.base_rate), int(self.train_size[2] * self.base_rate)
features = nn.Conv(self.n_filters,
kernel_size=(3, 3),
padding='SAME'
)(x)
enc_skip = []
for i, n_blocks in enumerate(self.n_enc_blocks):
for _ in range(n_blocks):
features = NAFBlock(self.n_filters * (2 ** i),
self.dropout_rate,
kh // (2 ** i),
kw // (2 ** i)
)(features, deterministic=not training)
enc_skip.append(features)
features = nn.Conv(self.n_filters * (2 ** (i + 1)),
kernel_size=(2, 2),
strides=(2, 2),
padding='VALID'
)(features)
enc_skip = enc_skip[::-1]
for _ in range(self.n_middle_blocks):
features = NAFBlock(self.n_filters * (2 ** n_stages),
self.dropout_rate,
kh // (2 ** n_stages),
kw // (2 ** n_stages)
)(features, deterministic=not training)
for i, n_blocks in enumerate(self.n_dec_blocks):
features = nn.Conv(self.n_filters * (2 ** (n_stages - i)) * 2,
kernel_size=(1, 1),
padding='VALID'
)(features)
features = PixelShuffle(2)(features)
features = features + enc_skip[i]
for _ in range(n_blocks):
features = NAFBlock(self.n_filters * (2 ** (n_stages - (i + 1))),
self.dropout_rate,
kh // (2 ** (n_stages - (i + 1))),
kw // (2 ** (n_stages - (i + 1)))
)(features, deterministic=not training)
x_res = nn.Conv(3,
kernel_size=(3, 3),
padding='SAME'
)(features)
x = x + x_res
return x