diff --git a/README.md b/README.md index 071714e..bc303d6 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ The dehazing results can be found at ## Installation & Preparation -Make sure you have `Python>=3.6` installed on your machine. +Make sure you have `Python>=3.7` installed on your machine. **Environment setup:** @@ -22,7 +22,7 @@ Make sure you have `Python>=3.6` installed on your machine. conda create -n dm2f conda activate dm2f -2. Install dependencies: +2. Install dependencies (test with PyTorch 1.8.0): 1. Install pytorch==1.8.0 torchvision==0.9.0 (via conda, recommend). @@ -42,22 +42,24 @@ Make sure you have `Python>=3.6` installed on your machine. 2. Set the path of datasets in config.py 3. Run by ```python train.py``` -The pretrained ResNeXt model is ported from the [official](https://github.com/facebookresearch/ResNeXt) torch version, +~~The pretrained ResNeXt model is ported from the [official](https://github.com/facebookresearch/ResNeXt) torch version, using the [convertor](https://github.com/clcarwin/convert_torch_to_pytorch) provided by clcarwin. -You can directly [download](https://drive.google.com/open?id=1dnH-IHwmu9xFPlyndqI6MfF4LvH6JKNQ) the pretrained model ported by me. +You can directly [download](https://drive.google.com/open?id=1dnH-IHwmu9xFPlyndqI6MfF4LvH6JKNQ) the pretrained model ported by me.~~ -*Hyper-parameters* of training were gathered at the beginning of *train.py* and you can conveniently +Use pretrained ResNeXt (resnext101_32x8d) from torchvision. + +*Hyper-parameters* of training were set at the top of *train.py*, and you can conveniently change them as you need. -Training a model on a single GTX 1080Ti GPU takes about 4 hours. +Training a model on a single ~~GTX 1080Ti~~ TITAN RTX GPU takes about ~~4~~ 5 hours. ## Testing 1. Set the path of five benchmark datasets in config.py. 2. Put the trained model in `./ckpt/`. -2. Run by ```python infer.py``` +2. Run by ```python test.py``` -*Settings* of testing were gathered at the beginning of *infer.py* and you can conveniently +*Settings* of testing were set at the top of `test.py`, and you can conveniently change them as you need. ## License diff --git a/model.py b/model.py index 4615022..4998d94 100644 --- a/model.py +++ b/model.py @@ -2,6 +2,7 @@ import torch.nn.functional as F from torch import nn +import torchvision.models as models from resnext import ResNeXt101 @@ -727,16 +728,23 @@ def forward(self, x0): class DM2FNet(Base): - def __init__(self, num_features=128): + def __init__(self, num_features=128, arch='resnext101_32x8d'): super(DM2FNet, self).__init__() self.num_features = num_features - resnext = ResNeXt101() - self.layer0 = resnext.layer0 - self.layer1 = resnext.layer1 - self.layer2 = resnext.layer2 - self.layer3 = resnext.layer3 - self.layer4 = resnext.layer4 + # resnext = ResNeXt101() + # + # self.layer0 = resnext.layer0 + # self.layer1 = resnext.layer1 + # self.layer2 = resnext.layer2 + # self.layer3 = resnext.layer3 + # self.layer4 = resnext.layer4 + + assert arch in ['resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] + backbone = models.__dict__[arch](pretrained=True) + del backbone.fc + self.backbone = backbone self.down1 = nn.Sequential( nn.Conv2d(256, num_features, kernel_size=1), nn.SELU() @@ -826,11 +834,23 @@ def __init__(self, num_features=128): def forward(self, x0, x0_hd=None): x = (x0 - self.mean) / self.std - layer0 = self.layer0(x) - layer1 = self.layer1(layer0) - layer2 = self.layer2(layer1) - layer3 = self.layer3(layer2) - layer4 = self.layer4(layer3) + backbone = self.backbone + + layer0 = backbone.conv1(x) + layer0 = backbone.bn1(layer0) + layer0 = backbone.relu(layer0) + layer0 = backbone.maxpool(layer0) + + layer1 = backbone.layer1(layer0) + layer2 = backbone.layer2(layer1) + layer3 = backbone.layer3(layer2) + layer4 = backbone.layer4(layer3) + + # layer0 = self.layer0(x) + # layer1 = self.layer1(layer0) + # layer2 = self.layer2(layer1) + # layer3 = self.layer3(layer2) + # layer4 = self.layer4(layer3) down1 = self.down1(layer1) down2 = self.down2(layer2) diff --git a/resnext/resnext_101_32x4d_.py b/resnext/resnext_101_32x4d_.py index 979e640..e4fe42a 100644 --- a/resnext/resnext_101_32x4d_.py +++ b/resnext/resnext_101_32x4d_.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn -from torch.autograd import Variable class LambdaBase(nn.Sequential): diff --git a/test.py b/test.py index 8e15b8a..132d5f7 100644 --- a/test.py +++ b/test.py @@ -20,7 +20,7 @@ ckpt_path = './ckpt' exp_name = 'RESIDE_ITS' args = { - 'snapshot': 'iter_40000_loss_0.01256_lr_0.000000' + 'snapshot': 'iter_40000_loss_0.01230_lr_0.000000' } to_test = {'SOTS': TEST_SOTS_ROOT}