Skip to content

Commit

Permalink
Update metrics&O-Haze preprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
jiaqixuac committed Oct 7, 2021
1 parent 81e1fe1 commit 456dc79
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 21 deletions.
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,18 @@ Make sure you have `Python>=3.7` installed on your machine.

pip install -r requirements.txt

3. Prepare the dataset (RESIDE)
* Prepare the dataset

1. Download the RESIDE dataset from the [official webpage](https://sites.google.com/site/boyilics/website-builder/reside).
* Download the RESIDE dataset from the [official webpage](https://sites.google.com/site/boyilics/website-builder/reside).

2. Make a directory `./data` and create a symbolic link for uncompressed data `./data/RESIDE`.
* Download the O-Haze dataset from the [official webpage](https://data.vision.ee.ethz.ch/cvl/ntire18//o-haze/).

* Make a directory `./data` and create a symbolic link for uncompressed data, e.g., `./data/RESIDE`.

## Training

1. Set the path of pretrained ResNeXt model in resnext/config.py
2. Set the path of datasets in config.py
1. ~~Set the path of pretrained ResNeXt model in resnext/config.py~~
2. Set the path of datasets in tools/config.py
3. Run by ```python train.py```

~~The pretrained ResNeXt model is ported from the [official](https://github.com/facebookresearch/ResNeXt) torch version,
Expand All @@ -55,7 +57,7 @@ Training a model on a single ~~GTX 1080Ti~~ TITAN RTX GPU takes about ~~4~~ 5 ho

## Testing

1. Set the path of five benchmark datasets in config.py.
1. Set the path of five benchmark datasets in tools/config.py.
2. Put the trained model in `./ckpt/`.
2. Run by ```python test.py```

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ numpy>=1.20.2
Pillow>=8.2.0
scipy>=1.7.1
scikit-image>=0.18.3
tqdm>=4.62.3
15 changes: 9 additions & 6 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import torch
from torchvision import transforms

from config import TEST_SOTS_ROOT
from utils import check_mkdir
from tools.config import TEST_SOTS_ROOT
from tools.utils import check_mkdir
from model import DM2FNet
from datasets import SotsDataset
from torch.utils.data import DataLoader
from skimage.metrics import peak_signal_noise_ratio
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

Expand All @@ -37,7 +37,7 @@ def main():
net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')))

net.eval()
psnrs = []
psnrs, ssims = [], []

with torch.no_grad():
for name, root in to_test.items():
Expand All @@ -59,13 +59,16 @@ def main():
gt = gts[i].cpu().numpy().transpose([1, 2, 0])
psnr = peak_signal_noise_ratio(gt, r)
psnrs.append(psnr)
print('predicting for {} ({}/{}) [{}]: {:.4f}'.format(name, idx + 1, len(dataloader), fs[i], psnr))
ssim = structural_similarity(gt, r, multichannel=True)
ssims.append(ssim)
print('predicting for {} ({}/{}) [{}]: PSNR {:.4f}, SSIM {:.4f}'
.format(name, idx + 1, len(dataloader), fs[i], psnr, ssim))

for r, f in zip(res.cpu(), fs):
to_pil(r).save(
os.path.join(ckpt_path, exp_name,
'(%s) %s_%s' % (exp_name, name, args['snapshot']), '%s.png' % f))
print(f"PSNR for {name}: {np.mean(psnrs):.6f}")
print(f"[{name}] PSNR: {np.mean(psnrs):.6f}, SSIM: {np.mean(ssims):.6f}")


if __name__ == '__main__':
Expand Down
Empty file added tools/__init__.py
Empty file.
7 changes: 5 additions & 2 deletions config.py → tools/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
# TEST_A_ROOT = os.path.join(root, 'TestA')
# TEST_B_ROOT = os.path.join(root, 'nature')

# O-Haze
OHAZE_ROOT = os.path.join(root, '../data', 'O-Haze')

# RESIDE
TRAIN_ITS_ROOT = os.path.join(root, 'data', 'RESIDE', 'ITS_v2') # ITS
TEST_SOTS_ROOT = os.path.join(root, 'data', 'RESIDE', 'SOTS', 'nyuhaze500') # SOTS indoor
TRAIN_ITS_ROOT = os.path.join(root, '../data', 'RESIDE', 'ITS_v2') # ITS
TEST_SOTS_ROOT = os.path.join(root, '../data', 'RESIDE', 'SOTS', 'nyuhaze500') # SOTS indoor
# TEST_SOTS_ROOT = os.path.join(root, 'SOTS', 'outdoor') # SOTS outdoor
# TEST_HSTS_ROOT = os.path.join(root, 'HSTS', 'synthetic') # HSTS
56 changes: 56 additions & 0 deletions tools/preprocess_ohaze_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
from PIL import Image

from config import OHAZE_ROOT
from math import ceil
from tqdm import tqdm

if __name__ == '__main__':
ohaze_root = OHAZE_ROOT
crop_size = 512

ori_root = os.path.join(ohaze_root, '# O-HAZY NTIRE 2018')
ori_haze_root = os.path.join(ori_root, 'hazy')
ori_gt_root = os.path.join(ori_root, 'GT')

patch_root = os.path.join(ohaze_root, 'train_crop_{}'.format(crop_size))
patch_haze_path = os.path.join(patch_root, 'img')
patch_gt_path = os.path.join(patch_root, 'gt')

os.makedirs(patch_root, exist_ok=True)
os.makedirs(patch_haze_path, exist_ok=True)
os.makedirs(patch_gt_path, exist_ok=True)

# first 35 images for training
train_list = [img_name for img_name in os.listdir(ori_haze_root)
if int(img_name.split('_')[0]) <= 35]

for idx, img_name in enumerate(tqdm(train_list)):
img_f_name, img_l_name = os.path.splitext(img_name)
gt_f_name = '{}GT'.format(img_f_name[: -4])

img = Image.open(os.path.join(ori_haze_root, img_name))
gt = Image.open(os.path.join(ori_gt_root, gt_f_name + img_l_name))

assert img.size == gt.size

w, h = img.size
stride = int(crop_size / 3.)
h_steps = 1 + int(ceil(float(max(h - crop_size, 0)) / stride))
w_steps = 1 + int(ceil(float(max(w - crop_size, 0)) / stride))

for h_idx in range(h_steps):
for w_idx in range(w_steps):
ws0 = w_idx * stride
ws1 = crop_size + ws0
hs0 = h_idx * stride
hs1 = crop_size + hs0
if h_idx == h_steps - 1:
hs0, hs1 = max(h - crop_size, 0), h
if w_idx == w_steps - 1:
ws0, ws1 = max(w - crop_size, 0), w
img_crop = img.crop((ws0, hs0, ws1, hs1))
gt_crop = gt.crop((ws0, hs0, ws1, hs1))

img_crop.save(os.path.join(patch_haze_path, '{}_h_{}_w_{}.png'.format(img_f_name, h_idx, w_idx)))
gt_crop.save(os.path.join(patch_gt_path, '{}_h_{}_w_{}.png'.format(gt_f_name, h_idx, w_idx)))
13 changes: 8 additions & 5 deletions utils.py → tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

class AvgMeter(object):
def __init__(self):
self.reset()
self.val = 0.
self.avg = 0.
self.sum = 0.
self.count = 0.

def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
self.val = 0.
self.avg = 0.
self.sum = 0.
self.count = 0.

def update(self, val, n=1):
self.val = val
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from torch.backends import cudnn
from torch.utils.data import DataLoader

from config import TRAIN_ITS_ROOT, TEST_SOTS_ROOT
from tools.config import TRAIN_ITS_ROOT, TEST_SOTS_ROOT
from datasets import ItsDataset, SotsDataset
from utils import AvgMeter, check_mkdir
from tools.utils import AvgMeter, check_mkdir
from model import DM2FNet

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
Expand Down

0 comments on commit 456dc79

Please sign in to comment.