-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathdata.py
42 lines (32 loc) · 1.69 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from os import listdir
from os.path import join
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image, ImageFilter
def is_image_file(filename):
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
def load_img(filepath):
img = Image.open(filepath).convert('YCbCr')
y, _, _ = img.split()
return y
CROP_SIZE = 32
class DatasetFromFolder(Dataset):
def __init__(self, image_dir, zoom_factor):
super(DatasetFromFolder, self).__init__()
self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)]
crop_size = CROP_SIZE - (CROP_SIZE % zoom_factor) # Valid crop size
self.input_transform = transforms.Compose([transforms.CenterCrop(crop_size), # cropping the image
transforms.Resize(crop_size//zoom_factor), # subsampling the image (half size)
transforms.Resize(crop_size, interpolation=Image.BICUBIC), # bicubic upsampling to get back the original size
transforms.ToTensor()])
self.target_transform = transforms.Compose([transforms.CenterCrop(crop_size), # since it's the target, we keep its original quality
transforms.ToTensor()])
def __getitem__(self, index):
input = load_img(self.image_filenames[index])
target = input.copy()
# input = input.filter(ImageFilter.GaussianBlur(1))
input = self.input_transform(input)
target = self.target_transform(target)
return input, target
def __len__(self):
return len(self.image_filenames)