Skip to content

Commit

Permalink
Update O-Haze dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
jiaqixuac committed Oct 7, 2021
1 parent 696c557 commit cfbc89d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
38 changes: 38 additions & 0 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ def make_dataset_ots(root):
return items


def make_dataset_ohaze(root: str, mode: str):
img_list = []
for img_name in os.listdir(os.path.join(root, mode, 'img')):
gt_name = img_name.replace('hazy', 'GT')
assert os.path.exist(os.path.join(root, mode, 'gt', gt_name))
img_list.append([os.path.join(root, mode, 'img', img_name),
os.path.join(root, mode, 'gt', gt_name)])
return img_list


def make_dataset_oihaze_train(root, suffix):
items = []
for img_name in os.listdir(os.path.join(root, 'haze' + suffix)):
Expand Down Expand Up @@ -242,6 +252,34 @@ def __len__(self):
return len(self.imgs)


class OHazeDataset(data.Dataset):
def __init__(self, root, mode):
self.root = root
self.mode = mode
self.imgs = make_dataset_ohaze(root, mode)

def __getitem__(self, index):
img_path, gt_path = self.imgs[index]
name = os.path.splitext(os.path.split(haze_path)[1])[0]

img = Image.open(img_path).convert('RGB')
gt = Image.open(gt_path).convert('RGB')

if 'train' in self.mode:
# img, gt = random_crop(416, img, gt)
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
gt = gt.transpose(Image.FLIP_LEFT_RIGHT)

rotate_degree = np.random.choice([-90, 0, 90, 180])
img, gt = img.rotate(rotate_degree, Image.BILINEAR), gt.rotate(rotate_degree, Image.BILINEAR)

return to_tensor(img), to_tensor(gt), name

def __len__(self):
return len(self.imgs)


class OIHaze(data.Dataset):
def __init__(self, root, mode, suffix=None, flip=False, crop=None):
assert mode in ['train', 'test']
Expand Down
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def parse_args():


cfgs = {
'use_physical': True,
'iter_num': 40000,
'train_batch_size': 16,
'last_iter': 0,
Expand Down

0 comments on commit cfbc89d

Please sign in to comment.