Skip to content

Commit

Permalink
add PhysNet_ConvLSTM_2DCNN
Browse files Browse the repository at this point in the history
  • Loading branch information
Jinsoo Kim committed Aug 10, 2021
1 parent 3121c35 commit 666b4a4
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 31 deletions.
66 changes: 48 additions & 18 deletions nets/blocks/cnn_blocks.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,65 @@
import torch.nn

from nets.blocks.blocks import ConvBlock2D
from nets.blocks.blocks import ConvBlock3D


class cnn_blocks(torch.nn.Module):
def __init__(self):
super(cnn_blocks, self).__init__()
self.cnn_blocks = torch.nn.Sequential(
ConvBlock3D(3, 16, [1, 5, 5], [1, 1, 1], [0, 2, 2]),
torch.nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
ConvBlock3D(16, 32, [1, 3, 3], [1, 1, 1], [1, 1, 1]),
ConvBlock3D(32, 64, [1, 3, 3], [1, 1, 1], [1, 1, 1]),
torch.nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
ConvBlock3D(64, 64, [1, 3, 3], [1, 1, 1], [1, 1, 1]),
ConvBlock3D(64, 64, [1, 3, 3], [1, 1, 1], [1, 1, 1]),
torch.nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
ConvBlock3D(64, 64, [1, 3, 3], [1, 1, 1], [1, 1, 1]),
ConvBlock3D(64, 64, [1, 3, 3], [1, 1, 1], [1, 1, 1]),
torch.nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
ConvBlock3D(64, 64, [1, 3, 3], [1, 1, 1], [1, 1, 1]),
ConvBlock3D(64, 64, [1, 3, 3], [1, 1, 1], [1, 1, 1]),
# torch.nn.AdaptiveMaxPool3d(1)
ConvBlock2D(3, 16, [5, 5], [1, 1], [2, 2]),
torch.nn.MaxPool2d((2, 2), stride=(2, 2)),
ConvBlock2D(16, 32, [3, 3], [1, 1], [1, 1]),
ConvBlock2D(32, 64, [3, 3], [1, 1], [1, 1]),
torch.nn.MaxPool2d((2, 2), stride=(2, 2)),
ConvBlock2D(64, 64, [3, 3], [1, 1], [1, 1]),
ConvBlock2D(64, 64, [3, 3], [1, 1], [1, 1]),
torch.nn.MaxPool2d((2, 2), stride=(2, 2)),
ConvBlock2D(64, 64, [3, 3], [1, 1], [1, 1]),
ConvBlock2D(64, 64, [3, 3], [1, 1], [1, 1]),
torch.nn.MaxPool2d((2, 2), stride=(2, 2)),
ConvBlock2D(64, 64, [3, 3], [1, 1], [1, 1]),
ConvBlock2D(64, 64, [3, 3], [1, 1], [1, 1]),
torch.nn.AdaptiveMaxPool2d(1)
)

def forward(self, x):
[batch, channel, length, width, height] = x.shape
# x = x.reshape(batch * length, channel, width, height)
# x = self.cnn_blocks(x)
# x = x.reshape(batch,length,-1,1,1)
x = x.view(batch * length, channel, width, height)
x = self.cnn_blocks(x)
x = x.view(batch,length,-1,1,1)

return x

'''
Conv3D 1x3x3(paper architecture)
'''
# class cnn_blocks(torch.nn.Module):
# def __init__(self):
# super(cnn_blocks, self).__init__()
# self.cnn_blocks = torch.nn.Sequential(
# ConvBlock3D(3, 16, [1, 5, 5], [1, 1, 1], [0, 2, 2]),
# torch.nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
# ConvBlock3D(16, 32, [1, 3, 3], [1, 1, 1], [1, 1, 1]),
# ConvBlock3D(32, 64, [1, 3, 3], [1, 1, 1], [1, 1, 1]),
# torch.nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
# ConvBlock3D(64, 64, [1, 3, 3], [1, 1, 1], [1, 1, 1]),
# ConvBlock3D(64, 64, [1, 3, 3], [1, 1, 1], [1, 1, 1]),
# torch.nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
# ConvBlock3D(64, 64, [1, 3, 3], [1, 1, 1], [1, 1, 1]),
# ConvBlock3D(64, 64, [1, 3, 3], [1, 1, 1], [1, 1, 1]),
# torch.nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
# ConvBlock3D(64, 64, [1, 3, 3], [1, 1, 1], [1, 1, 1]),
# ConvBlock3D(64, 64, [1, 3, 3], [1, 1, 1], [1, 1, 1]),
# # torch.nn.AdaptiveMaxPool3d(1)
# )
#
# def forward(self, x):
# [batch, channel, length, width, height] = x.shape
# # x = x.reshape(batch * length, channel, width, height)
# # x = self.cnn_blocks(x)
# # x = x.reshape(batch,length,-1,1,1)
# x = self.cnn_blocks(x)
#
# return x

15 changes: 4 additions & 11 deletions nets/models/PhysNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,14 @@ def __init__(self, frame=32):
super(PhysNet_2DCNN_LSTM, self).__init__()
self.physnet_lstm = torch.nn.ModuleDict({
'cnn_blocks' : cnn_blocks(),
# 'lstm' : torch.nn.LSTM(input_size=64, hidden_size=64, num_layers=2, bidirectional=True, batch_first=True),
'spatial_global_avgpool' : torch.nn.AdaptiveMaxPool3d((frame, 1, 1)),
'cov_lstm' : ConvLSTM(64,[1,1,64],(1,1),num_layers=3, batch_first=True,bias=True, return_all_layers=False),
'cov_lstm' : ConvLSTM(64,[1,1,64],(1,1),num_layers=3, batch_first=True, bias=True, return_all_layers=False),
'cnn_flatten' : torch.nn.Conv3d(64, 1, [1, 1, 1], stride=1, padding=0)

})

def forward(self, x):
[batch, channel, length, width, height] = x.shape
x = self.physnet_lstm['cnn_blocks'](x)
# x,(_,_) = self.physnet_lstm['lstm'](x)
x = self.physnet_lstm['spatial_global_avgpool'](x)
x = x.reshape(batch, length, -1, 1, 1)
x = self.physnet_lstm['cov_lstm'](x)
# x = x.reshape(batch, channel, length, 1, 1)
x = torch.permute(x[0][0], (0, 2, 1, 3, 4))
x,_ = self.physnet_lstm['cov_lstm'](x)
x = torch.permute(x[0], (0, 2, 1, 3, 4))
x = self.physnet_lstm['cnn_flatten'](x)
return x.reshape(-1, length)
return x.view(-1, length)
4 changes: 2 additions & 2 deletions params.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
"train_ratio_comment" : "generate train dataset using train_ratio",
"validation_ratio": 0.9,
"validation_ratio_comment" : "split train dataset using validation_ratio",
"train_batch_size" : 16,
"train_batch_size" : 32,
"train_batch_size_comment" :
[
"PhysNet_LSTM : 8"
],
"train_shuffle" : 0,
"test_batch_size" : 16,
"test_batch_size" : 32,
"test_shuffle" : 0
},
"hyper_params":
Expand Down

0 comments on commit 666b4a4

Please sign in to comment.