-
Notifications
You must be signed in to change notification settings - Fork 6
/
dataset.py
52 lines (40 loc) · 1.46 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import numpy as np
import torch
import torch.utils.data as data
from torchvision import datasets, transforms
import os
from PIL import Image, ImageOps
import cv2
import os
import cv2
import numpy as np
import torch
import torch.utils.data
# Create a pytorch dataset class for the data and use albumentations for data augmentation
class Dataset(torch.utils.data.Dataset):
def __init__(self, images_list, masks_list, transform=None):
self.images_list = images_list
self.masks_list = masks_list
self.transform = transform
def __len__(self):
return len(self.images_list)
def __getitem__(self, idx):
image = cv2.imread(self.images_list[idx])
image = cv2.resize(image, (224, 224))
image = np.array(image)
mask = cv2.imread(self.masks_list[idx], cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask, (224, 224))
mask = np.array(mask)
if self.transform:
augmented = self.transform(image=image, mask=mask)
image = augmented['image']
mask = augmented['mask']
#print("Augmented image", image.shape)
#print("Augmented mask", mask.shape)
image = image.transpose(2, 0, 1)
mask = np.expand_dims(mask, axis=2)
mask = mask.transpose(2, 0, 1)
image = image/255.0
mask = mask/255.0
# return image, mask, self.images_list[idx]
return image, mask