-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathdata.py
102 lines (78 loc) · 3.04 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import numpy as np
import cv2
from glob import glob
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from albumentations import HorizontalFlip, VerticalFlip, Rotate
def create_dir(path):
""" Create a directory. """
if not os.path.exists(path):
os.makedirs(path)
def load_data(path, split=0.2):
""" Load the images and masks """
images = sorted(glob(f"{path}/*/image/*.png"))
masks = sorted(glob(f"{path}/*/mask/*.png"))
""" Split the data """
split_size = int(len(images) * split)
train_x, valid_x = train_test_split(images, test_size=split_size, random_state=42)
train_y, valid_y = train_test_split(masks, test_size=split_size, random_state=42)
return (train_x, train_y), (valid_x, valid_y)
def augment_data(images, masks, save_path, augment=True):
""" Performing data augmentation. """
H = 512
W = 512
for idx, (x, y) in tqdm(enumerate(zip(images, masks)), total=len(images)):
""" Extracting the dir name and image name """
dir_name = x.split("/")[-3]
name = dir_name + "_" + x.split("/")[-1].split(".")[0]
""" Read the image and mask """
x = cv2.imread(x, cv2.IMREAD_COLOR)
y = cv2.imread(y, cv2.IMREAD_COLOR)
if augment == True:
aug = HorizontalFlip(p=1.0)
augmented = aug(image=x, mask=y)
x1 = augmented["image"]
y1 = augmented["mask"]
aug = VerticalFlip(p=1)
augmented = aug(image=x, mask=y)
x2 = augmented['image']
y2 = augmented['mask']
aug = Rotate(limit=45, p=1.0)
augmented = aug(image=x, mask=y)
x3 = augmented["image"]
y3 = augmented["mask"]
X = [x, x1, x2, x3]
Y = [y, y1, y2, y3]
else:
X = [x]
Y = [y]
idx = 0
for i, m in zip(X, Y):
i = cv2.resize(i, (W, H))
m = cv2.resize(m, (W, H))
m = m/255.0
m = (m > 0.5) * 255
if len(X) == 1:
tmp_image_name = f"{name}.jpg"
tmp_mask_name = f"{name}.jpg"
else:
tmp_image_name = f"{name}_{idx}.jpg"
tmp_mask_name = f"{name}_{idx}.jpg"
image_path = os.path.join(save_path, "image/", tmp_image_name)
mask_path = os.path.join(save_path, "mask/", tmp_mask_name)
cv2.imwrite(image_path, i)
cv2.imwrite(mask_path, m)
idx += 1
if __name__ == "__main__":
""" Load the dataset """
dataset_path = os.path.join("data", "train")
(train_x, train_y), (valid_x, valid_y) = load_data(dataset_path, split=0.2)
print("Train: ", len(train_x))
print("Valid: ", len(valid_x))
create_dir("new_data/train/image/")
create_dir("new_data/train/mask/")
create_dir("new_data/valid/image/")
create_dir("new_data/valid/mask/")
augment_data(train_x, train_y, "new_data/train/", augment=True)
augment_data(valid_x, valid_y, "new_data/valid/", augment=False)