Skip to content

Commit

Permalink
add MTTS
Browse files Browse the repository at this point in the history
  • Loading branch information
DaeyeolKim committed Aug 2, 2021
1 parent 61fcd27 commit d0d2dd7
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
- [x] [DeepPhys : DeepPhys: Video-Based Physiological Measurement Using Convolutional Attention Networks](https://arxiv.org/abs/1805.07888)
- [ ] [MTTS :Multi-Task Temporal Shift Attention Networks for
On-Device Contactless Vitals Measurement](https://papers.nips.cc/paper/2020/file/e1228be46de6a0234ac22ded31417bc7-Paper.pdf)
+ need to verification
- [ ] DeepPhys + LSTM
- [x] [3D physNet : Remote Photoplethysmograph Signal Measurement from Facial Videos Using Spatio-Temporal Networks](https://arxiv.org/abs/1905.02419)

Expand Down
34 changes: 34 additions & 0 deletions dataset/MTTSDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset


class MTTSDataset(Dataset):
def __init__(self, appearance_data, motion_data, hr_target,rr_target):
self.transform = transforms.Compose([transforms.ToTensor()])
self.a = appearance_data
self.m = motion_data
self.hr_label = hr_target.reshape(-1, 1)
self.rr_label = rr_target.reshape(-1, 1)

def __getitem__(self, index):
if torch.is_tensor(index):
index = index.tolist()

appearance_data = torch.tensor(np.transpose(self.a[index], (2, 0, 1)), dtype=torch.float32)
motion_data = torch.tensor(np.transpose(self.m[index], (2, 0, 1)), dtype=torch.float32)
hr_target = torch.tensor(self.hr_label[index], dtype=torch.float32)
rr_target = torch.tensor(self.rr_label[index], dtype=torch.float32)

inputs = torch.stack([appearance_data,motion_data],dim=0)
targets = torch.stack([hr_target,rr_target],dim=0)

if torch.cuda.is_available():
inputs = inputs.to('cuda')
targets = targets.to('cuda')

return inputs, targets

def __len__(self):
return len(self.hr_label)
4 changes: 2 additions & 2 deletions nets/blocks/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ def __call__(self, input, n_frame=4, fold_div=3):


class TSM_Block(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
def __init__(self, in_channels, out_channels, kernel_size,padding):
super().__init__()
self.tsm1 = TSM()
self.t_conv1 = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
padding=1)
padding=padding)

def forward(self, input, n_frame=2, fold_div=3):
t = self.tsm1(input, n_frame, fold_div)
Expand Down
42 changes: 42 additions & 0 deletions nets/models/MTTS.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch

from nets.models.sub_models.AppearanceModel import AppearanceModel_2D
from nets.models.sub_models.LinearModel import LinearModel
from nets.models.sub_models.MotionModel import MotionModel_TS


class MTTS(torch.nn.Module):
def __init__(self):
super().__init__()
self.in_channels = 3
self.out_channels = 32
self.kernel_size = 3
self.attention_mask1 = None
self.attention_mask2 = None

self.appearance_model = AppearanceModel_2D(in_channels=self.in_channels, out_channels=self.out_channels,
kernel_size=self.kernel_size)
self.motion_model = MotionModel_TS(in_channels=self.in_channels, out_channels=self.out_channels,
kernel_size=self.kernel_size)

self.hr_linear_model = LinearModel()
self.rr_linear_model = LinearModel()

def forward(self, inputs):
"""
:param inputs:
inputs[0] : appearance_input
inputs[1] : motion_input
:return:
original 2d model
"""
inputs = torch.chunk(inputs,2,dim=1)
self.attention_mask1, self.attention_mask2 = self.appearance_model(torch.squeeze(inputs[0],1))
motion_output = self.motion_model(torch.squeeze(inputs[1],1), self.attention_mask1, self.attention_mask2)
hr_out = self.linear_model(motion_output)
rr_out = self.linear_model(motion_output)

return [hr_out,rr_out]

def get_attention_mask(self):
return self.attention_mask1, self.attention_mask2
39 changes: 38 additions & 1 deletion nets/models/sub_models/MotionModel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch

from nets.blocks.blocks import TSM_Block


class MotionModel(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
Expand Down Expand Up @@ -29,7 +31,7 @@ def forward(self, inputs, mask1, mask2):
M2 = self.m_batch_Normalization2(self.m_conv2(M1))
# element wise multiplication Mask1
ones = torch.ones(size=M2.shape).to('cuda')
g1 = torch.tanh(torch.mul(ones@mask1, M2))
g1 = torch.tanh(torch.mul(ones @ mask1, M2))
M3 = self.m_dropout1(g1)
# pooling
M4 = self.m_avg1(M3)
Expand All @@ -43,3 +45,38 @@ def forward(self, inputs, mask1, mask2):
out = torch.tanh(M8)

return out


class MotionModel_TS(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super().__init__()
# Motion model
self.m_tsm_conv1 = TSM_Block(in_channels, out_channels, padding='same')
self.m_tsm_conv2 = TSM_Block(out_channels, out_channels, padding='valid')
self.m_avg1 = torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
self.m_dropout1 = torch.nn.Dropout2d(p=0.50)
self.m_tsm_conv3 = TSM_Block(out_channels, out_channels * 2, padding='same')
self.m_tsm_conv4 = TSM_Block(out_channels * 2, out_channels * 2, padding='valid')
self.m_avg2 = torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
self.m_dropout2 = torch.nn.Dropout2d(p=0.50)


def forward(self, inputs, mask1, mask2):
M1 = torch.tanh(self.m_tsm_conv1(inputs))
M2 = torch.tanh(self.m_tsm_conv2(M1))
# element wise multiplication Mask1
ones = torch.ones(size=M2.shape).to('cuda')
g1 = torch.tanh(torch.mul(ones @ mask1, M2))
M3 = self.m_avg1(g1)
# pooling
M4 = self.m_dropout1(M3)
# g1 = torch.tanh(torch.mul(1 * mask1, M4))
M5 = torch.tanh(self.m_tsm_conv3(M4))
M6 = torch.tanh(self.m_tsm_conv4(M5))
# element wise multiplication Mask2
g2 = torch.tanh(torch.mul(1 * mask2, M6))
M7 = self.m_avg2(g2)
M8 = self.m_dropout2(M7)
out = torch.tanh(M8)

return out

0 comments on commit d0d2dd7

Please sign in to comment.