diff --git a/datasets.py b/datasets.py index d4753f5..ebf7468 100644 --- a/datasets.py +++ b/datasets.py @@ -1,13 +1,15 @@ import os import os.path +import random +import numpy as np +from PIL import Image +import scipy.io as sio + import torch import torch.utils.data as data -import scipy.io as sio -from PIL import Image -from torchvision.transforms import ToTensor -import random from torchvision import transforms +from torchvision.transforms import ToTensor to_tensor = ToTensor() @@ -44,10 +46,10 @@ def make_dataset_ots(root): def make_dataset_ohaze(root: str, mode: str): img_list = [] - for img_name in os.listdir(os.path.join(root, mode, 'img')): + for img_name in os.listdir(os.path.join(root, mode, 'hazy')): gt_name = img_name.replace('hazy', 'GT') - assert os.path.exist(os.path.join(root, mode, 'gt', gt_name)) - img_list.append([os.path.join(root, mode, 'img', img_name), + assert os.path.exists(os.path.join(root, mode, 'gt', gt_name)) + img_list.append([os.path.join(root, mode, 'hazy', img_name), os.path.join(root, mode, 'gt', gt_name)]) return img_list @@ -259,10 +261,10 @@ def __init__(self, root, mode): self.imgs = make_dataset_ohaze(root, mode) def __getitem__(self, index): - img_path, gt_path = self.imgs[index] + haze_path, gt_path = self.imgs[index] name = os.path.splitext(os.path.split(haze_path)[1])[0] - img = Image.open(img_path).convert('RGB') + img = Image.open(haze_path).convert('RGB') gt = Image.open(gt_path).convert('RGB') if 'train' in self.mode: diff --git a/model.py b/model.py index 4998d94..cf88cd5 100644 --- a/model.py +++ b/model.py @@ -9,19 +9,10 @@ class Base(nn.Module): def __init__(self): super(Base, self).__init__() - self.mean = torch.zeros(1, 3, 1, 1) - self.std = torch.zeros(1, 3, 1, 1) - self.mean[0, 0, 0, 0] = 0.485 - self.mean[0, 1, 0, 0] = 0.456 - self.mean[0, 2, 0, 0] = 0.406 - self.std[0, 0, 0, 0] = 0.229 - self.std[0, 1, 0, 0] = 0.224 - self.std[0, 2, 0, 0] = 0.225 - - self.mean = nn.Parameter(self.mean) - self.std = nn.Parameter(self.std) - self.mean.requires_grad = False - self.std.requires_grad = False + rgb_mean = (0.485, 0.456, 0.406) + self.mean = nn.Parameter(torch.Tensor(rgb_mean).view(1, 3, 1, 1), requires_grad=False) + rgb_std = (0.229, 0.224, 0.225) + self.std = nn.Parameter(torch.Tensor(rgb_std).view(1, 3, 1, 1), requires_grad=False) class BaseA(nn.Module): @@ -63,18 +54,15 @@ def __init__(self): class Base_OHAZE(nn.Module): def __init__(self): super(Base_OHAZE, self).__init__() - self.mean = torch.zeros(1, 3, 1, 1) - self.std = torch.zeros(1, 3, 1, 1) - self.mean[0, 0, 0, 0] = 0.47421 - self.mean[0, 1, 0, 0] = 0.50878 - self.mean[0, 2, 0, 0] = 0.56789 - self.std[0, 0, 0, 0] = 0.10168 - self.std[0, 1, 0, 0] = 0.10488 - self.std[0, 2, 0, 0] = 0.11524 - self.mean = nn.Parameter(self.mean) - self.std = nn.Parameter(self.std) - self.mean.requires_grad = False - self.std.requires_grad = False + rgb_mean = (0.47421, 0.50878, 0.56789) + self.mean_in = nn.Parameter(torch.Tensor(rgb_mean).view(1, 3, 1, 1), requires_grad=False) + rgb_std = (0.10168, 0.10488, 0.11524) + self.std_in = nn.Parameter(torch.Tensor(rgb_std).view(1, 3, 1, 1), requires_grad=False) + + rgb_mean = (0.35851, 0.35316, 0.34425) + self.mean_out = nn.Parameter(torch.Tensor(rgb_mean).view(1, 3, 1, 1), requires_grad=False) + rgb_std = (0.16391, 0.16174, 0.17148) + self.std_out = nn.Parameter(torch.Tensor(rgb_std).view(1, 3, 1, 1), requires_grad=False) class J0(Base): @@ -940,3 +928,197 @@ def forward(self, x0, x0_hd=None): return x_fusion, x_phy, x_j1, x_j2, x_j3, x_j4, t, a.view(x.size(0), -1) else: return x_fusion + + +class DM2FNet_woPhy(Base_OHAZE): + def __init__(self, num_features=64, arch='resnext101_32x8d'): + super(DM2FNet_woPhy, self).__init__() + self.num_features = num_features + + # resnext = ResNeXt101Syn() + # self.layer0 = resnext.layer0 + # self.layer1 = resnext.layer1 + # self.layer2 = resnext.layer2 + # self.layer3 = resnext.layer3 + # self.layer4 = resnext.layer4 + + assert arch in ['resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] + backbone = models.__dict__[arch](pretrained=True) + del backbone.fc + self.backbone = backbone + + self.down0 = nn.Sequential( + nn.Conv2d(64, num_features, kernel_size=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU() + ) + self.down1 = nn.Sequential( + nn.Conv2d(256, num_features, kernel_size=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU() + ) + self.down2 = nn.Sequential( + nn.Conv2d(512, num_features, kernel_size=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU() + ) + self.down3 = nn.Sequential( + nn.Conv2d(1024, num_features, kernel_size=1), nn.SELU() + ) + self.down4 = nn.Sequential( + nn.Conv2d(2048, num_features, kernel_size=1), nn.SELU() + ) + + self.fuse3 = nn.Sequential( + nn.Conv2d(num_features * 2, num_features, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=3, padding=1) + ) + self.fuse2 = nn.Sequential( + nn.Conv2d(num_features * 2, num_features, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=3, padding=1) + ) + self.fuse1 = nn.Sequential( + nn.Conv2d(num_features * 2, num_features, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=3, padding=1) + ) + self.fuse0 = nn.Sequential( + nn.Conv2d(num_features * 2, num_features, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=3, padding=1) + ) + + self.fuse3_attention = nn.Sequential( + nn.Conv2d(num_features * 2, num_features, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=1), nn.Sigmoid() + ) + self.fuse2_attention = nn.Sequential( + nn.Conv2d(num_features * 2, num_features, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=1), nn.Sigmoid() + ) + self.fuse1_attention = nn.Sequential( + nn.Conv2d(num_features * 2, num_features, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=1), nn.Sigmoid() + ) + self.fuse0_attention = nn.Sequential( + nn.Conv2d(num_features * 2, num_features, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features, num_features, kernel_size=1), nn.Sigmoid() + ) + + self.p0 = nn.Sequential( + nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features // 2, 3, kernel_size=1) + ) + self.p1 = nn.Sequential( + nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features // 2, 3, kernel_size=1) + ) + self.p2_0 = nn.Sequential( + nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features // 2, 3, kernel_size=1) + ) + self.p2_1 = nn.Sequential( + nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features // 2, 3, kernel_size=1) + ) + self.p3_0 = nn.Sequential( + nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features // 2, 3, kernel_size=1) + ) + self.p3_1 = nn.Sequential( + nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features // 2, 3, kernel_size=1) + ) + + self.attentional_fusion = nn.Sequential( + nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features // 2, num_features // 2, kernel_size=3, padding=1), nn.SELU(), + nn.Conv2d(num_features // 2, 12, kernel_size=3, padding=1) + ) + + # self.vgg = VGGF() + + for m in self.modules(): + if isinstance(m, nn.SELU) or isinstance(m, nn.ReLU): + m.inplace = True + + def forward(self, x0): + x = (x0 - self.mean_in) / self.std_in + + backbone = self.backbone + + layer0 = backbone.conv1(x) + layer0 = backbone.bn1(layer0) + layer0 = backbone.relu(layer0) + layer0 = backbone.maxpool(layer0) + + layer1 = backbone.layer1(layer0) + layer2 = backbone.layer2(layer1) + layer3 = backbone.layer3(layer2) + layer4 = backbone.layer4(layer3) + + down0 = self.down0(layer0) + down1 = self.down1(layer1) + down2 = self.down2(layer2) + down3 = self.down3(layer3) + down4 = self.down4(layer4) + + down4 = F.upsample(down4, size=down3.size()[2:], mode='bilinear') + fuse3_attention = self.fuse3_attention(torch.cat((down4, down3), 1)) + f = down4 + self.fuse3(torch.cat((down4, fuse3_attention * down3), 1)) + + f = F.upsample(f, size=down2.size()[2:], mode='bilinear') + fuse2_attention = self.fuse2_attention(torch.cat((f, down2), 1)) + f = f + self.fuse2(torch.cat((f, fuse2_attention * down2), 1)) + + f = F.upsample(f, size=down1.size()[2:], mode='bilinear') + fuse1_attention = self.fuse1_attention(torch.cat((f, down1), 1)) + f = f + self.fuse1(torch.cat((f, fuse1_attention * down1), 1)) + + f = F.upsample(f, size=down0.size()[2:], mode='bilinear') + fuse0_attention = self.fuse0_attention(torch.cat((f, down0), 1)) + f = f + self.fuse0(torch.cat((f, fuse0_attention * down0), 1)) + + log_x0 = torch.log(x0.clamp(min=1e-8)) + log_log_x0_inverse = torch.log(torch.log(1 / x0.clamp(min=1e-8, max=(1 - 1e-8)))) + + x_p0 = torch.exp(log_x0 + F.upsample(self.p0(f), size=x0.size()[2:], mode='bilinear')).clamp(min=0, max=1) + + x_p1 = ((x + F.upsample(self.p1(f), size=x0.size()[2:], mode='bilinear')) * self.std_out + self.mean_out)\ + .clamp(min=0., max=1.) + + log_x_p2_0 = torch.log( + ((x + F.upsample(self.p2_0(f), size=x0.size()[2:], mode='bilinear')) * self.std_out + self.mean_out) + .clamp(min=1e-8)) + x_p2 = torch.exp(log_x_p2_0 + F.upsample(self.p2_1(f), size=x0.size()[2:], mode='bilinear'))\ + .clamp(min=0., max=1.) + + log_x_p3_0 = torch.exp(log_log_x0_inverse + F.upsample(self.p3_0(f), size=x0.size()[2:], mode='bilinear')) + x_p3 = torch.exp(-log_x_p3_0 + F.upsample(self.p3_1(f), size=x0.size()[2:], mode='bilinear')).clamp(min=0, + max=1) + + attention_fusion = F.upsample(self.attentional_fusion(f), size=x0.size()[2:], mode='bilinear') + x_fusion = torch.cat((torch.sum(F.softmax(attention_fusion[:, : 4, :, :], 1) * torch.stack( + (x_p0[:, 0, :, :], x_p1[:, 0, :, :], x_p2[:, 0, :, :], x_p3[:, 0, :, :]), 1), 1, True), + torch.sum(F.softmax(attention_fusion[:, 4: 8, :, :], 1) * torch.stack((x_p0[:, 1, :, :], + x_p1[:, 1, :, :], + x_p2[:, 1, :, :], + x_p3[:, 1, :, :]), + 1), 1, True), + torch.sum(F.softmax(attention_fusion[:, 8:, :, :], 1) * torch.stack((x_p0[:, 2, :, :], + x_p1[:, 2, :, :], + x_p2[:, 2, :, :], + x_p3[:, 2, :, :]), + 1), 1, True)), + 1).clamp(min=0, max=1) + + if self.training: + return x_fusion, x_p0, x_p1, x_p2, x_p3 + else: + return x_fusion diff --git a/test.py b/test.py index 07bf333..bff3864 100644 --- a/test.py +++ b/test.py @@ -5,10 +5,10 @@ import torch from torchvision import transforms -from tools.config import TEST_SOTS_ROOT -from tools.utils import check_mkdir -from model import DM2FNet -from datasets import SotsDataset +from tools.config import TEST_SOTS_ROOT, OHAZE_ROOT +from tools.utils import check_mkdir, sliding_forward +from model import DM2FNet, DM2FNet_woPhy +from datasets import SotsDataset, OHazeDataset from torch.utils.data import DataLoader from skimage.metrics import peak_signal_noise_ratio, structural_similarity @@ -18,32 +18,45 @@ torch.cuda.set_device(0) ckpt_path = './ckpt' -exp_name = 'RESIDE_ITS' +# exp_name = 'RESIDE_ITS' +exp_name = 'O-Haze' + args = { - 'snapshot': 'iter_40000_loss_0.01230_lr_0.000000' + # 'snapshot': 'iter_40000_loss_0.01230_lr_0.000000', + 'snapshot': 'iter_19000_loss_0.04261_lr_0.000014', } -to_test = {'SOTS': TEST_SOTS_ROOT} +to_test = { + # 'SOTS': TEST_SOTS_ROOT, + 'O-Haze': OHAZE_ROOT, +} to_pil = transforms.ToPILImage() def main(): - net = DM2FNet().cuda() - # net = nn.DataParallel(net) - - if len(args['snapshot']) > 0: - print('load snapshot \'%s\' for testing' % args['snapshot']) - net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'))) - - net.eval() - psnrs, ssims = [], [] - with torch.no_grad(): for name, root in to_test.items(): - dataset = SotsDataset(root) + if 'SOTS' in name: + net = DM2FNet().cuda() + dataset = SotsDataset(root) + elif 'O-Haze' in name: + net = DM2FNet_woPhy().cuda() + dataset = OHazeDataset(root, 'test') + else: + raise NotImplementedError + + # net = nn.DataParallel(net) + + if len(args['snapshot']) > 0: + print('load snapshot \'%s\' for testing' % args['snapshot']) + net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'))) + + net.eval() dataloader = DataLoader(dataset, batch_size=1) + psnrs, ssims = [], [] + for idx, data in enumerate(dataloader): # haze_image, _, _, _, fs = data haze, gts, fs = data @@ -52,14 +65,19 @@ def main(): check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']))) haze = haze.cuda() - res = net(haze).data + + if 'O-Haze' in name: + res = sliding_forward(net, haze).detach() + else: + res = net(haze).detach() for i in range(len(fs)): r = res[i].cpu().numpy().transpose([1, 2, 0]) gt = gts[i].cpu().numpy().transpose([1, 2, 0]) psnr = peak_signal_noise_ratio(gt, r) psnrs.append(psnr) - ssim = structural_similarity(gt, r, multichannel=True) + ssim = structural_similarity(gt, r, data_range=1, multichannel=True, + gaussian_weights=True, sigma=1.5, use_sample_covariance=False) ssims.append(ssim) print('predicting for {} ({}/{}) [{}]: PSNR {:.4f}, SSIM {:.4f}' .format(name, idx + 1, len(dataloader), fs[i], psnr, ssim)) @@ -68,7 +86,8 @@ def main(): to_pil(r).save( os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']), '%s.png' % f)) - print(f"[{name}] PSNR: {np.mean(psnrs):.6f}, SSIM: {np.mean(ssims):.6f}") + + print(f"[{name}] PSNR: {np.mean(psnrs):.6f}, SSIM: {np.mean(ssims):.6f}") if __name__ == '__main__': diff --git a/tools/config.py b/tools/config.py index 77440f8..c3ce961 100644 --- a/tools/config.py +++ b/tools/config.py @@ -1,7 +1,7 @@ # coding: utf-8 -import os +from os import path as osp -root = os.path.dirname(os.path.abspath(__file__)) +root = osp.dirname(osp.abspath(__file__)) # # TestA # TRAIN_A_ROOT = os.path.join(root, 'TrainA') @@ -9,10 +9,10 @@ # TEST_B_ROOT = os.path.join(root, 'nature') # O-Haze -OHAZE_ROOT = os.path.join(root, '../data', 'O-Haze') +OHAZE_ROOT = osp.abspath(osp.join(root, '../data', 'O-Haze')) # RESIDE -TRAIN_ITS_ROOT = os.path.join(root, '../data', 'RESIDE', 'ITS_v2') # ITS -TEST_SOTS_ROOT = os.path.join(root, '../data', 'RESIDE', 'SOTS', 'nyuhaze500') # SOTS indoor +TRAIN_ITS_ROOT = osp.abspath(osp.join(root, '../data', 'RESIDE', 'ITS_v2')) # ITS +TEST_SOTS_ROOT = osp.abspath(osp.join(root, '../data', 'RESIDE', 'SOTS', 'nyuhaze500')) # SOTS indoor # TEST_SOTS_ROOT = os.path.join(root, 'SOTS', 'outdoor') # SOTS outdoor # TEST_HSTS_ROOT = os.path.join(root, 'HSTS', 'synthetic') # HSTS diff --git a/tools/preprocess_ohaze_data.py b/tools/preprocess_ohaze_data.py index c25c864..0b18814 100644 --- a/tools/preprocess_ohaze_data.py +++ b/tools/preprocess_ohaze_data.py @@ -14,7 +14,7 @@ ori_gt_root = os.path.join(ori_root, 'GT') patch_root = os.path.join(ohaze_root, 'train_crop_{}'.format(crop_size)) - patch_haze_path = os.path.join(patch_root, 'img') + patch_haze_path = os.path.join(patch_root, 'hazy') patch_gt_path = os.path.join(patch_root, 'gt') os.makedirs(patch_root, exist_ok=True) diff --git a/tools/utils.py b/tools/utils.py index d787d2a..7f8cfad 100644 --- a/tools/utils.py +++ b/tools/utils.py @@ -1,5 +1,7 @@ import os +import math import random +import torch class AvgMeter(object): @@ -35,3 +37,31 @@ def random_crop(size, x): x1 = random.randint(0, w - size_w) y1 = random.randint(0, h - size_h) return x[:, :, y1: y1 + size_h, x1: x1 + size_w] + + +def sliding_forward(net, x: torch.Tensor, crop_size=1536): + n, c, h, w = x.size() + + if h <= crop_size and w <= crop_size: + return net(x) + else: + result = torch.zeros(n, c, h, w).cuda() + count = torch.zeros(n, 1, h, w).cuda() + stride = int(crop_size / 3.) + + h_steps = 1 + int(math.ceil(float(max(h - crop_size, 0)) / stride)) + w_steps = 1 + int(math.ceil(float(max(w - crop_size, 0)) / stride)) + + for h_idx in range(h_steps): + for w_idx in range(w_steps): + ws0, ws1 = w_idx * stride, crop_size + w_idx * stride + hs0, hs1 = h_idx * stride, crop_size + h_idx * stride + if h_idx == h_steps - 1: + hs0, hs1 = max(h - crop_size, 0), h + if w_idx == w_steps - 1: + ws0, ws1 = max(w - crop_size, 0), w + result[:, :, hs0: hs1, ws0: ws1] += net(x[:, :, hs0: hs1, ws0: ws1]).data + count[:, :, hs0: hs1, ws0: ws1] += 1 + assert torch.min(count) > 0 + result = result / count + return result diff --git a/train.py b/train.py index adee1ba..bebc326 100644 --- a/train.py +++ b/train.py @@ -2,6 +2,7 @@ import argparse import os import datetime +from tqdm import tqdm import torch from torch import nn @@ -9,10 +10,12 @@ from torch.backends import cudnn from torch.utils.data import DataLoader +from model import DM2FNet from tools.config import TRAIN_ITS_ROOT, TEST_SOTS_ROOT from datasets import ItsDataset, SotsDataset from tools.utils import AvgMeter, check_mkdir -from model import DM2FNet + +from skimage.metrics import peak_signal_noise_ratio, structural_similarity def parse_args(): @@ -69,6 +72,7 @@ def main(): check_mkdir(args.ckpt_path) check_mkdir(os.path.join(args.ckpt_path, args.exp_name)) open(log_path, 'w').write(str(cfgs) + '\n\n') + train(net, optimizer) @@ -155,7 +159,7 @@ def validate(net, curr_iter, optimizer): loss_record = AvgMeter() with torch.no_grad(): - for i, data in enumerate(val_loader): + for data in tqdm(val_loader): haze, gt, _ = data haze = haze.cuda() @@ -164,7 +168,7 @@ def validate(net, curr_iter, optimizer): dehaze = net(haze) loss = criterion(dehaze, gt) - loss_record.update(loss.data, haze.size(0)) + loss_record.update(loss.item(), haze.size(0)) snapshot_name = 'iter_%d_loss_%.5f_lr_%.6f' % (curr_iter + 1, loss_record.avg, optimizer.param_groups[1]['lr']) print('[validate]: [iter %d], [loss %.5f]' % (curr_iter + 1, loss_record.avg)) diff --git a/train_ohaze.py b/train_ohaze.py new file mode 100644 index 0000000..63bc716 --- /dev/null +++ b/train_ohaze.py @@ -0,0 +1,205 @@ +# coding: utf-8 +import argparse +import os +import datetime +from tqdm import tqdm + +import torch +from torch import nn +from torch import optim +from torch.backends import cudnn +from torch.utils.data import DataLoader +import torch.cuda.amp as amp + +from model import DM2FNet_woPhy +from tools.config import OHAZE_ROOT +from datasets import OHazeDataset +from tools.utils import AvgMeter, check_mkdir, sliding_forward + +from skimage.metrics import peak_signal_noise_ratio, structural_similarity + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a DM2FNet') + parser.add_argument( + '--gpus', type=str, default='0', help='gpus to use ') + parser.add_argument('--ckpt-path', default='./ckpt', help='checkpoint path') + parser.add_argument( + '--exp-name', + default='O-Haze', + help='experiment name.') + args = parser.parse_args() + + return args + + +cfgs = { + 'use_physical': True, + 'iter_num': 20000, + 'train_batch_size': 16, + 'last_iter': 0, + 'lr': 2e-4, + 'lr_decay': 0.9, + 'weight_decay': 2e-5, + 'momentum': 0.9, + 'snapshot': '', + 'val_freq': 2000, + 'crop_size': 512, +} + + +def main(): + net = DM2FNet_woPhy().cuda().train() + # net = DataParallel(net) + + optimizer = optim.Adam([ + {'params': [param for name, param in net.named_parameters() + if name[-4:] == 'bias' and param.requires_grad], + 'lr': 2 * cfgs['lr']}, + {'params': [param for name, param in net.named_parameters() + if name[-4:] != 'bias' and param.requires_grad], + 'lr': cfgs['lr'], 'weight_decay': cfgs['weight_decay']} + ]) + + if len(cfgs['snapshot']) > 0: + print('training resumes from \'%s\'' % cfgs['snapshot']) + net.load_state_dict(torch.load(os.path.join(args.ckpt_path, + args.exp_name, cfgs['snapshot'] + '.pth'))) + optimizer.load_state_dict(torch.load(os.path.join(args.ckpt_path, + args.exp_name, cfgs['snapshot'] + '_optim.pth'))) + optimizer.param_groups[0]['lr'] = 2 * cfgs['lr'] + optimizer.param_groups[1]['lr'] = cfgs['lr'] + + check_mkdir(args.ckpt_path) + check_mkdir(os.path.join(args.ckpt_path, args.exp_name)) + open(log_path, 'w').write(str(cfgs) + '\n\n') + + train(net, optimizer) + + +def train(net, optimizer): + curr_iter = cfgs['last_iter'] + scaler = amp.GradScaler() + torch.cuda.empty_cache() + + while curr_iter <= cfgs['iter_num']: + train_loss_record = AvgMeter() + loss_x_jf_record = AvgMeter() + loss_x_j1_record, loss_x_j2_record = AvgMeter(), AvgMeter() + loss_x_j3_record, loss_x_j4_record = AvgMeter(), AvgMeter() + + for data in train_loader: + optimizer.param_groups[0]['lr'] = 2 * cfgs['lr'] * (1 - float(curr_iter) / cfgs['iter_num']) \ + ** cfgs['lr_decay'] + optimizer.param_groups[1]['lr'] = cfgs['lr'] * (1 - float(curr_iter) / cfgs['iter_num']) \ + ** cfgs['lr_decay'] + + haze, gt, _ = data + + batch_size = haze.size(0) + + haze, gt = haze.cuda(), gt.cuda() + + optimizer.zero_grad() + + with amp.autocast(): + x_jf, x_j1, x_j2, x_j3, x_j4 = net(haze) + + loss_x_jf = criterion(x_jf, gt) + loss_x_j1 = criterion(x_j1, gt) + loss_x_j2 = criterion(x_j2, gt) + loss_x_j3 = criterion(x_j3, gt) + loss_x_j4 = criterion(x_j4, gt) + + loss = loss_x_jf + loss_x_j1 + loss_x_j2 + loss_x_j3 + loss_x_j4 + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + # loss.backward() + # optimizer.step() + + train_loss_record.update(loss.item(), batch_size) + + loss_x_jf_record.update(loss_x_jf.item(), batch_size) + loss_x_j1_record.update(loss_x_j1.item(), batch_size) + loss_x_j2_record.update(loss_x_j2.item(), batch_size) + loss_x_j3_record.update(loss_x_j3.item(), batch_size) + loss_x_j4_record.update(loss_x_j4.item(), batch_size) + + curr_iter += 1 + + log = '[iter %d], [train loss %.5f], [loss_x_fusion %.5f], [loss_x_j1 %.5f], ' \ + '[loss_x_j2 %.5f], [loss_x_j3 %.5f], [loss_x_j4 %.5f], [lr %.13f]' % \ + (curr_iter, train_loss_record.avg, loss_x_jf_record.avg, + loss_x_j1_record.avg, loss_x_j2_record.avg, loss_x_j3_record.avg, loss_x_j4_record.avg, + optimizer.param_groups[1]['lr']) + print(log) + open(log_path, 'a').write(log + '\n') + + if curr_iter == 1 or (curr_iter + 1) % cfgs['val_freq'] == 0: + validate(net, curr_iter, optimizer) + torch.cuda.empty_cache() + + if curr_iter > cfgs['iter_num']: + break + + +def validate(net, curr_iter, optimizer): + print('validating...') + net.eval() + + loss_record = AvgMeter() + psnr_record, ssim_record = AvgMeter(), AvgMeter() + + with torch.no_grad(): + for data in tqdm(val_loader): + haze, gt, _ = data + haze, gt = haze.cuda(), gt.cuda() + + dehaze = sliding_forward(net, haze) + + loss = criterion(dehaze, gt) + loss_record.update(loss.item(), haze.size(0)) + + for i in range(len(haze)): + r = dehaze[i].cpu().numpy().transpose([1, 2, 0]) # data range [0, 1] + g = gt[i].cpu().numpy().transpose([1, 2, 0]) + psnr = peak_signal_noise_ratio(g, r) + ssim = structural_similarity(g, r, data_range=1, multichannel=True, + gaussian_weights=True, sigma=1.5, use_sample_covariance=False) + psnr_record.update(psnr) + ssim_record.update(ssim) + + snapshot_name = 'iter_%d_loss_%.5f_lr_%.6f' % (curr_iter + 1, loss_record.avg, optimizer.param_groups[1]['lr']) + log = '[validate]: [iter {}], [loss {:.5f}] [PSNR {:.4f}] [SSIM {:.4f}]'.format( + curr_iter + 1, loss_record.avg, psnr_record.avg, ssim_record.avg) + print(log) + open(log_path, 'a').write(log + '\n') + torch.save(net.state_dict(), + os.path.join(args.ckpt_path, args.exp_name, snapshot_name + '.pth')) + torch.save(optimizer.state_dict(), + os.path.join(args.ckpt_path, args.exp_name, snapshot_name + '_optim.pth')) + + net.train() + + +if __name__ == '__main__': + args = parse_args() + + os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus + cudnn.benchmark = True + torch.cuda.set_device(int(args.gpus)) + + train_dataset = OHazeDataset(OHAZE_ROOT, 'train_crop_512') + train_loader = DataLoader(train_dataset, batch_size=cfgs['train_batch_size'], num_workers=4, + shuffle=True, drop_last=True) + + val_dataset = OHazeDataset(OHAZE_ROOT, 'val') + val_loader = DataLoader(val_dataset, batch_size=1) + + criterion = nn.L1Loss().cuda() + log_path = os.path.join(args.ckpt_path, args.exp_name, str(datetime.datetime.now()) + '.txt') + + main()