Skip to content

Commit

Permalink
Add RTNetDataset.py
Browse files Browse the repository at this point in the history
  • Loading branch information
DaeyeolKim committed Sep 2, 2021
1 parent bb3c728 commit c5aacf3
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
31 changes: 31 additions & 0 deletions dataset/RTNetDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset


class RTNetDataset(Dataset):
def __init__(self, face_data, mask_data, target):
self.transform = transforms.Compose([transforms.ToTensor()])
self.face = face_data
self.mask = mask_data
self.label = target.reshape(-1, 1)

def __getitem__(self, index):
if torch.is_tensor(index):
index = index.tolist()

appearance_data = torch.tensor(np.transpose(self.face[index], (2, 0, 1)), dtype=torch.float32)
motion_data = torch.tensor(np.transpose(self.mask[index], (2, 0, 1)), dtype=torch.float32)
target = torch.tensor(self.label[index], dtype=torch.float32)

inputs = torch.stack([appearance_data,motion_data],dim=0)

if torch.cuda.is_available():
inputs = inputs.to('cuda')
target = target.to('cuda')

return inputs, target

def __len__(self):
return len(self.label)
15 changes: 15 additions & 0 deletions dataset/dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,19 @@ def dataset_loader(save_root_path: str = "/media/hdd1/dy_dataset/",
dbp=np.asarray(dbp),
hr=np.asarray(hr))

elif model_name in ["RTNet"]:
face_data = []
mask_data = []
target_data = []

for key in hpy_file.keys():
face_data.extend(hpy_file[key]['preprocessed_video'][:, :, :, -3:])
mask_data.extend(hpy_file[key]['preprocessed_video'][:, :, :, :3])
target_data.extend(hpy_file[key]['preprocessed_label'])
hpy_file.close()

dataset = PPNetDataset(face_data=np.asarray(face_data),
mask_data=np.asarray(mask_data),
target=np.asarray(target_data))

return dataset

0 comments on commit c5aacf3

Please sign in to comment.