Skip to content

Commit

Permalink
Update Code
Browse files Browse the repository at this point in the history
  • Loading branch information
jiaqixuac committed Sep 29, 2021
1 parent aff4eff commit 48079d1
Show file tree
Hide file tree
Showing 11 changed files with 39 additions and 923 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ __pycache__/
.ipynb_checkpoints/
.idea
ckpt
data
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ This repo is the implementation of
The dehazing results can be found at
[Google Drive](https://drive.google.com/drive/folders/1ZVBI_3Y2NthVLeK7ODMIB5vRjmN9payF?usp=sharing).

## Installation
## Installation & Preparation

Make sure you have `Python>=3.6` installed on your machine.

**Environment setup:**

1. create conda environment
1. Create conda environment

conda create -n midline
conda activate midline
conda create -n dm2f
conda activate dm2f

2. Install dependencies:

Expand All @@ -30,6 +30,12 @@ Make sure you have `Python>=3.6` installed on your machine.

pip install -r requirements.txt

3. Prepare the dataset (RESIDE)

1. Download the RESIDE dataset from the [official webpage](https://sites.google.com/site/boyilics/website-builder/reside).

2. Make a directory `./data` and create a symbolic link for uncompressed data `./data/RESIDE`.

## Training

1. Set the path of pretrained ResNeXt model in resnext/config.py
Expand Down
50 changes: 12 additions & 38 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,15 @@
# coding: utf-8
import os

root = '/home/jqxu/data/RESIDE'
train_a_root = os.path.join(root, 'TrainA')
test_a_root = os.path.join(root, 'TestA')
test_b_root = os.path.join(root, 'nature')

train_its_root = os.path.join(root, 'ITS_v2')
test_hsts_root = os.path.join(root, 'HSTS', 'synthetic', '')
# test_sots_root = os.path.join(root, 'SOTS', 'nyuhaze500')
test_sots_root = os.path.join(root, 'SOTS', 'outdoor')

ohaze_ori_root = os.path.join(root, 'O-HAZE')
train_ohaze_ori_root = os.path.join(ohaze_ori_root, 'train')
test_ohaze_ori_root = os.path.join(ohaze_ori_root, 'test')

ohaze_resize_root = os.path.join('/media/b3-542/4edbfae9-f11c-447b-b07c-585bc4017092/DataSets/dehaze', 'O-HAZE')
train_ohaze_resize_root = os.path.join(ohaze_resize_root, 'train')
test_ohaze_resize_root = os.path.join(ohaze_resize_root, 'test')

ihaze_resize_root = os.path.join('/media/b3-542/4edbfae9-f11c-447b-b07c-585bc4017092/DataSets/dehaze', 'I-HAZE')
train_ihaze_resize_root = os.path.join(ihaze_resize_root, 'train')
test_ihaze_resize_root = os.path.join(ihaze_resize_root, 'test')

ohaze_hd_root = os.path.join('/media/b3-542/4edbfae9-f11c-447b-b07c-585bc4017092/DataSets/dehaze', 'O-HAZE-HD')
train_ohaze_hd_root = os.path.join(ohaze_hd_root, 'train')
test_ohaze_hd_root = os.path.join(ohaze_hd_root, 'test')

ihaze_hd_root = os.path.join('/media/b3-542/4edbfae9-f11c-447b-b07c-585bc4017092/DataSets/dehaze', 'I-HAZE-HD')
train_ihaze_hd_root = os.path.join(ihaze_hd_root, 'train')
test_ihaze_hd_root = os.path.join(ihaze_hd_root, 'test')

ihaze_root = os.path.join('/media/b3-542/4edbfae9-f11c-447b-b07c-585bc4017092/DataSets/dehaze', 'I-HAZE')
train_ihaze_root = os.path.join(ihaze_root, 'train')
test_ihaze_root = os.path.join(ihaze_root, 'test')

train_ots_root = '/media/b3-542/4edbfae9-f11c-447b-b07c-585bc4017092/DataSets/dehaze/OTS'
test_natural2_root = '/home/b3-542/文档/DataSets/dehaze/natural2'

vgg_path = '/media/b3-542/454BAA0333169FE1/Packages/PyTorch Pretrained/VggNet/vgg16-397923af.pth'
root = os.path.dirname(os.path.abspath(__file__))

# # TestA
# TRAIN_A_ROOT = os.path.join(root, 'TrainA')
# TEST_A_ROOT = os.path.join(root, 'TestA')
# TEST_B_ROOT = os.path.join(root, 'nature')

# RESIDE
TRAIN_ITS_ROOT = os.path.join(root, 'data', 'RESIDE', 'ITS_v2') # ITS
TEST_SOTS_ROOT = os.path.join(root, 'data', 'RESIDE', 'SOTS', 'nyuhaze500') # SOTS indoor
# TEST_SOTS_ROOT = os.path.join(root, 'SOTS', 'outdoor') # SOTS outdoor
# TEST_HSTS_ROOT = os.path.join(root, 'HSTS', 'synthetic') # HSTS
3 changes: 1 addition & 2 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import torch.utils.data as data
import scipy.io as sio
from PIL import Image
from torchvision.transforms import ToTensor, Resize
from torchvision.transforms import ToTensor
import random
from torchvision import transforms
import numpy as np


def make_dataset(root):
Expand Down
9 changes: 3 additions & 6 deletions infer.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
# coding: utf-8
import os
import cv2

import numpy as np
import torch
from PIL import Image
from torch.autograd import Variable
from torchvision import transforms

from config import test_sots_root, test_hsts_root
from misc import check_mkdir, crf_refine
from config import TEST_SOTS_ROOT
from misc import check_mkdir
from model import ours
from datasets import ImageFolder2
from torch.utils.data import DataLoader
from torch import nn
from skimage.metrics import peak_signal_noise_ratio

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
Expand All @@ -27,7 +24,7 @@
'snapshot': 'iter_40000_loss_0.01221_lr_0.000000'
}

to_test = {'SOTS': test_sots_root}
to_test = {'SOTS': TEST_SOTS_ROOT}

to_pil = transforms.ToPILImage()

Expand Down
171 changes: 0 additions & 171 deletions misc.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
import numpy as np
import os
import torch
import random
from torch import nn
from torchvision import models
from config import vgg_path
import torch.nn.functional as F
from math import ceil
# import pydensecrf.densecrf as dcrf


class AvgMeter(object):
Expand All @@ -32,134 +24,6 @@ def check_mkdir(dir_name):
os.mkdir(dir_name)


def _sigmoid(x):
return 1 / (1 + np.exp(-x))


def crf_refine(img, annos):
assert img.dtype == np.uint8
assert annos.dtype == np.uint8
assert img.shape[:2] == annos.shape

# img and annos should be np array with data type uint8

EPSILON = 1e-8

M = 2 # salient or not
tau = 1.05
# Setup the CRF model
d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M)

anno_norm = annos / 255.

n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm))
p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm))

U = np.zeros((M, img.shape[0] * img.shape[1]), dtype='float32')
U[0, :] = n_energy.flatten()
U[1, :] = p_energy.flatten()

d.setUnaryEnergy(U)

d.addPairwiseGaussian(sxy=3, compat=3)
d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5)

# Do the inference
infer = np.array(d.inference(1)).astype('float32')
res = infer[1, :]

res = res * 255
res = res.reshape(img.shape[:2])
return res.astype('uint8')


def sliding_forward(net, x, crop_size=2048):
n, c, h, w = x.size()
if h <= crop_size and w <= crop_size:
return net(x)
else:
result = torch.zeros(n, 3, h, w).cuda()
count = torch.zeros(n, 3, h, w).cuda()
stride = int(crop_size / 3.)

h_steps = 2 + max(h - crop_size, 0) / stride
w_steps = 2 + max(w - crop_size, 0) / stride

for h_idx in range(h_steps):
for w_idx in range(w_steps):
h_slice = slice(h_idx * stride, min(crop_size + h_idx * stride, h))
w_slice = slice(w_idx * stride, min(crop_size + w_idx * stride, w))
if h_idx == h_steps - 1:
h_slice = slice(max(h - crop_size, 0), h)
if w_idx == w_steps - 1:
w_slice = slice(max(w - crop_size, 0), w)
result[:, :, h_slice, w_slice] += net(x[:, :, h_slice, w_slice].contiguous())
count[:, :, h_slice, w_slice] += 1
assert torch.min(count) > 0
result = result / count
return result


def sliding_forward2(net, x, crop_size=(2048, 2048)):
ch, cw = crop_size
n, c, h, w = x.size()
if h <= ch and w <= cw:
return net(x)
else:
result = torch.zeros_like(x).cuda()
count = torch.zeros_like(x).cuda()
stride_h = int(ch / 3.)
stride_w = int(cw / 3.)

h_steps = 2 + max(h - ch, 0) / stride_h
w_steps = 2 + max(w - cw, 0) / stride_w

for h_idx in range(h_steps):
for w_idx in range(w_steps):
h_slice = slice(h_idx * stride_h, min(ch + h_idx * stride_h, h))
w_slice = slice(w_idx * stride_w, min(cw + w_idx * stride_w, w))
if h_idx == h_steps - 1:
h_slice = slice(max(h - ch, 0), h)
if w_idx == w_steps - 1:
w_slice = slice(max(w - cw, 0), w)
result[:, :, h_slice, w_slice] += net(x[:, :, h_slice, w_slice])
count[:, :, h_slice, w_slice] += 1
assert torch.min(count) > 0
result = result / count
return result


def sliding_forward3(net, x, shrink_factor, crop_size=1536):
n, c, h, w = x.size()
if h <= crop_size and w <= crop_size:
x_sm = F.upsample(x, size=(int(h * shrink_factor), int(w * shrink_factor)), mode='bilinear')
return net(x_sm, x)
else:
result = torch.zeros(n, c, h, w).cuda()
count = torch.zeros(n, 1, h, w).cuda()
stride = int(crop_size / 3.)

h_steps = 1 + int(ceil(float(max(h - crop_size, 0)) / stride))
w_steps = 1 + int(ceil(float(max(w - crop_size, 0)) / stride))

for h_idx in range(h_steps):
for w_idx in range(w_steps):
ws0, ws1 = w_idx * stride, crop_size + w_idx * stride
hs0, hs1 = h_idx * stride, crop_size + h_idx * stride
if h_idx == h_steps - 1:
hs0, hs1 = max(h - crop_size, 0), h
if w_idx == w_steps - 1:
ws0, ws1 = max(w - crop_size, 0), w
x_patch = x[:, :, hs0: hs1, ws0: ws1]
patch_h, patch_w = x_patch.size()[2:]
x_patch_sm = F.upsample(x_patch, size=(int(patch_h * shrink_factor), int(patch_w * shrink_factor)), mode='bilinear')
result[:, :, hs0: hs1, ws0: ws1] += net(x_patch_sm, x_patch)
count[:, :, hs0: hs1, ws0: ws1] += 1
assert torch.min(count) > 0
result = result / count
return result


def random_crop(size, x):
h, w = x.size()[2:]
size_h = min(size, h)
Expand All @@ -168,38 +32,3 @@ def random_crop(size, x):
x1 = random.randint(0, w - size_w)
y1 = random.randint(0, h - size_h)
return x[:, :, y1: y1 + size_h, x1: x1 + size_w]


class PerceptualLoss(nn.Module):
def __init__(self, order):
super(PerceptualLoss, self).__init__()
assert order in [1, 2]
vgg = models.vgg16()
vgg.load_state_dict(torch.load(vgg_path))
self.vgg = nn.Sequential(*(list(vgg.features.children())[: 9])).eval()

self.criterion = nn.L1Loss() if order == 1 else nn.MSELoss()

self.mean = torch.zeros(1, 3, 1, 1)
self.std = torch.zeros(1, 3, 1, 1)
self.mean[0, 0, 0, 0] = 0.485
self.mean[0, 1, 0, 0] = 0.456
self.mean[0, 2, 0, 0] = 0.406
self.std[0, 0, 0, 0] = 0.229
self.std[0, 1, 0, 0] = 0.224
self.std[0, 2, 0, 0] = 0.225

self.mean = nn.Parameter(self.mean)
self.std = nn.Parameter(self.std)

for m in self.vgg.modules():
if isinstance(m, nn.ReLU):
m.inplace = True

for param in self.parameters():
param.requires_grad = False

def forward(self, input, target):
input = (input - self.mean) / self.std
target = (target - self.mean) / self.std
return self.criterion(self.vgg(input), self.vgg(target).detach())
16 changes: 8 additions & 8 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,15 +709,15 @@ def forward(self, x0):
x_fusion = torch.cat((torch.sum(F.softmax(attention_fusion[:, : 4, :, :], 1) * torch.stack(
(x_p0[:, 0, :, :], x_p1[:, 0, :, :], x_p2[:, 0, :, :], x_p3[:, 0, :, :]), 1), 1, True),
torch.sum(F.softmax(attention_fusion[:, 4: 8, :, :], 1) * torch.stack((x_p0[:, 1, :, :],
x_p1[:, 1, :, :],
x_p2[:, 1, :, :],
x_p3[:, 1, :, :]),
1), 1, True),
x_p1[:, 1, :, :],
x_p2[:, 1, :, :],
x_p3[:, 1, :, :]),
1), 1, True),
torch.sum(F.softmax(attention_fusion[:, 8:, :, :], 1) * torch.stack((x_p0[:, 2, :, :],
x_p1[:, 2, :, :],
x_p2[:, 2, :, :],
x_p3[:, 2, :, :]),
1), 1, True)),
x_p1[:, 2, :, :],
x_p2[:, 2, :, :],
x_p3[:, 2, :, :]),
1), 1, True)),
1).clamp(min=0, max=1)

if self.training:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ torchvision>=0.9.0
numpy>=1.20.2
Pillow>=8.2.0
scipy>=1.7.1
scikit-image>=0.18.3
Loading

0 comments on commit 48079d1

Please sign in to comment.