Skip to content

use nn.Sequential to remove python control flow from autoencoder up/downsampling #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 54 additions & 36 deletions flux/modules/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ def forward(self, x: Tensor):
x = self.conv(x)
return x

class DownBlock(nn.Module):
def __init__(self, block: list, downsample: nn.Module) -> None:
super().__init__()
# we're doing this instead of a flat nn.Sequential to preserve the keys "block" "downsample"
self.block = nn.Sequential(*block)
self.downsample = downsample

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.downsample(self.block(x))


class Encoder(nn.Module):
def __init__(
Expand All @@ -128,23 +138,25 @@ def __init__(
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
down_layers = []
block_in = self.ch
# ideally, this would all append to a single flat nn.Sequential
# we cannot do this due to the existing state dict keys
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
block_layers = []
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
block_layers.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out # ?
# originally this provided for attn layers, but those are never actually created
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in)
downsample = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
else:
downsample = nn.Identity()
down_layers.append(DownBlock(block_layers, downsample))
self.down = nn.Sequential(*down_layers)

# middle
self.mid = nn.Module()
Expand All @@ -158,18 +170,10 @@ def __init__(

def forward(self, x: Tensor) -> Tensor:
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
h = self.conv_in(h)
h = self.down(h)

# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
Expand All @@ -179,6 +183,15 @@ def forward(self, x: Tensor) -> Tensor:
h = self.conv_out(h)
return h

class UpBlock(nn.Module):
def __init__(self, block: list, upsample: nn.Module) -> None:
super().__init__()
self.block = nn.Sequential(*block)
self.upsample = upsample

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.upsample(self.block(x))


class Decoder(nn.Module):
def __init__(
Expand Down Expand Up @@ -214,26 +227,37 @@ def __init__(
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)

# upsampling
self.up = nn.ModuleList()
up_blocks = []
# 3, 2, 1, 0, descending order
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
level_blocks = []
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
level_blocks.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in)
upsample = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
else:
upsample = nn.Identity()
# 0, 1, 2, 3, ascending order
up_blocks.insert(0, UpBlock(level_blocks, upsample)) # prepend to get consistent order
self.up = nn.Sequential(*up_blocks)

# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)

# this is a hack to get something like property but only evaluate it once
# we're doing it like this so that up_descending isn't in the state_dict keys
# without adding anything conditional to the main flow
def __getattr__(self, name):
if name == "up_descending":
self.up_descending = nn.Sequential(*reversed(self.up))
Decoder.__getattr__ = nn.Module.__getattr__
return self.up_descending
return super().__getattr__(name)

def forward(self, z: Tensor) -> Tensor:
# z to block_in
h = self.conv_in(z)
Expand All @@ -244,13 +268,7 @@ def forward(self, z: Tensor) -> Tensor:
h = self.mid.block_2(h)

# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
h = self.up_descending(h)

# end
h = self.norm_out(h)
Expand Down
Loading