-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathisicloader.py
103 lines (75 loc) · 3.29 KB
/
isicloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import sys
import pickle
import cv2
from skimage import io
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
import pandas as pd
from skimage.transform import rotate
from torchvision.utils import save_image
IMG_FOLDER_NAME = "A"
IMG_POST_FOLDER_NAME = 'B'
LIST_FOLDER_NAME = 'list'
ANNOT_FOLDER_NAME = "label"
label_suffix = ".png"
def load_img_name_list(dataset_path):
img_name_list = np.loadtxt(dataset_path, dtype=np.str)
if img_name_list.ndim == 2:
return img_name_list[:, 0]
return img_name_list
def get_img_post_path(root_dir,img_name):
return os.path.join(root_dir, IMG_POST_FOLDER_NAME, img_name)
def get_img_path(root_dir, img_name):
return os.path.join(root_dir, IMG_FOLDER_NAME, img_name)
def get_label_path(root_dir, img_name):
return os.path.join(root_dir, ANNOT_FOLDER_NAME, img_name) #.replace('.jpg', label_suffix)
class ISICDataset(Dataset):
# def __init__(self, args, data_path , transform = None, mode = 'Training',plane = False):
def __init__(self, args, data_path, transform = None, mode = 'train', plane=False):
self.root_dir = "/media/lscsc/nas/yihan/ddpm_1/GCD/WHU"
self.list_path = os.path.join(self.root_dir, LIST_FOLDER_NAME, self.split+'.txt')
self.img_name_list = load_img_name_list(self.list_path)
self.dataset_len = len(self.img_name_list)
# if self.data_len <= 0:
# self.data_len = self.dataset_len
# else:
self.data_len = self.dataset_len
self.transform = transform
def __len__(self):
return self.data_len
def __getitem__(self, index):
"""Get the images"""
A_path = get_img_path(self.root_dir, self.img_name_list[index % self.data_len])
B_path = get_img_post_path(self.root_dir, self.img_name_list[index % self.data_len])
img_A = Image.open(A_path).convert("RGB")
img_B = Image.open(B_path).convert("RGB")
L_path = get_label_path(self.root_dir, self.img_name_list[index % self.data_len])
img_lbl = Image.open(L_path).convert("L")
# name = self.name_list[index]+'.jpg'
# img_path = os.path.join(self.data_path, 'ISBI2016_ISIC_Part3B_'+ self.mode +'_Data',name)
# mask_name = name.split('.')[0] + '_Segmentation.png'
# msk_path = os.path.join(self.data_path, 'ISBI2016_ISIC_Part3B_'+ self.mode +'_Data',mask_name)
# img = Image.open(img_path).convert('RGB')
# mask = Image.open(msk_path).convert('L')
# if self.mode == 'Training':
# label = 0 if self.label_list[index] == 'benign' else 1
# else:
# label = int(self.label_list[index])
if self.transform:
state = torch.get_rng_state()
img_A = self.transform(img_A)
img_B = self.transform(img_B)
torch.set_rng_state(state)
img_lbl = self.transform(img_lbl)
name = self.img_name_list[index % self.data_len]
print(name)
# if self.mode == 'Training':
return (img_A, img_B, img_lbl, self.img_name_list[index % self.data_len])
# else:
# return (img, mask, name)