-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset_processing.py
66 lines (50 loc) · 1.64 KB
/
dataset_processing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import cv2
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import albumentations as A
import yaml
with open('config.yaml', 'r') as file:
config = yaml.safe_load(file)
CSV_FILE = config['CSV_FILE']
IMG_SIZE = config['IMG_SIZE']
DATA_DIR = config['DATA_DIR']
df = pd.read_csv(CSV_FILE)
train_df,valid_df = train_test_split(df,test_size=0.20,random_state=42)
def get_train_augs():
return A.Compose([
A.Resize(IMG_SIZE,IMG_SIZE),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5)
])
def get_valid_augs():
return A.Compose([
A.Resize(IMG_SIZE,IMG_SIZE)
])
class SegmentationDataset(Dataset):
def __init__(self,df,augmentations):
self.df = df
self.augmentations = augmentations
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = df.iloc[idx]
image_path = DATA_DIR + row.images
mask_path = DATA_DIR + row.masks
image = cv2.imread(image_path)
image = cv2.cvtColor(image , cv2.COLOR_BGR2RGB)
mask = cv2.imread(mask_path,cv2.IMREAD_GRAYSCALE) #(h,w)
mask = np.expand_dims(mask,axis=-1) #(h,w,c)
if self.augmentations:
data = self.augmentations(image=image,mask=mask)
image = data['image']
mask = data['mask']
image = np.transpose(image,(2,0,1)).astype(np.float32) #(c,h,w)
mask = np.transpose(mask,(2,0,1)).astype(np.float32)
image = torch.Tensor(image)/255.0
mask = torch.round(torch.Tensor(mask)/255.0)
return image,mask
trainset = SegmentationDataset(train_df,get_train_augs())
validset = SegmentationDataset(valid_df,get_valid_augs())