Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
cindydeng1991 authored Feb 26, 2021
1 parent c43c0b4 commit 14b9b36
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch.utils.data as data
from glob import glob
import matplotlib.pyplot as plt
from torchvision import transforms
import cv2
from PIL import Image
import random
import os
import numpy as np
import torch

class cudataset(data.Dataset):
def __init__(self):
super(cudataset, self).__init__()
self.depth = np.load('trainset_denoise/data_x_25.npy')#(batch,height,width,c)
#print(self.depth.shape)
self.depth = np.transpose(self.depth, (0, 3, 1, 2))
self.depth_t = torch.from_numpy(self.depth)

self.gt = np.load('trainset_denoise/label_25.npy') # (batch,height,width,c)
self.gt = np.transpose(self.gt, (0, 3, 1, 2))
self.gt_t = torch.from_numpy(self.gt)

self.rgb = np.load('trainset_denoise/data_y_25.npy') # (batch,height,width,c)
self.rgb = np.transpose(self.rgb, (0, 3, 1, 2))
self.rgb_t = torch.from_numpy(self.rgb)

def __getitem__(self, item):
img_depth = self.depth_t[item]
img_gt = self.gt_t[item]
img_rgb = self.rgb_t[item]

return (img_depth, img_gt,img_rgb)

def __len__(self):
return len(self.depth)

if __name__ =='__main__':
dataset=cudataset()
dataloader=data.DataLoader(dataset,batch_size=1)
for b1,(img_L,img_H,img_RGB) in enumerate(dataloader):
print(b1)
print(img_L.shape,img_H.shape,img_RGB.shape)
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch.utils.data as data
from glob import glob
import matplotlib.pyplot as plt
from torchvision import transforms
import cv2
from PIL import Image
import random
import os
import numpy as np
import torch


class cudatatest(data.Dataset):
def __init__(self):
super(cudatatest, self).__init__()
# self.depth = np.load('testset/data_d1.npy',allow_pickle=True)#(batch,height,width,c)
self.depth = np.load('noise_testset/data_n.npy', allow_pickle=True) # (batch,height,width,c)
self.depth = np.transpose(self.depth, (0, 3, 1, 2))
self.depth_t = torch.from_numpy(self.depth)

# self.gt = np.load('testset/data_g1.npy',allow_pickle=True) # (batch,height,width,c)
self.gt = np.load('noise_testset/gt.npy', allow_pickle=True) # (batch,height,width,c)
self.gt = np.transpose(self.gt, (0, 3, 1, 2))
self.gt_t = torch.from_numpy(self.gt)

# self.rgb = np.load('testset/data_c1.npy',allow_pickle=True) # (batch,height,width,c)
self.rgb = np.load('noise_testset/data_f.npy', allow_pickle=True) # (batch,height,width,c)
self.rgb = np.transpose(self.rgb, (0, 3, 1, 2))
self.rgb_t = torch.from_numpy(self.rgb)

def __getitem__(self, item):
img_depth = self.depth_t[item]
img_gt = self.gt_t[item]
img_rgb = self.rgb_t[item]

return (img_depth, img_gt, img_rgb)

def __len__(self):
return len(self.depth)




if __name__ == '__main__':
dataset = cudatatest()
dataloader = data.DataLoader(dataset, batch_size=1)
for b1, (img_L, img_H, img_RGB) in enumerate(dataloader):
print(b1)
print(img_L.shape, img_H.shape, img_RGB.shape)

0 comments on commit 14b9b36

Please sign in to comment.