diff --git a/README.md b/README.md index bc303d6..92e847e 100644 --- a/README.md +++ b/README.md @@ -30,16 +30,18 @@ Make sure you have `Python>=3.7` installed on your machine. pip install -r requirements.txt -3. Prepare the dataset (RESIDE) +* Prepare the dataset - 1. Download the RESIDE dataset from the [official webpage](https://sites.google.com/site/boyilics/website-builder/reside). + * Download the RESIDE dataset from the [official webpage](https://sites.google.com/site/boyilics/website-builder/reside). - 2. Make a directory `./data` and create a symbolic link for uncompressed data `./data/RESIDE`. + * Download the O-Haze dataset from the [official webpage](https://data.vision.ee.ethz.ch/cvl/ntire18//o-haze/). + + * Make a directory `./data` and create a symbolic link for uncompressed data, e.g., `./data/RESIDE`. ## Training -1. Set the path of pretrained ResNeXt model in resnext/config.py -2. Set the path of datasets in config.py +1. ~~Set the path of pretrained ResNeXt model in resnext/config.py~~ +2. Set the path of datasets in tools/config.py 3. Run by ```python train.py``` ~~The pretrained ResNeXt model is ported from the [official](https://github.com/facebookresearch/ResNeXt) torch version, @@ -55,7 +57,7 @@ Training a model on a single ~~GTX 1080Ti~~ TITAN RTX GPU takes about ~~4~~ 5 ho ## Testing -1. Set the path of five benchmark datasets in config.py. +1. Set the path of five benchmark datasets in tools/config.py. 2. Put the trained model in `./ckpt/`. 2. Run by ```python test.py``` diff --git a/requirements.txt b/requirements.txt index bcddc4f..b47a587 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ numpy>=1.20.2 Pillow>=8.2.0 scipy>=1.7.1 scikit-image>=0.18.3 +tqdm>=4.62.3 diff --git a/test.py b/test.py index 132d5f7..07bf333 100644 --- a/test.py +++ b/test.py @@ -5,12 +5,12 @@ import torch from torchvision import transforms -from config import TEST_SOTS_ROOT -from utils import check_mkdir +from tools.config import TEST_SOTS_ROOT +from tools.utils import check_mkdir from model import DM2FNet from datasets import SotsDataset from torch.utils.data import DataLoader -from skimage.metrics import peak_signal_noise_ratio +from skimage.metrics import peak_signal_noise_ratio, structural_similarity os.environ['CUDA_VISIBLE_DEVICES'] = '0' @@ -37,7 +37,7 @@ def main(): net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'))) net.eval() - psnrs = [] + psnrs, ssims = [], [] with torch.no_grad(): for name, root in to_test.items(): @@ -59,13 +59,16 @@ def main(): gt = gts[i].cpu().numpy().transpose([1, 2, 0]) psnr = peak_signal_noise_ratio(gt, r) psnrs.append(psnr) - print('predicting for {} ({}/{}) [{}]: {:.4f}'.format(name, idx + 1, len(dataloader), fs[i], psnr)) + ssim = structural_similarity(gt, r, multichannel=True) + ssims.append(ssim) + print('predicting for {} ({}/{}) [{}]: PSNR {:.4f}, SSIM {:.4f}' + .format(name, idx + 1, len(dataloader), fs[i], psnr, ssim)) for r, f in zip(res.cpu(), fs): to_pil(r).save( os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']), '%s.png' % f)) - print(f"PSNR for {name}: {np.mean(psnrs):.6f}") + print(f"[{name}] PSNR: {np.mean(psnrs):.6f}, SSIM: {np.mean(ssims):.6f}") if __name__ == '__main__': diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/config.py b/tools/config.py similarity index 61% rename from config.py rename to tools/config.py index 73222de..77440f8 100644 --- a/config.py +++ b/tools/config.py @@ -8,8 +8,11 @@ # TEST_A_ROOT = os.path.join(root, 'TestA') # TEST_B_ROOT = os.path.join(root, 'nature') +# O-Haze +OHAZE_ROOT = os.path.join(root, '../data', 'O-Haze') + # RESIDE -TRAIN_ITS_ROOT = os.path.join(root, 'data', 'RESIDE', 'ITS_v2') # ITS -TEST_SOTS_ROOT = os.path.join(root, 'data', 'RESIDE', 'SOTS', 'nyuhaze500') # SOTS indoor +TRAIN_ITS_ROOT = os.path.join(root, '../data', 'RESIDE', 'ITS_v2') # ITS +TEST_SOTS_ROOT = os.path.join(root, '../data', 'RESIDE', 'SOTS', 'nyuhaze500') # SOTS indoor # TEST_SOTS_ROOT = os.path.join(root, 'SOTS', 'outdoor') # SOTS outdoor # TEST_HSTS_ROOT = os.path.join(root, 'HSTS', 'synthetic') # HSTS diff --git a/tools/preprocess_ohaze_data.py b/tools/preprocess_ohaze_data.py new file mode 100644 index 0000000..c25c864 --- /dev/null +++ b/tools/preprocess_ohaze_data.py @@ -0,0 +1,56 @@ +import os +from PIL import Image + +from config import OHAZE_ROOT +from math import ceil +from tqdm import tqdm + +if __name__ == '__main__': + ohaze_root = OHAZE_ROOT + crop_size = 512 + + ori_root = os.path.join(ohaze_root, '# O-HAZY NTIRE 2018') + ori_haze_root = os.path.join(ori_root, 'hazy') + ori_gt_root = os.path.join(ori_root, 'GT') + + patch_root = os.path.join(ohaze_root, 'train_crop_{}'.format(crop_size)) + patch_haze_path = os.path.join(patch_root, 'img') + patch_gt_path = os.path.join(patch_root, 'gt') + + os.makedirs(patch_root, exist_ok=True) + os.makedirs(patch_haze_path, exist_ok=True) + os.makedirs(patch_gt_path, exist_ok=True) + + # first 35 images for training + train_list = [img_name for img_name in os.listdir(ori_haze_root) + if int(img_name.split('_')[0]) <= 35] + + for idx, img_name in enumerate(tqdm(train_list)): + img_f_name, img_l_name = os.path.splitext(img_name) + gt_f_name = '{}GT'.format(img_f_name[: -4]) + + img = Image.open(os.path.join(ori_haze_root, img_name)) + gt = Image.open(os.path.join(ori_gt_root, gt_f_name + img_l_name)) + + assert img.size == gt.size + + w, h = img.size + stride = int(crop_size / 3.) + h_steps = 1 + int(ceil(float(max(h - crop_size, 0)) / stride)) + w_steps = 1 + int(ceil(float(max(w - crop_size, 0)) / stride)) + + for h_idx in range(h_steps): + for w_idx in range(w_steps): + ws0 = w_idx * stride + ws1 = crop_size + ws0 + hs0 = h_idx * stride + hs1 = crop_size + hs0 + if h_idx == h_steps - 1: + hs0, hs1 = max(h - crop_size, 0), h + if w_idx == w_steps - 1: + ws0, ws1 = max(w - crop_size, 0), w + img_crop = img.crop((ws0, hs0, ws1, hs1)) + gt_crop = gt.crop((ws0, hs0, ws1, hs1)) + + img_crop.save(os.path.join(patch_haze_path, '{}_h_{}_w_{}.png'.format(img_f_name, h_idx, w_idx))) + gt_crop.save(os.path.join(patch_gt_path, '{}_h_{}_w_{}.png'.format(gt_f_name, h_idx, w_idx))) diff --git a/utils.py b/tools/utils.py similarity index 76% rename from utils.py rename to tools/utils.py index c4db1d1..d787d2a 100644 --- a/utils.py +++ b/tools/utils.py @@ -4,13 +4,16 @@ class AvgMeter(object): def __init__(self): - self.reset() + self.val = 0. + self.avg = 0. + self.sum = 0. + self.count = 0. def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 + self.val = 0. + self.avg = 0. + self.sum = 0. + self.count = 0. def update(self, val, n=1): self.val = val diff --git a/train.py b/train.py index b16d48d..a61c771 100644 --- a/train.py +++ b/train.py @@ -8,9 +8,9 @@ from torch.backends import cudnn from torch.utils.data import DataLoader -from config import TRAIN_ITS_ROOT, TEST_SOTS_ROOT +from tools.config import TRAIN_ITS_ROOT, TEST_SOTS_ROOT from datasets import ItsDataset, SotsDataset -from utils import AvgMeter, check_mkdir +from tools.utils import AvgMeter, check_mkdir from model import DM2FNet os.environ['CUDA_VISIBLE_DEVICES'] = '0'