diff --git a/datasets.py b/datasets.py index a5e59aa..d4753f5 100644 --- a/datasets.py +++ b/datasets.py @@ -42,6 +42,16 @@ def make_dataset_ots(root): return items +def make_dataset_ohaze(root: str, mode: str): + img_list = [] + for img_name in os.listdir(os.path.join(root, mode, 'img')): + gt_name = img_name.replace('hazy', 'GT') + assert os.path.exist(os.path.join(root, mode, 'gt', gt_name)) + img_list.append([os.path.join(root, mode, 'img', img_name), + os.path.join(root, mode, 'gt', gt_name)]) + return img_list + + def make_dataset_oihaze_train(root, suffix): items = [] for img_name in os.listdir(os.path.join(root, 'haze' + suffix)): @@ -242,6 +252,34 @@ def __len__(self): return len(self.imgs) +class OHazeDataset(data.Dataset): + def __init__(self, root, mode): + self.root = root + self.mode = mode + self.imgs = make_dataset_ohaze(root, mode) + + def __getitem__(self, index): + img_path, gt_path = self.imgs[index] + name = os.path.splitext(os.path.split(haze_path)[1])[0] + + img = Image.open(img_path).convert('RGB') + gt = Image.open(gt_path).convert('RGB') + + if 'train' in self.mode: + # img, gt = random_crop(416, img, gt) + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + gt = gt.transpose(Image.FLIP_LEFT_RIGHT) + + rotate_degree = np.random.choice([-90, 0, 90, 180]) + img, gt = img.rotate(rotate_degree, Image.BILINEAR), gt.rotate(rotate_degree, Image.BILINEAR) + + return to_tensor(img), to_tensor(gt), name + + def __len__(self): + return len(self.imgs) + + class OIHaze(data.Dataset): def __init__(self, root, mode, suffix=None, flip=False, crop=None): assert mode in ['train', 'test'] diff --git a/train.py b/train.py index b75a11e..adee1ba 100644 --- a/train.py +++ b/train.py @@ -30,6 +30,7 @@ def parse_args(): cfgs = { + 'use_physical': True, 'iter_num': 40000, 'train_batch_size': 16, 'last_iter': 0,