Skip to content

Commit

Permalink
Update O-Haze train
Browse files Browse the repository at this point in the history
  • Loading branch information
jiaqixuac committed Oct 9, 2021
1 parent cfbc89d commit 04cd00b
Show file tree
Hide file tree
Showing 8 changed files with 506 additions and 64 deletions.
20 changes: 11 additions & 9 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os
import os.path

import random
import numpy as np
from PIL import Image
import scipy.io as sio

import torch
import torch.utils.data as data
import scipy.io as sio
from PIL import Image
from torchvision.transforms import ToTensor
import random
from torchvision import transforms
from torchvision.transforms import ToTensor

to_tensor = ToTensor()

Expand Down Expand Up @@ -44,10 +46,10 @@ def make_dataset_ots(root):

def make_dataset_ohaze(root: str, mode: str):
img_list = []
for img_name in os.listdir(os.path.join(root, mode, 'img')):
for img_name in os.listdir(os.path.join(root, mode, 'hazy')):
gt_name = img_name.replace('hazy', 'GT')
assert os.path.exist(os.path.join(root, mode, 'gt', gt_name))
img_list.append([os.path.join(root, mode, 'img', img_name),
assert os.path.exists(os.path.join(root, mode, 'gt', gt_name))
img_list.append([os.path.join(root, mode, 'hazy', img_name),
os.path.join(root, mode, 'gt', gt_name)])
return img_list

Expand Down Expand Up @@ -259,10 +261,10 @@ def __init__(self, root, mode):
self.imgs = make_dataset_ohaze(root, mode)

def __getitem__(self, index):
img_path, gt_path = self.imgs[index]
haze_path, gt_path = self.imgs[index]
name = os.path.splitext(os.path.split(haze_path)[1])[0]

img = Image.open(img_path).convert('RGB')
img = Image.open(haze_path).convert('RGB')
gt = Image.open(gt_path).convert('RGB')

if 'train' in self.mode:
Expand Down
232 changes: 207 additions & 25 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,10 @@
class Base(nn.Module):
def __init__(self):
super(Base, self).__init__()
self.mean = torch.zeros(1, 3, 1, 1)
self.std = torch.zeros(1, 3, 1, 1)
self.mean[0, 0, 0, 0] = 0.485
self.mean[0, 1, 0, 0] = 0.456
self.mean[0, 2, 0, 0] = 0.406
self.std[0, 0, 0, 0] = 0.229
self.std[0, 1, 0, 0] = 0.224
self.std[0, 2, 0, 0] = 0.225

self.mean = nn.Parameter(self.mean)
self.std = nn.Parameter(self.std)
self.mean.requires_grad = False
self.std.requires_grad = False
rgb_mean = (0.485, 0.456, 0.406)
self.mean = nn.Parameter(torch.Tensor(rgb_mean).view(1, 3, 1, 1), requires_grad=False)
rgb_std = (0.229, 0.224, 0.225)
self.std = nn.Parameter(torch.Tensor(rgb_std).view(1, 3, 1, 1), requires_grad=False)


class BaseA(nn.Module):
Expand Down Expand Up @@ -63,18 +54,15 @@ def __init__(self):
class Base_OHAZE(nn.Module):
def __init__(self):
super(Base_OHAZE, self).__init__()
self.mean = torch.zeros(1, 3, 1, 1)
self.std = torch.zeros(1, 3, 1, 1)
self.mean[0, 0, 0, 0] = 0.47421
self.mean[0, 1, 0, 0] = 0.50878
self.mean[0, 2, 0, 0] = 0.56789
self.std[0, 0, 0, 0] = 0.10168
self.std[0, 1, 0, 0] = 0.10488
self.std[0, 2, 0, 0] = 0.11524
self.mean = nn.Parameter(self.mean)
self.std = nn.Parameter(self.std)
self.mean.requires_grad = False
self.std.requires_grad = False
rgb_mean = (0.47421, 0.50878, 0.56789)
self.mean_in = nn.Parameter(torch.Tensor(rgb_mean).view(1, 3, 1, 1), requires_grad=False)
rgb_std = (0.10168, 0.10488, 0.11524)
self.std_in = nn.Parameter(torch.Tensor(rgb_std).view(1, 3, 1, 1), requires_grad=False)

rgb_mean = (0.35851, 0.35316, 0.34425)
self.mean_out = nn.Parameter(torch.Tensor(rgb_mean).view(1, 3, 1, 1), requires_grad=False)
rgb_std = (0.16391, 0.16174, 0.17148)
self.std_out = nn.Parameter(torch.Tensor(rgb_std).view(1, 3, 1, 1), requires_grad=False)


class J0(Base):
Expand Down Expand Up @@ -940,3 +928,197 @@ def forward(self, x0, x0_hd=None):
return x_fusion, x_phy, x_j1, x_j2, x_j3, x_j4, t, a.view(x.size(0), -1)
else:
return x_fusion


class DM2FNet_woPhy(Base_OHAZE):
def __init__(self, num_features=64, arch='resnext101_32x8d'):
super(DM2FNet_woPhy, self).__init__()
self.num_features = num_features

# resnext = ResNeXt101Syn()
# self.layer0 = resnext.layer0
# self.layer1 = resnext.layer1
# self.layer2 = resnext.layer2
# self.layer3 = resnext.layer3
# self.layer4 = resnext.layer4

assert arch in ['resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
backbone = models.__dict__[arch](pretrained=True)
del backbone.fc
self.backbone = backbone

self.down0 = nn.Sequential(
nn.Conv2d(64, num_features, kernel_size=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU()
)
self.down1 = nn.Sequential(
nn.Conv2d(256, num_features, kernel_size=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU()
)
self.down2 = nn.Sequential(
nn.Conv2d(512, num_features, kernel_size=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU()
)
self.down3 = nn.Sequential(
nn.Conv2d(1024, num_features, kernel_size=1), nn.SELU()
)
self.down4 = nn.Sequential(
nn.Conv2d(2048, num_features, kernel_size=1), nn.SELU()
)

self.fuse3 = nn.Sequential(
nn.Conv2d(num_features * 2, num_features, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)
)
self.fuse2 = nn.Sequential(
nn.Conv2d(num_features * 2, num_features, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)
)
self.fuse1 = nn.Sequential(
nn.Conv2d(num_features * 2, num_features, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)
)
self.fuse0 = nn.Sequential(
nn.Conv2d(num_features * 2, num_features, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)
)

self.fuse3_attention = nn.Sequential(
nn.Conv2d(num_features * 2, num_features, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=1), nn.Sigmoid()
)
self.fuse2_attention = nn.Sequential(
nn.Conv2d(num_features * 2, num_features, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=1), nn.Sigmoid()
)
self.fuse1_attention = nn.Sequential(
nn.Conv2d(num_features * 2, num_features, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=1), nn.Sigmoid()
)
self.fuse0_attention = nn.Sequential(
nn.Conv2d(num_features * 2, num_features, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features, num_features, kernel_size=1), nn.Sigmoid()
)

self.p0 = nn.Sequential(
nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features // 2, 3, kernel_size=1)
)
self.p1 = nn.Sequential(
nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features // 2, 3, kernel_size=1)
)
self.p2_0 = nn.Sequential(
nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features // 2, 3, kernel_size=1)
)
self.p2_1 = nn.Sequential(
nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features // 2, 3, kernel_size=1)
)
self.p3_0 = nn.Sequential(
nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features // 2, 3, kernel_size=1)
)
self.p3_1 = nn.Sequential(
nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features // 2, 3, kernel_size=1)
)

self.attentional_fusion = nn.Sequential(
nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features // 2, num_features // 2, kernel_size=3, padding=1), nn.SELU(),
nn.Conv2d(num_features // 2, 12, kernel_size=3, padding=1)
)

# self.vgg = VGGF()

for m in self.modules():
if isinstance(m, nn.SELU) or isinstance(m, nn.ReLU):
m.inplace = True

def forward(self, x0):
x = (x0 - self.mean_in) / self.std_in

backbone = self.backbone

layer0 = backbone.conv1(x)
layer0 = backbone.bn1(layer0)
layer0 = backbone.relu(layer0)
layer0 = backbone.maxpool(layer0)

layer1 = backbone.layer1(layer0)
layer2 = backbone.layer2(layer1)
layer3 = backbone.layer3(layer2)
layer4 = backbone.layer4(layer3)

down0 = self.down0(layer0)
down1 = self.down1(layer1)
down2 = self.down2(layer2)
down3 = self.down3(layer3)
down4 = self.down4(layer4)

down4 = F.upsample(down4, size=down3.size()[2:], mode='bilinear')
fuse3_attention = self.fuse3_attention(torch.cat((down4, down3), 1))
f = down4 + self.fuse3(torch.cat((down4, fuse3_attention * down3), 1))

f = F.upsample(f, size=down2.size()[2:], mode='bilinear')
fuse2_attention = self.fuse2_attention(torch.cat((f, down2), 1))
f = f + self.fuse2(torch.cat((f, fuse2_attention * down2), 1))

f = F.upsample(f, size=down1.size()[2:], mode='bilinear')
fuse1_attention = self.fuse1_attention(torch.cat((f, down1), 1))
f = f + self.fuse1(torch.cat((f, fuse1_attention * down1), 1))

f = F.upsample(f, size=down0.size()[2:], mode='bilinear')
fuse0_attention = self.fuse0_attention(torch.cat((f, down0), 1))
f = f + self.fuse0(torch.cat((f, fuse0_attention * down0), 1))

log_x0 = torch.log(x0.clamp(min=1e-8))
log_log_x0_inverse = torch.log(torch.log(1 / x0.clamp(min=1e-8, max=(1 - 1e-8))))

x_p0 = torch.exp(log_x0 + F.upsample(self.p0(f), size=x0.size()[2:], mode='bilinear')).clamp(min=0, max=1)

x_p1 = ((x + F.upsample(self.p1(f), size=x0.size()[2:], mode='bilinear')) * self.std_out + self.mean_out)\
.clamp(min=0., max=1.)

log_x_p2_0 = torch.log(
((x + F.upsample(self.p2_0(f), size=x0.size()[2:], mode='bilinear')) * self.std_out + self.mean_out)
.clamp(min=1e-8))
x_p2 = torch.exp(log_x_p2_0 + F.upsample(self.p2_1(f), size=x0.size()[2:], mode='bilinear'))\
.clamp(min=0., max=1.)

log_x_p3_0 = torch.exp(log_log_x0_inverse + F.upsample(self.p3_0(f), size=x0.size()[2:], mode='bilinear'))
x_p3 = torch.exp(-log_x_p3_0 + F.upsample(self.p3_1(f), size=x0.size()[2:], mode='bilinear')).clamp(min=0,
max=1)

attention_fusion = F.upsample(self.attentional_fusion(f), size=x0.size()[2:], mode='bilinear')
x_fusion = torch.cat((torch.sum(F.softmax(attention_fusion[:, : 4, :, :], 1) * torch.stack(
(x_p0[:, 0, :, :], x_p1[:, 0, :, :], x_p2[:, 0, :, :], x_p3[:, 0, :, :]), 1), 1, True),
torch.sum(F.softmax(attention_fusion[:, 4: 8, :, :], 1) * torch.stack((x_p0[:, 1, :, :],
x_p1[:, 1, :, :],
x_p2[:, 1, :, :],
x_p3[:, 1, :, :]),
1), 1, True),
torch.sum(F.softmax(attention_fusion[:, 8:, :, :], 1) * torch.stack((x_p0[:, 2, :, :],
x_p1[:, 2, :, :],
x_p2[:, 2, :, :],
x_p3[:, 2, :, :]),
1), 1, True)),
1).clamp(min=0, max=1)

if self.training:
return x_fusion, x_p0, x_p1, x_p2, x_p3
else:
return x_fusion
Loading

0 comments on commit 04cd00b

Please sign in to comment.