Skip to content

Commit

Permalink
Update create_dataset.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Linfeng-Tang authored Jun 1, 2023
1 parent 4d17939 commit 4f4881c
Showing 1 changed file with 2 additions and 79 deletions.
81 changes: 2 additions & 79 deletions create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,76 +11,6 @@
from utils import randrot,randfilp
from natsort import natsorted


class RandomScaleCrop(object):
"""
Credit to Jialong Wu from https://github.com/lorenmt/mtan/issues/34.
"""
def __init__(self, scale=[1.0, 1.2, 1.5]):
self.scale = scale

def __call__(self, img, label, depth, normal):
height, width = img.shape[-2:]
sc = self.scale[random.randint(0, len(self.scale) - 1)]
h, w = int(height / sc), int(width / sc)
i = random.randint(0, height - h)
j = random.randint(0, width - w)
img_ = F.interpolate(img[None, :, i:i + h, j:j + w], size=(height, width), mode='bilinear', align_corners=True).squeeze(0)
label_ = F.interpolate(label[None, None, i:i + h, j:j + w], size=(height, width), mode='nearest').squeeze(0).squeeze(0)
depth_ = F.interpolate(depth[None, :, i:i + h, j:j + w], size=(height, width), mode='nearest').squeeze(0)
normal_ = F.interpolate(normal[None, :, i:i + h, j:j + w], size=(height, width), mode='bilinear', align_corners=True).squeeze(0)
return img_, label_, depth_ / sc, normal_


class NYUv2(Dataset):
"""
We could further improve the performance with the data augmentation of NYUv2 defined in:
[1] PAD-Net: Multi-Tasks Guided Prediction-and-Distillation Network for Simultaneous Depth Estimation and Scene Parsing
[2] Pattern affinitive propagation across depth, surface normal and semantic segmentation
[3] Mti-net: Multiscale task interaction networks for multi-task learning
1. Random scale in a selected raio 1.0, 1.2, and 1.5.
2. Random horizontal flip.
Please note that: all baselines and MTAN did NOT apply data augmentation in the original paper.
"""
def __init__(self, root, train=True, augmentation=False):
self.train = train
self.root = os.path.expanduser(root)
self.augmentation = augmentation

# read the data file
if train:
self.data_path = root + '/train'
else:
self.data_path = root + '/val'

# calculate data length
self.data_len = len(fnmatch.filter(os.listdir(self.data_path + '/image'), '*.npy'))

def __getitem__(self, index):
# load data from the pre-processed npy files
image = torch.from_numpy(np.moveaxis(np.load(self.data_path + '/image/{:d}.npy'.format(index)), -1, 0))
semantic = torch.from_numpy(np.load(self.data_path + '/label/{:d}.npy'.format(index)))
depth = torch.from_numpy(np.moveaxis(np.load(self.data_path + '/depth/{:d}.npy'.format(index)), -1, 0))
normal = torch.from_numpy(np.moveaxis(np.load(self.data_path + '/normal/{:d}.npy'.format(index)), -1, 0))

# apply data augmentation if required
if self.augmentation:
image, semantic, depth, normal = RandomScaleCrop()(image, semantic, depth, normal)
if torch.rand(1) < 0.5:
image = torch.flip(image, dims=[2])
semantic = torch.flip(semantic, dims=[1])
depth = torch.flip(depth, dims=[2])
normal = torch.flip(normal, dims=[2])
normal[0, :, :] = - normal[0, :, :]

return image.float(), semantic.float(), depth.float(), normal.float()

def __len__(self):
return self.data_len


class MSRSData(torch.utils.data.Dataset):
"""
Load dataset with infrared folder path and visible folder path
Expand All @@ -100,14 +30,7 @@ def __init__(self, opts, is_train=True, crop=lambda x: x):
else:
self.vis_folder = os.path.join(opts.dataroot, 'test', 'vi')
self.ir_folder = os.path.join(opts.dataroot, 'test', 'ir')
self.label_folder = os.path.join(opts.dataroot, 'test', 'label')
# self.vis_folder = os.path.join(opts.dataroot, 'train', 'vi')
# self.ir_folder = os.path.join(opts.dataroot, 'train', 'ir')
# self.label_folder = os.path.join(opts.dataroot, 'train', 'label')

# self.vis_folder = '/data/timer/Segmentation/SegFormer/datasets/MSRS/RGB'
# self.ir_folder = '/data/timer/Segmentation/SegFormer/datasets/MSRS/Thermal'
# self.label_folder = '/data/timer/Segmentation/SegFormer/datasets/MSRS/Label'
self.label_folder = os.path.join(opts.dataroot, 'test', 'label')
self.crop = torchvision.transforms.RandomCrop(256)
# gain infrared and visible images list
self.ir_list = natsorted(os.listdir(self.label_folder))
Expand Down Expand Up @@ -236,4 +159,4 @@ def imread(path, label=False, vis_flage=True):
new_height = height - (height % 32)
img = img.resize((new_width, new_height))
im_ts = TF.to_tensor(img).unsqueeze(0)
return im_ts, width, height
return im_ts, width, height

0 comments on commit 4f4881c

Please sign in to comment.