diff --git a/.gitignore b/.gitignore index 47cc256..8155a34 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ __pycache__/ .ipynb_checkpoints/ .idea ckpt +data \ No newline at end of file diff --git a/README.md b/README.md index 4ac40b1..071714e 100644 --- a/README.md +++ b/README.md @@ -11,16 +11,16 @@ This repo is the implementation of The dehazing results can be found at [Google Drive](https://drive.google.com/drive/folders/1ZVBI_3Y2NthVLeK7ODMIB5vRjmN9payF?usp=sharing). -## Installation +## Installation & Preparation Make sure you have `Python>=3.6` installed on your machine. **Environment setup:** -1. create conda environment +1. Create conda environment - conda create -n midline - conda activate midline + conda create -n dm2f + conda activate dm2f 2. Install dependencies: @@ -30,6 +30,12 @@ Make sure you have `Python>=3.6` installed on your machine. pip install -r requirements.txt +3. Prepare the dataset (RESIDE) + + 1. Download the RESIDE dataset from the [official webpage](https://sites.google.com/site/boyilics/website-builder/reside). + + 2. Make a directory `./data` and create a symbolic link for uncompressed data `./data/RESIDE`. + ## Training 1. Set the path of pretrained ResNeXt model in resnext/config.py diff --git a/config.py b/config.py index a1bb535..73222de 100644 --- a/config.py +++ b/config.py @@ -1,41 +1,15 @@ # coding: utf-8 import os -root = '/home/jqxu/data/RESIDE' -train_a_root = os.path.join(root, 'TrainA') -test_a_root = os.path.join(root, 'TestA') -test_b_root = os.path.join(root, 'nature') - -train_its_root = os.path.join(root, 'ITS_v2') -test_hsts_root = os.path.join(root, 'HSTS', 'synthetic', '') -# test_sots_root = os.path.join(root, 'SOTS', 'nyuhaze500') -test_sots_root = os.path.join(root, 'SOTS', 'outdoor') - -ohaze_ori_root = os.path.join(root, 'O-HAZE') -train_ohaze_ori_root = os.path.join(ohaze_ori_root, 'train') -test_ohaze_ori_root = os.path.join(ohaze_ori_root, 'test') - -ohaze_resize_root = os.path.join('/media/b3-542/4edbfae9-f11c-447b-b07c-585bc4017092/DataSets/dehaze', 'O-HAZE') -train_ohaze_resize_root = os.path.join(ohaze_resize_root, 'train') -test_ohaze_resize_root = os.path.join(ohaze_resize_root, 'test') - -ihaze_resize_root = os.path.join('/media/b3-542/4edbfae9-f11c-447b-b07c-585bc4017092/DataSets/dehaze', 'I-HAZE') -train_ihaze_resize_root = os.path.join(ihaze_resize_root, 'train') -test_ihaze_resize_root = os.path.join(ihaze_resize_root, 'test') - -ohaze_hd_root = os.path.join('/media/b3-542/4edbfae9-f11c-447b-b07c-585bc4017092/DataSets/dehaze', 'O-HAZE-HD') -train_ohaze_hd_root = os.path.join(ohaze_hd_root, 'train') -test_ohaze_hd_root = os.path.join(ohaze_hd_root, 'test') - -ihaze_hd_root = os.path.join('/media/b3-542/4edbfae9-f11c-447b-b07c-585bc4017092/DataSets/dehaze', 'I-HAZE-HD') -train_ihaze_hd_root = os.path.join(ihaze_hd_root, 'train') -test_ihaze_hd_root = os.path.join(ihaze_hd_root, 'test') - -ihaze_root = os.path.join('/media/b3-542/4edbfae9-f11c-447b-b07c-585bc4017092/DataSets/dehaze', 'I-HAZE') -train_ihaze_root = os.path.join(ihaze_root, 'train') -test_ihaze_root = os.path.join(ihaze_root, 'test') - -train_ots_root = '/media/b3-542/4edbfae9-f11c-447b-b07c-585bc4017092/DataSets/dehaze/OTS' -test_natural2_root = '/home/b3-542/文档/DataSets/dehaze/natural2' - -vgg_path = '/media/b3-542/454BAA0333169FE1/Packages/PyTorch Pretrained/VggNet/vgg16-397923af.pth' +root = os.path.dirname(os.path.abspath(__file__)) + +# # TestA +# TRAIN_A_ROOT = os.path.join(root, 'TrainA') +# TEST_A_ROOT = os.path.join(root, 'TestA') +# TEST_B_ROOT = os.path.join(root, 'nature') + +# RESIDE +TRAIN_ITS_ROOT = os.path.join(root, 'data', 'RESIDE', 'ITS_v2') # ITS +TEST_SOTS_ROOT = os.path.join(root, 'data', 'RESIDE', 'SOTS', 'nyuhaze500') # SOTS indoor +# TEST_SOTS_ROOT = os.path.join(root, 'SOTS', 'outdoor') # SOTS outdoor +# TEST_HSTS_ROOT = os.path.join(root, 'HSTS', 'synthetic') # HSTS diff --git a/datasets.py b/datasets.py index e3574ee..fdb2b33 100644 --- a/datasets.py +++ b/datasets.py @@ -5,10 +5,9 @@ import torch.utils.data as data import scipy.io as sio from PIL import Image -from torchvision.transforms import ToTensor, Resize +from torchvision.transforms import ToTensor import random from torchvision import transforms -import numpy as np def make_dataset(root): diff --git a/infer.py b/infer.py index df657af..0ff3cbc 100644 --- a/infer.py +++ b/infer.py @@ -1,19 +1,16 @@ # coding: utf-8 import os -import cv2 import numpy as np import torch -from PIL import Image from torch.autograd import Variable from torchvision import transforms -from config import test_sots_root, test_hsts_root -from misc import check_mkdir, crf_refine +from config import TEST_SOTS_ROOT +from misc import check_mkdir from model import ours from datasets import ImageFolder2 from torch.utils.data import DataLoader -from torch import nn from skimage.metrics import peak_signal_noise_ratio os.environ['CUDA_VISIBLE_DEVICES'] = '0' @@ -27,7 +24,7 @@ 'snapshot': 'iter_40000_loss_0.01221_lr_0.000000' } -to_test = {'SOTS': test_sots_root} +to_test = {'SOTS': TEST_SOTS_ROOT} to_pil = transforms.ToPILImage() diff --git a/misc.py b/misc.py index 5f3bbfb..c4db1d1 100644 --- a/misc.py +++ b/misc.py @@ -1,13 +1,5 @@ -import numpy as np import os -import torch import random -from torch import nn -from torchvision import models -from config import vgg_path -import torch.nn.functional as F -from math import ceil -# import pydensecrf.densecrf as dcrf class AvgMeter(object): @@ -32,134 +24,6 @@ def check_mkdir(dir_name): os.mkdir(dir_name) -def _sigmoid(x): - return 1 / (1 + np.exp(-x)) - - -def crf_refine(img, annos): - assert img.dtype == np.uint8 - assert annos.dtype == np.uint8 - assert img.shape[:2] == annos.shape - - # img and annos should be np array with data type uint8 - - EPSILON = 1e-8 - - M = 2 # salient or not - tau = 1.05 - # Setup the CRF model - d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], M) - - anno_norm = annos / 255. - - n_energy = -np.log((1.0 - anno_norm + EPSILON)) / (tau * _sigmoid(1 - anno_norm)) - p_energy = -np.log(anno_norm + EPSILON) / (tau * _sigmoid(anno_norm)) - - U = np.zeros((M, img.shape[0] * img.shape[1]), dtype='float32') - U[0, :] = n_energy.flatten() - U[1, :] = p_energy.flatten() - - d.setUnaryEnergy(U) - - d.addPairwiseGaussian(sxy=3, compat=3) - d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5) - - # Do the inference - infer = np.array(d.inference(1)).astype('float32') - res = infer[1, :] - - res = res * 255 - res = res.reshape(img.shape[:2]) - return res.astype('uint8') - - -def sliding_forward(net, x, crop_size=2048): - n, c, h, w = x.size() - if h <= crop_size and w <= crop_size: - return net(x) - else: - result = torch.zeros(n, 3, h, w).cuda() - count = torch.zeros(n, 3, h, w).cuda() - stride = int(crop_size / 3.) - - h_steps = 2 + max(h - crop_size, 0) / stride - w_steps = 2 + max(w - crop_size, 0) / stride - - for h_idx in range(h_steps): - for w_idx in range(w_steps): - h_slice = slice(h_idx * stride, min(crop_size + h_idx * stride, h)) - w_slice = slice(w_idx * stride, min(crop_size + w_idx * stride, w)) - if h_idx == h_steps - 1: - h_slice = slice(max(h - crop_size, 0), h) - if w_idx == w_steps - 1: - w_slice = slice(max(w - crop_size, 0), w) - result[:, :, h_slice, w_slice] += net(x[:, :, h_slice, w_slice].contiguous()) - count[:, :, h_slice, w_slice] += 1 - assert torch.min(count) > 0 - result = result / count - return result - - -def sliding_forward2(net, x, crop_size=(2048, 2048)): - ch, cw = crop_size - n, c, h, w = x.size() - if h <= ch and w <= cw: - return net(x) - else: - result = torch.zeros_like(x).cuda() - count = torch.zeros_like(x).cuda() - stride_h = int(ch / 3.) - stride_w = int(cw / 3.) - - h_steps = 2 + max(h - ch, 0) / stride_h - w_steps = 2 + max(w - cw, 0) / stride_w - - for h_idx in range(h_steps): - for w_idx in range(w_steps): - h_slice = slice(h_idx * stride_h, min(ch + h_idx * stride_h, h)) - w_slice = slice(w_idx * stride_w, min(cw + w_idx * stride_w, w)) - if h_idx == h_steps - 1: - h_slice = slice(max(h - ch, 0), h) - if w_idx == w_steps - 1: - w_slice = slice(max(w - cw, 0), w) - result[:, :, h_slice, w_slice] += net(x[:, :, h_slice, w_slice]) - count[:, :, h_slice, w_slice] += 1 - assert torch.min(count) > 0 - result = result / count - return result - - -def sliding_forward3(net, x, shrink_factor, crop_size=1536): - n, c, h, w = x.size() - if h <= crop_size and w <= crop_size: - x_sm = F.upsample(x, size=(int(h * shrink_factor), int(w * shrink_factor)), mode='bilinear') - return net(x_sm, x) - else: - result = torch.zeros(n, c, h, w).cuda() - count = torch.zeros(n, 1, h, w).cuda() - stride = int(crop_size / 3.) - - h_steps = 1 + int(ceil(float(max(h - crop_size, 0)) / stride)) - w_steps = 1 + int(ceil(float(max(w - crop_size, 0)) / stride)) - - for h_idx in range(h_steps): - for w_idx in range(w_steps): - ws0, ws1 = w_idx * stride, crop_size + w_idx * stride - hs0, hs1 = h_idx * stride, crop_size + h_idx * stride - if h_idx == h_steps - 1: - hs0, hs1 = max(h - crop_size, 0), h - if w_idx == w_steps - 1: - ws0, ws1 = max(w - crop_size, 0), w - x_patch = x[:, :, hs0: hs1, ws0: ws1] - patch_h, patch_w = x_patch.size()[2:] - x_patch_sm = F.upsample(x_patch, size=(int(patch_h * shrink_factor), int(patch_w * shrink_factor)), mode='bilinear') - result[:, :, hs0: hs1, ws0: ws1] += net(x_patch_sm, x_patch) - count[:, :, hs0: hs1, ws0: ws1] += 1 - assert torch.min(count) > 0 - result = result / count - return result - - def random_crop(size, x): h, w = x.size()[2:] size_h = min(size, h) @@ -168,38 +32,3 @@ def random_crop(size, x): x1 = random.randint(0, w - size_w) y1 = random.randint(0, h - size_h) return x[:, :, y1: y1 + size_h, x1: x1 + size_w] - - -class PerceptualLoss(nn.Module): - def __init__(self, order): - super(PerceptualLoss, self).__init__() - assert order in [1, 2] - vgg = models.vgg16() - vgg.load_state_dict(torch.load(vgg_path)) - self.vgg = nn.Sequential(*(list(vgg.features.children())[: 9])).eval() - - self.criterion = nn.L1Loss() if order == 1 else nn.MSELoss() - - self.mean = torch.zeros(1, 3, 1, 1) - self.std = torch.zeros(1, 3, 1, 1) - self.mean[0, 0, 0, 0] = 0.485 - self.mean[0, 1, 0, 0] = 0.456 - self.mean[0, 2, 0, 0] = 0.406 - self.std[0, 0, 0, 0] = 0.229 - self.std[0, 1, 0, 0] = 0.224 - self.std[0, 2, 0, 0] = 0.225 - - self.mean = nn.Parameter(self.mean) - self.std = nn.Parameter(self.std) - - for m in self.vgg.modules(): - if isinstance(m, nn.ReLU): - m.inplace = True - - for param in self.parameters(): - param.requires_grad = False - - def forward(self, input, target): - input = (input - self.mean) / self.std - target = (target - self.mean) / self.std - return self.criterion(self.vgg(input), self.vgg(target).detach()) diff --git a/model.py b/model.py index 849a7a4..1f17bfd 100644 --- a/model.py +++ b/model.py @@ -709,15 +709,15 @@ def forward(self, x0): x_fusion = torch.cat((torch.sum(F.softmax(attention_fusion[:, : 4, :, :], 1) * torch.stack( (x_p0[:, 0, :, :], x_p1[:, 0, :, :], x_p2[:, 0, :, :], x_p3[:, 0, :, :]), 1), 1, True), torch.sum(F.softmax(attention_fusion[:, 4: 8, :, :], 1) * torch.stack((x_p0[:, 1, :, :], - x_p1[:, 1, :, :], - x_p2[:, 1, :, :], - x_p3[:, 1, :, :]), - 1), 1, True), + x_p1[:, 1, :, :], + x_p2[:, 1, :, :], + x_p3[:, 1, :, :]), + 1), 1, True), torch.sum(F.softmax(attention_fusion[:, 8:, :, :], 1) * torch.stack((x_p0[:, 2, :, :], - x_p1[:, 2, :, :], - x_p2[:, 2, :, :], - x_p3[:, 2, :, :]), - 1), 1, True)), + x_p1[:, 2, :, :], + x_p2[:, 2, :, :], + x_p3[:, 2, :, :]), + 1), 1, True)), 1).clamp(min=0, max=1) if self.training: diff --git a/requirements.txt b/requirements.txt index 61f14c1..bcddc4f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ torchvision>=0.9.0 numpy>=1.20.2 Pillow>=8.2.0 scipy>=1.7.1 +scikit-image>=0.18.3 diff --git a/resnext/resnext_101_32x4d_syn.py b/resnext/resnext_101_32x4d_syn.py deleted file mode 100644 index 8eb52d5..0000000 --- a/resnext/resnext_101_32x4d_syn.py +++ /dev/null @@ -1,690 +0,0 @@ -import torch -import torch.nn as nn -from torch.autograd import Variable -from functools import reduce -# from sync_batchnorm import SynchronizedBatchNorm2d - - -class LambdaBase(nn.Sequential): - def __init__(self, fn, *args): - super(LambdaBase, self).__init__(*args) - self.lambda_func = fn - - def forward_prepare(self, input): - output = [] - for module in self._modules.values(): - output.append(module(input)) - return output if output else input - - -class Lambda(LambdaBase): - def forward(self, input): - return self.lambda_func(self.forward_prepare(input)) - - -class LambdaMap(LambdaBase): - def forward(self, input): - return list(map(self.lambda_func, self.forward_prepare(input))) - - -class LambdaReduce(LambdaBase): - def forward(self, input): - return reduce(self.lambda_func, self.forward_prepare(input)) - - -resnext_101_32x4d = nn.Sequential( # Sequential, - nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3), 1, 1, bias=False), - SynchronizedBatchNorm2d(64), - nn.ReLU(), - nn.MaxPool2d((3, 3), (2, 2), (1, 1)), - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(64, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(128), - nn.ReLU(), - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(128), - nn.ReLU(), - ), - nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(256), - ), - nn.Sequential( # Sequential, - nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(256), - ), - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(128), - nn.ReLU(), - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(128), - nn.ReLU(), - ), - nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(256), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(128), - nn.ReLU(), - nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(128), - nn.ReLU(), - ), - nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(256), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - ), - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(256), - nn.ReLU(), - nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(256), - nn.ReLU(), - ), - nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - ), - nn.Sequential( # Sequential, - nn.Conv2d(256, 512, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - ), - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(256), - nn.ReLU(), - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(256), - nn.ReLU(), - ), - nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(256), - nn.ReLU(), - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(256), - nn.ReLU(), - ), - nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(256), - nn.ReLU(), - nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(256), - nn.ReLU(), - ), - nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - ), - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - nn.Sequential( # Sequential, - nn.Conv2d(512, 1024, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(512), - nn.ReLU(), - ), - nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - ), - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - nn.ReLU(), - nn.Conv2d(1024, 1024, (3, 3), (2, 2), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(1024), - nn.ReLU(), - ), - nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(2048), - ), - nn.Sequential( # Sequential, - nn.Conv2d(1024, 2048, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(2048), - ), - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - nn.ReLU(), - nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(1024), - nn.ReLU(), - ), - nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(2048), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - nn.Sequential( # Sequential, - LambdaMap(lambda x: x, # ConcatTable, - nn.Sequential( # Sequential, - nn.Sequential( # Sequential, - nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(1024), - nn.ReLU(), - nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 32, bias=False), - SynchronizedBatchNorm2d(1024), - nn.ReLU(), - ), - nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), - SynchronizedBatchNorm2d(2048), - ), - Lambda(lambda x: x), # Identity, - ), - LambdaReduce(lambda x, y: x + y), # CAddTable, - nn.ReLU(), - ), - ), - nn.AvgPool2d((7, 7), (1, 1)), - Lambda(lambda x: x.view(x.size(0), -1)), # View, - nn.Sequential(Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x), nn.Linear(2048, 1000)), # Linear, -) diff --git a/resnext/resnext_regular.py b/resnext/resnext_regular.py index 0127cf4..3b8d52c 100644 --- a/resnext/resnext_regular.py +++ b/resnext/resnext_regular.py @@ -1,7 +1,6 @@ import torch from torch import nn -# from . import resnext_101_32x4d_, resnext_101_32x4d_syn from . import resnext_101_32x4d_ from .config import resnext_101_32_path diff --git a/train.py b/train.py index ec851bf..fa8a312 100644 --- a/train.py +++ b/train.py @@ -9,7 +9,7 @@ from torch.backends import cudnn from torch.utils.data import DataLoader -from config import train_its_root, test_sots_root +from config import TRAIN_ITS_ROOT, TEST_SOTS_ROOT from datasets import ITS, ImageFolder2 from misc import AvgMeter, check_mkdir from model import ours @@ -35,9 +35,9 @@ 'crop_size': 256 } -train_set = ITS(train_its_root, True, args['crop_size']) +train_set = ITS(TRAIN_ITS_ROOT, True, args['crop_size']) train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=16, shuffle=True, drop_last=True) -val_set = ImageFolder2(test_sots_root) +val_set = ImageFolder2(TEST_SOTS_ROOT) val_loader = DataLoader(val_set, batch_size=8) criterion = nn.L1Loss().cuda()