From 48312e7ed70e62a7d5e67cea9f334cda7f813c8d Mon Sep 17 00:00:00 2001 From: jiaqixuac Date: Fri, 1 Oct 2021 11:10:45 +0800 Subject: [PATCH] Update model&train&test --- datasets.py | 2 +- model.py | 293 ++++++++------------------------------------ test.py | 82 +++++-------- train.py | 102 ++++++++------- misc.py => utils.py | 0 5 files changed, 131 insertions(+), 348 deletions(-) rename misc.py => utils.py (100%) diff --git a/datasets.py b/datasets.py index 12efb09..a5e59aa 100644 --- a/datasets.py +++ b/datasets.py @@ -217,7 +217,7 @@ def __len__(self): class SotsDataset(data.Dataset): - def __init__(self, root, mode='train'): + def __init__(self, root, mode=None): self.root = root self.imgs = make_dataset(root) self.mode = mode diff --git a/model.py b/model.py index 1f17bfd..4615022 100644 --- a/model.py +++ b/model.py @@ -726,15 +726,12 @@ def forward(self, x0): return x_fusion -# from resnext import ResNeXt50 - - -class ours_R50(Base): +class DM2FNet(Base): def __init__(self, num_features=128): - super(ours_R50, self).__init__() + super(DM2FNet, self).__init__() self.num_features = num_features - resnext = ResNeXt50() + resnext = ResNeXt101() self.layer0 = resnext.layer0 self.layer1 = resnext.layer1 self.layer2 = resnext.layer2 @@ -764,17 +761,13 @@ def __init__(self, num_features=128): nn.Conv2d(num_features, num_features, kernel_size=1), nn.SELU(), nn.Conv2d(num_features, 1, kernel_size=1), nn.Sigmoid() ) + self.attention_phy = nn.Sequential( nn.Conv2d(num_features * 4, num_features, kernel_size=3, padding=1), nn.SELU(), nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), nn.Conv2d(num_features, num_features * 4, kernel_size=1) ) - self.attention0 = nn.Sequential( - nn.Conv2d(num_features * 4, num_features, kernel_size=3, padding=1), nn.SELU(), - nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), - nn.Conv2d(num_features, num_features * 4, kernel_size=1) - ) self.attention1 = nn.Sequential( nn.Conv2d(num_features * 4, num_features, kernel_size=3, padding=1), nn.SELU(), nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), @@ -790,197 +783,7 @@ def __init__(self, num_features=128): nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), nn.Conv2d(num_features, num_features * 4, kernel_size=1) ) - - self.refine = nn.Sequential( - nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), - nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), - nn.Conv2d(num_features, num_features, kernel_size=1) - ) - - self.p0 = nn.Sequential( - nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(), - nn.Conv2d(num_features // 2, 3, kernel_size=1) - ) - self.p1 = nn.Sequential( - nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(), - nn.Conv2d(num_features // 2, 3, kernel_size=1) - ) - self.p2 = nn.Sequential( - nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(), - nn.Conv2d(num_features // 2, 3, kernel_size=1) - ) - self.p3 = nn.Sequential( - nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(), - nn.Conv2d(num_features // 2, 3, kernel_size=1) - ) - - self.attentional_fusion = nn.Sequential( - nn.Conv2d(num_features * 4, num_features, kernel_size=1), nn.SELU(), - nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(), - nn.Conv2d(num_features // 2, num_features // 2, kernel_size=3, padding=1), nn.SELU(), - nn.Conv2d(num_features // 2, 15, kernel_size=1) - ) - - for m in self.modules(): - if isinstance(m, nn.SELU) or isinstance(m, nn.ReLU): - m.inplace = True - - def forward(self, x0, x0_hd=None): - x = (x0 - self.mean) / self.std - - layer0 = self.layer0(x) - layer1 = self.layer1(layer0) - layer2 = self.layer2(layer1) - layer3 = self.layer3(layer2) - layer4 = self.layer4(layer3) - - down1 = self.down1(layer1) - down2 = self.down2(layer2) - down3 = self.down3(layer3) - down4 = self.down4(layer4) - - down2 = F.upsample(down2, size=down1.size()[2:], mode='bilinear') - down3 = F.upsample(down3, size=down1.size()[2:], mode='bilinear') - down4 = F.upsample(down4, size=down1.size()[2:], mode='bilinear') - - concat = torch.cat((down1, down2, down3, down4), 1) - - n, c, h, w = down1.size() - - attention_phy = self.attention_phy(concat) - attention_phy = F.softmax(attention_phy.view(n, 4, c, h, w), 1) - f_phy = down1 * attention_phy[:, 0, :, :, :] + down2 * attention_phy[:, 1, :, :, :] + \ - down3 * attention_phy[:, 2, :, :, :] + down4 * attention_phy[:, 3, :, :, :] - f_phy = self.refine(f_phy) + f_phy - - attention0 = self.attention0(concat) - attention0 = F.softmax(attention0.view(n, 4, c, h, w), 1) - f0 = down1 * attention0[:, 0, :, :, :] + down2 * attention0[:, 1, :, :, :] + \ - down3 * attention0[:, 2, :, :, :] + down4 * attention0[:, 3, :, :, :] - f0 = self.refine(f0) + f0 - - attention1 = self.attention1(concat) - attention1 = F.softmax(attention1.view(n, 4, c, h, w), 1) - f1 = down1 * attention1[:, 0, :, :, :] + down2 * attention1[:, 1, :, :, :] + \ - down3 * attention1[:, 2, :, :, :] + down4 * attention1[:, 3, :, :, :] - f1 = self.refine(f1) + f1 - - attention2 = self.attention2(concat) - attention2 = F.softmax(attention2.view(n, 4, c, h, w), 1) - f2 = down1 * attention2[:, 0, :, :, :] + down2 * attention2[:, 1, :, :, :] + \ - down3 * attention2[:, 2, :, :, :] + down4 * attention2[:, 3, :, :, :] - f2 = self.refine(f2) + f2 - - attention3 = self.attention3(concat) - attention3 = F.softmax(attention3.view(n, 4, c, h, w), 1) - f3 = down1 * attention3[:, 0, :, :, :] + down2 * attention3[:, 1, :, :, :] + \ - down3 * attention3[:, 2, :, :, :] + down4 * attention3[:, 3, :, :, :] - f3 = self.refine(f3) + f3 - - if x0_hd is not None: - x0 = x0_hd - x = (x0 - self.mean) / self.std - - log_x0 = torch.log(x0.clamp(min=1e-8)) - log_log_x0_inverse = torch.log(torch.log(1 / x0.clamp(min=1e-8, max=(1 - 1e-8)))) - - a = self.a(f_phy) - t = F.upsample(self.t(f_phy), size=x0.size()[2:], mode='bilinear') - x_phy = ((x0 - a * (1 - t)) / t.clamp(min=1e-8)).clamp(min=0, max=1) - - p0 = F.upsample(self.p0(f0), size=x0.size()[2:], mode='bilinear') - x_p0 = torch.exp(log_x0 + p0).clamp(min=0, max=1) - - p1 = F.upsample(self.p1(f1), size=x0.size()[2:], mode='bilinear') - x_p1 = ((x + p1) * self.std + self.mean).clamp(min=0, max=1) - - p2 = F.upsample(self.p2(f2), size=x0.size()[2:], mode='bilinear') - x_p2 = torch.exp(-torch.exp(log_log_x0_inverse + p2)).clamp(min=0, max=1) - - p3 = F.upsample(self.p3(f3), size=x0.size()[2:], mode='bilinear') - # x_p3 = (torch.log(1 + p3 * x0)).clamp(min=0, max=1) - x_p3 = (torch.log(1 + torch.exp(log_x0 + p3))).clamp(min=0, max=1) - - attention_fusion = F.upsample(self.attentional_fusion(concat), size=x0.size()[2:], mode='bilinear') - x_fusion = torch.cat((torch.sum(F.softmax(attention_fusion[:, : 5, :, :], 1) * torch.stack( - (x_phy[:, 0, :, :], x_p0[:, 0, :, :], x_p1[:, 0, :, :], x_p2[:, 0, :, :], x_p3[:, 0, :, :]), 1), 1, True), - torch.sum(F.softmax(attention_fusion[:, 5: 10, :, :], 1) * torch.stack((x_phy[:, 1, :, :], - x_p0[:, 1, :, :], - x_p1[:, 1, :, :], - x_p2[:, 1, :, :], - x_p3[:, 1, :, :]), - 1), 1, True), - torch.sum(F.softmax(attention_fusion[:, 10:, :, :], 1) * torch.stack((x_phy[:, 2, :, :], - x_p0[:, 2, :, :], - x_p1[:, 2, :, :], - x_p2[:, 2, :, :], - x_p3[:, 2, :, :]), - 1), 1, True)), - 1).clamp(min=0, max=1) - - if self.training: - return x_fusion, x_phy, x_p0, x_p1, x_p2, x_p3, t, a.view(x.size(0), -1) - else: - return x_fusion - - -class ours(Base): - def __init__(self, num_features=128): - super(ours, self).__init__() - self.num_features = num_features - - resnext = ResNeXt101() - self.layer0 = resnext.layer0 - self.layer1 = resnext.layer1 - self.layer2 = resnext.layer2 - self.layer3 = resnext.layer3 - self.layer4 = resnext.layer4 - - self.down1 = nn.Sequential( - nn.Conv2d(256, num_features, kernel_size=1), nn.SELU() - ) - self.down2 = nn.Sequential( - nn.Conv2d(512, num_features, kernel_size=1), nn.SELU() - ) - self.down3 = nn.Sequential( - nn.Conv2d(1024, num_features, kernel_size=1), nn.SELU() - ) - self.down4 = nn.Sequential( - nn.Conv2d(2048, num_features, kernel_size=1), nn.SELU() - ) - - self.t = nn.Sequential( - nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(), - nn.Conv2d(num_features // 2, 1, kernel_size=1), nn.Sigmoid() - ) - self.a = nn.Sequential( - nn.AdaptiveAvgPool2d(1), - nn.Conv2d(num_features, num_features, kernel_size=1), nn.SELU(), - nn.Conv2d(num_features, num_features, kernel_size=1), nn.SELU(), - nn.Conv2d(num_features, 1, kernel_size=1), nn.Sigmoid() - ) - self.attention_phy = nn.Sequential( - nn.Conv2d(num_features * 4, num_features, kernel_size=3, padding=1), nn.SELU(), - nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), - nn.Conv2d(num_features, num_features * 4, kernel_size=1) - ) - - self.attention0 = nn.Sequential( - nn.Conv2d(num_features * 4, num_features, kernel_size=3, padding=1), nn.SELU(), - nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), - nn.Conv2d(num_features, num_features * 4, kernel_size=1) - ) - self.attention1 = nn.Sequential( - nn.Conv2d(num_features * 4, num_features, kernel_size=3, padding=1), nn.SELU(), - nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), - nn.Conv2d(num_features, num_features * 4, kernel_size=1) - ) - self.attention2 = nn.Sequential( - nn.Conv2d(num_features * 4, num_features, kernel_size=3, padding=1), nn.SELU(), - nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), - nn.Conv2d(num_features, num_features * 4, kernel_size=1) - ) - self.attention3 = nn.Sequential( + self.attention4 = nn.Sequential( nn.Conv2d(num_features * 4, num_features, kernel_size=3, padding=1), nn.SELU(), nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), nn.SELU(), nn.Conv2d(num_features, num_features * 4, kernel_size=1) @@ -992,24 +795,24 @@ def __init__(self, num_features=128): nn.Conv2d(num_features, num_features, kernel_size=1) ) - self.p0 = nn.Sequential( + self.j1 = nn.Sequential( nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(), nn.Conv2d(num_features // 2, 3, kernel_size=1) ) - self.p1 = nn.Sequential( + self.j2 = nn.Sequential( nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(), nn.Conv2d(num_features // 2, 3, kernel_size=1) ) - self.p2 = nn.Sequential( + self.j3 = nn.Sequential( nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(), nn.Conv2d(num_features // 2, 3, kernel_size=1) ) - self.p3 = nn.Sequential( + self.j4 = nn.Sequential( nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(), nn.Conv2d(num_features // 2, 3, kernel_size=1) ) - self.attentional_fusion = nn.Sequential( + self.attention_fusion = nn.Sequential( nn.Conv2d(num_features * 4, num_features, kernel_size=1), nn.SELU(), nn.Conv2d(num_features, num_features // 2, kernel_size=3, padding=1), nn.SELU(), nn.Conv2d(num_features // 2, num_features // 2, kernel_size=3, padding=1), nn.SELU(), @@ -1048,12 +851,6 @@ def forward(self, x0, x0_hd=None): down3 * attention_phy[:, 2, :, :, :] + down4 * attention_phy[:, 3, :, :, :] f_phy = self.refine(f_phy) + f_phy - attention0 = self.attention0(concat) - attention0 = F.softmax(attention0.view(n, 4, c, h, w), 1) - f0 = down1 * attention0[:, 0, :, :, :] + down2 * attention0[:, 1, :, :, :] + \ - down3 * attention0[:, 2, :, :, :] + down4 * attention0[:, 3, :, :, :] - f0 = self.refine(f0) + f0 - attention1 = self.attention1(concat) attention1 = F.softmax(attention1.view(n, 4, c, h, w), 1) f1 = down1 * attention1[:, 0, :, :, :] + down2 * attention1[:, 1, :, :, :] + \ @@ -1072,6 +869,12 @@ def forward(self, x0, x0_hd=None): down3 * attention3[:, 2, :, :, :] + down4 * attention3[:, 3, :, :, :] f3 = self.refine(f3) + f3 + attention4 = self.attention4(concat) + attention4 = F.softmax(attention4.view(n, 4, c, h, w), 1) + f4 = down1 * attention4[:, 0, :, :, :] + down2 * attention4[:, 1, :, :, :] + \ + down3 * attention4[:, 2, :, :, :] + down4 * attention4[:, 3, :, :, :] + f4 = self.refine(f4) + f4 + if x0_hd is not None: x0 = x0_hd x = (x0 - self.mean) / self.std @@ -1079,41 +882,41 @@ def forward(self, x0, x0_hd=None): log_x0 = torch.log(x0.clamp(min=1e-8)) log_log_x0_inverse = torch.log(torch.log(1 / x0.clamp(min=1e-8, max=(1 - 1e-8)))) + # J0 = (I - A0 * (1 - T0)) / T0 a = self.a(f_phy) t = F.upsample(self.t(f_phy), size=x0.size()[2:], mode='bilinear') - x_phy = ((x0 - a * (1 - t)) / t.clamp(min=1e-8)).clamp(min=0, max=1) - - p0 = F.upsample(self.p0(f0), size=x0.size()[2:], mode='bilinear') - x_p0 = torch.exp(log_x0 + p0).clamp(min=0, max=1) - - p1 = F.upsample(self.p1(f1), size=x0.size()[2:], mode='bilinear') - x_p1 = ((x + p1) * self.std + self.mean).clamp(min=0, max=1) - - p2 = F.upsample(self.p2(f2), size=x0.size()[2:], mode='bilinear') - x_p2 = torch.exp(-torch.exp(log_log_x0_inverse + p2)).clamp(min=0, max=1) - - p3 = F.upsample(self.p3(f3), size=x0.size()[2:], mode='bilinear') - # x_p3 = (torch.log(1 + p3 * x0)).clamp(min=0, max=1) - x_p3 = (torch.log(1 + torch.exp(log_x0 + p3))).clamp(min=0, max=1) - - attention_fusion = F.upsample(self.attentional_fusion(concat), size=x0.size()[2:], mode='bilinear') - x_fusion = torch.cat((torch.sum(F.softmax(attention_fusion[:, : 5, :, :], 1) * torch.stack( - (x_phy[:, 0, :, :], x_p0[:, 0, :, :], x_p1[:, 0, :, :], x_p2[:, 0, :, :], x_p3[:, 0, :, :]), 1), 1, True), - torch.sum(F.softmax(attention_fusion[:, 5: 10, :, :], 1) * torch.stack((x_phy[:, 1, :, :], - x_p0[:, 1, :, :], - x_p1[:, 1, :, :], - x_p2[:, 1, :, :], - x_p3[:, 1, :, :]), - 1), 1, True), - torch.sum(F.softmax(attention_fusion[:, 10:, :, :], 1) * torch.stack((x_phy[:, 2, :, :], - x_p0[:, 2, :, :], - x_p1[:, 2, :, :], - x_p2[:, 2, :, :], - x_p3[:, 2, :, :]), - 1), 1, True)), - 1).clamp(min=0, max=1) + x_phy = ((x0 - a * (1 - t)) / t.clamp(min=1e-8)).clamp(min=0., max=1.) + + # J2 = I * exp(R2) + r1 = F.upsample(self.j1(f1), size=x0.size()[2:], mode='bilinear') + x_j1 = torch.exp(log_x0 + r1).clamp(min=0., max=1.) + + # J2 = I + R2 + r2 = F.upsample(self.j2(f2), size=x0.size()[2:], mode='bilinear') + x_j2 = ((x + r2) * self.std + self.mean).clamp(min=0., max=1.) + + # + r3 = F.upsample(self.j3(f3), size=x0.size()[2:], mode='bilinear') + x_j3 = torch.exp(-torch.exp(log_log_x0_inverse + r3)).clamp(min=0., max=1.) + + # J4 = log(1 + I * R4) + r4 = F.upsample(self.j4(f4), size=x0.size()[2:], mode='bilinear') + # x_j4 = (torch.log(1 + r4 * x0)).clamp(min=0, max=1) + x_j4 = (torch.log(1 + torch.exp(log_x0 + r4))).clamp(min=0., max=1.) + + attention_fusion = F.upsample(self.attention_fusion(concat), size=x0.size()[2:], mode='bilinear') + x_f0 = torch.sum(F.softmax(attention_fusion[:, :5, :, :], 1) * + torch.stack((x_phy[:, 0, :, :], x_j1[:, 0, :, :], x_j2[:, 0, :, :], + x_j3[:, 0, :, :], x_j4[:, 0, :, :]), 1), 1, True) + x_f1 = torch.sum(F.softmax(attention_fusion[:, 5: 10, :, :], 1) * + torch.stack((x_phy[:, 1, :, :], x_j1[:, 1, :, :], x_j2[:, 1, :, :], + x_j3[:, 1, :, :], x_j4[:, 1, :, :]), 1), 1, True) + x_f2 = torch.sum(F.softmax(attention_fusion[:, 10:, :, :], 1) * + torch.stack((x_phy[:, 2, :, :], x_j1[:, 2, :, :], x_j2[:, 2, :, :], + x_j3[:, 2, :, :], x_j4[:, 2, :, :]), 1), 1, True) + x_fusion = torch.cat((x_f0, x_f1, x_f2), 1).clamp(min=0., max=1.) if self.training: - return x_fusion, x_phy, x_p0, x_p1, x_p2, x_p3, t, a.view(x.size(0), -1) + return x_fusion, x_phy, x_j1, x_j2, x_j3, x_j4, t, a.view(x.size(0), -1) else: return x_fusion diff --git a/test.py b/test.py index e544aac..8e15b8a 100644 --- a/test.py +++ b/test.py @@ -3,12 +3,11 @@ import numpy as np import torch -from torch.autograd import Variable from torchvision import transforms from config import TEST_SOTS_ROOT -from misc import check_mkdir -from model import ours +from utils import check_mkdir +from model import DM2FNet from datasets import SotsDataset from torch.utils.data import DataLoader from skimage.metrics import peak_signal_noise_ratio @@ -19,9 +18,9 @@ torch.cuda.set_device(0) ckpt_path = './ckpt' -exp_name = '(ablation8 its) ours' +exp_name = 'RESIDE_ITS' args = { - 'snapshot': 'iter_40000_loss_0.01221_lr_0.000000' + 'snapshot': 'iter_40000_loss_0.01256_lr_0.000000' } to_test = {'SOTS': TEST_SOTS_ROOT} @@ -30,7 +29,7 @@ def main(): - net = ours().cuda() + net = DM2FNet().cuda() # net = nn.DataParallel(net) if len(args['snapshot']) > 0: @@ -40,50 +39,33 @@ def main(): net.eval() psnrs = [] - # with torch.no_grad(): - # for name, root in to_test.iteritems(): - # dataset = ImageFolder2(root, 'test') - # dataloader = DataLoader(dataset, batch_size=8) - # - # for idx, data in enumerate(dataloader): - # print 'predicting for %s: %d / %d' % (name, idx + 1, len(dataloader)) - # check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']))) - # - # # haze_image, _, _, _, fs = data - # haze_image, fs = data - # - # img_var = Variable(haze_image).cuda() - # res = net(img_var).data - # res[res > 1] = 1 - # res[res < 0] = 0 - # - # for r, f in zip(res.cpu(), fs): - # to_pil(r).save(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']), '%s.png' % f)) - for name, root in to_test.items(): - dataset = SotsDataset(root, 'train') - dataloader = DataLoader(dataset, batch_size=1) - - for idx, data in enumerate(dataloader): - # haze_image, _, _, _, fs = data - haze_image, gts, fs = data - # print(haze_image.shape, gts.shape) - - print('predicting for %s [%s]: %d / %d' % (name, fs, idx + 1, len(dataloader))) - check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']))) - - img_var = Variable(haze_image, volatile=True).cuda() - res = net(img_var).data - - for i in range(len(fs)): - r = res[i].cpu().numpy().transpose([1, 2, 0]) - gt = gts[i].cpu().numpy().transpose([1, 2, 0]) - psnr = peak_signal_noise_ratio(gt, r) - psnrs.append(psnr) - - for r, f in zip(res.cpu(), fs): - to_pil(r).save( - os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']), '%s.png' % f)) - print(f"PSNR: {np.mean(psnrs):.6f}") + with torch.no_grad(): + for name, root in to_test.items(): + dataset = SotsDataset(root) + dataloader = DataLoader(dataset, batch_size=1) + + for idx, data in enumerate(dataloader): + # haze_image, _, _, _, fs = data + haze, gts, fs = data + # print(haze_image.shape, gts.shape) + + check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']))) + + haze = haze.cuda() + res = net(haze).data + + for i in range(len(fs)): + r = res[i].cpu().numpy().transpose([1, 2, 0]) + gt = gts[i].cpu().numpy().transpose([1, 2, 0]) + psnr = peak_signal_noise_ratio(gt, r) + psnrs.append(psnr) + print('predicting for {} ({}/{}) [{}]: {:.4f}'.format(name, idx + 1, len(dataloader), fs[i], psnr)) + + for r, f in zip(res.cpu(), fs): + to_pil(r).save( + os.path.join(ckpt_path, exp_name, + '(%s) %s_%s' % (exp_name, name, args['snapshot']), '%s.png' % f)) + print(f"PSNR for {name}: {np.mean(psnrs):.6f}") if __name__ == '__main__': diff --git a/train.py b/train.py index 86e6c8b..b16d48d 100644 --- a/train.py +++ b/train.py @@ -5,14 +5,13 @@ import torch from torch import nn from torch import optim -from torch.autograd import Variable from torch.backends import cudnn from torch.utils.data import DataLoader from config import TRAIN_ITS_ROOT, TEST_SOTS_ROOT from datasets import ItsDataset, SotsDataset -from misc import AvgMeter, check_mkdir -from model import ours +from utils import AvgMeter, check_mkdir +from model import DM2FNet os.environ['CUDA_VISIBLE_DEVICES'] = '0' @@ -21,7 +20,7 @@ torch.cuda.set_device(0) ckpt_path = './ckpt' -exp_name = '(ablation8 its) ours' +exp_name = 'RESIDE_ITS' args = { 'iter_num': 40000, @@ -48,7 +47,7 @@ def main(): - net = ours().cuda().train() + net = DM2FNet().cuda().train() # net = nn.DataParallel(net) optimizer = optim.Adam([ @@ -76,60 +75,66 @@ def train(net, optimizer): while curr_iter <= args['iter_num']: train_loss_record = AvgMeter() - loss_x_fusion_record, loss_x_phy_record, loss_x_p0_record, loss_x_p1_record, loss_x_p2_record, loss_x_p3_record, loss_t_record, loss_a_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() + loss_x_jf_record, loss_x_j0_record = AvgMeter(), AvgMeter() + loss_x_j1_record, loss_x_j2_record = AvgMeter(), AvgMeter() + loss_x_j3_record, loss_x_j4_record = AvgMeter(), AvgMeter() + loss_t_record, loss_a_record = AvgMeter(), AvgMeter() + for data in train_loader: - optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (1 - float(curr_iter) / args['iter_num'] - ) ** args['lr_decay'] - optimizer.param_groups[1]['lr'] = args['lr'] * (1 - float(curr_iter) / args['iter_num'] - ) ** args['lr_decay'] + optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (1 - float(curr_iter) / args['iter_num']) \ + ** args['lr_decay'] + optimizer.param_groups[1]['lr'] = args['lr'] * (1 - float(curr_iter) / args['iter_num']) \ + ** args['lr_decay'] - haze_image, gt_trans_map, gt_ato, gt, _ = data + haze, gt_trans_map, gt_ato, gt, _ = data - batch_size = haze_image.size(0) + batch_size = haze.size(0) - haze_image = Variable(haze_image).cuda() - gt_trans_map = Variable(gt_trans_map).cuda() - gt_ato = Variable(gt_ato).cuda() - gt = Variable(gt).cuda() + haze = haze.cuda() + gt_trans_map = gt_trans_map.cuda() + gt_ato = gt_ato.cuda() + gt = gt.cuda() optimizer.zero_grad() - x_fusion, x_phy, x_p0, x_p1, x_p2, x_p3, t, a = net(haze_image) + x_jf, x_j0, x_j1, x_j2, x_j3, x_j4, t, a = net(haze) - loss_x_fusion = criterion(x_fusion, gt) - loss_x_phy = criterion(x_phy, gt) - loss_x_p0 = criterion(x_p0, gt) - loss_x_p1 = criterion(x_p1, gt) - loss_x_p2 = criterion(x_p2, gt) - loss_x_p3 = criterion(x_p3, gt) + loss_x_jf = criterion(x_jf, gt) + loss_x_j0 = criterion(x_j0, gt) + loss_x_j1 = criterion(x_j1, gt) + loss_x_j2 = criterion(x_j2, gt) + loss_x_j3 = criterion(x_j3, gt) + loss_x_j4 = criterion(x_j4, gt) loss_t = criterion(t, gt_trans_map) loss_a = criterion(a, gt_ato) - loss = loss_x_fusion + loss_x_p0 + loss_x_p1 + loss_x_p2 + loss_x_p3 + loss_x_phy + 10 * loss_t + loss_a + loss = loss_x_jf + loss_x_j0 + loss_x_j1 + loss_x_j2 + loss_x_j3 + loss_x_j4 \ + + 10 * loss_t + loss_a loss.backward() optimizer.step() - train_loss_record.update(loss.data, batch_size) + # update recorder + train_loss_record.update(loss.item(), batch_size) - loss_x_fusion_record.update(loss_x_fusion.data, batch_size) - loss_x_phy_record.update(loss_x_phy.data, batch_size) - loss_x_p0_record.update(loss_x_p0.data, batch_size) - loss_x_p1_record.update(loss_x_p1.data, batch_size) - loss_x_p2_record.update(loss_x_p2.data, batch_size) - loss_x_p3_record.update(loss_x_p3.data, batch_size) + loss_x_jf_record.update(loss_x_jf.item(), batch_size) + loss_x_j0_record.update(loss_x_j0.item(), batch_size) + loss_x_j1_record.update(loss_x_j1.item(), batch_size) + loss_x_j2_record.update(loss_x_j2.item(), batch_size) + loss_x_j3_record.update(loss_x_j3.item(), batch_size) + loss_x_j4_record.update(loss_x_j4.item(), batch_size) - loss_t_record.update(loss_t.data, batch_size) - loss_a_record.update(loss_a.data, batch_size) + loss_t_record.update(loss_t.item(), batch_size) + loss_a_record.update(loss_a.item(), batch_size) curr_iter += 1 - log = '[iter %d], [train loss %.5f], [loss_x_fusion %.5f], [loss_x_phy %.5f], [loss_x_p0 %.5f], ' \ - '[loss_x_p1 %.5f], [loss_x_p2 %.5f], [loss_x_p3 %.5f], [loss_t %.5f], [loss_a %.5f], ' \ + log = '[iter %d], [train loss %.5f], [loss_x_fusion %.5f], [loss_x_phy %.5f], [loss_x_j1 %.5f], ' \ + '[loss_x_j2 %.5f], [loss_x_j3 %.5f], [loss_x_j4 %.5f], [loss_t %.5f], [loss_a %.5f], ' \ '[lr %.13f]' % \ - (curr_iter, train_loss_record.avg, loss_x_fusion_record.avg, loss_x_phy_record.avg, - loss_x_p0_record.avg, loss_x_p1_record.avg, loss_x_p2_record.avg, loss_x_p3_record.avg, + (curr_iter, train_loss_record.avg, loss_x_jf_record.avg, loss_x_j0_record.avg, + loss_x_j1_record.avg, loss_x_j2_record.avg, loss_x_j3_record.avg, loss_x_j4_record.avg, loss_t_record.avg, loss_a_record.avg, optimizer.param_groups[1]['lr']) print(log) open(log_path, 'a').write(log + '\n') @@ -137,6 +142,9 @@ def train(net, optimizer): if (curr_iter + 1) % args['val_freq'] == 0: validate(net, curr_iter, optimizer) + if curr_iter > args['iter_num']: + break + def validate(net, curr_iter, optimizer): print('validating...') @@ -146,25 +154,15 @@ def validate(net, curr_iter, optimizer): with torch.no_grad(): for i, data in enumerate(val_loader): - haze_image, gt, _ = data + haze, gt, _ = data - haze_image = Variable(haze_image).cuda() - gt = Variable(gt).cuda() + haze = haze.cuda() + gt = gt.cuda() - dehaze = net(haze_image) + dehaze = net(haze) loss = criterion(dehaze, gt) - loss_record.update(loss.data, haze_image.size(0)) - # for i, data in enumerate(val_loader): - # haze_image, gt, _ = data - # - # haze_image = Variable(haze_image, volatile=True).cuda() - # gt = Variable(gt, volatile=True).cuda() - # - # dehaze = net(haze_image) - # - # loss = criterion(dehaze, gt) - # loss_record.update(loss.data, haze_image.size(0)) + loss_record.update(loss.data, haze.size(0)) snapshot_name = 'iter_%d_loss_%.5f_lr_%.6f' % (curr_iter + 1, loss_record.avg, optimizer.param_groups[1]['lr']) print('[validate]: [iter %d], [loss %.5f]' % (curr_iter + 1, loss_record.avg)) diff --git a/misc.py b/utils.py similarity index 100% rename from misc.py rename to utils.py