Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
dragonbook committed Nov 25, 2018
1 parent ef9b63e commit 9db504a
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions src/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch.nn as nn
import torch.nn.functional as F

# TODO: custom weight initialization


class Basic3DBlock(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size):
super(Basic3DBlock, self).__init__()
self.block = nn.Sequential(
nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1, padding=((kernel_size-1)//2)),
nn.BatchNorm3d(out_planes),
nn.ReLU(True)
)

def forward(self, x):
return self.block(x)


class Res3DBlock(nn.Module):
def __init__(self, in_planes, out_planes):
super(Res3DBlock, self).__init__()
self.res_branch = nn.Sequential(
nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=1, padding=1),
nn.BatchNorm3d(out_planes),
nn.ReLU(True),
nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=1, padding=1),
nn.BatchNorm3d(out_planes)
)

if in_planes == out_planes:
self.skip_con = nn.Sequential()
else:
self.skip_con = nn.Sequential(
nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=1, padding=0),
nn.BatchNorm3d(out_planes)
)

def forward(self, x):
res = self.res_branch(x)
skip = self.skip_con(x)
return F.relu(res + skip, True)


class Upsample3DBlock(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride):
super(Upsample3DBlock, self).__init__()
self.block = nn.Sequential(
nn.ConvTranspose3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, output_padding=stride-1),
nn.BatchNorm3d(out_planes),
nn.ReLU(True)
)

def forward(self, x):
return self.block(x)

0 comments on commit 9db504a

Please sign in to comment.