Skip to content

Commit

Permalink
Update test
Browse files Browse the repository at this point in the history
  • Loading branch information
jiaqixuac committed Oct 10, 2021
1 parent 04cd00b commit d6f231f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ numpy>=1.20.2
Pillow>=8.2.0
scipy>=1.7.1
scikit-image>=0.18.3
tqdm>=4.62.3
tqdm>=4.62.2
16 changes: 12 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

import numpy as np
import torch
from torch import nn
from torchvision import transforms

from tools.config import TEST_SOTS_ROOT, OHAZE_ROOT
from tools.utils import check_mkdir, sliding_forward
from tools.utils import AvgMeter, check_mkdir, sliding_forward
from model import DM2FNet, DM2FNet_woPhy
from datasets import SotsDataset, OHazeDataset
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -36,6 +37,8 @@

def main():
with torch.no_grad():
criterion = nn.L1Loss().cuda()

for name, root in to_test.items():
if 'SOTS' in name:
net = DM2FNet().cuda()
Expand All @@ -56,13 +59,15 @@ def main():
dataloader = DataLoader(dataset, batch_size=1)

psnrs, ssims = [], []
loss_record = AvgMeter()

for idx, data in enumerate(dataloader):
# haze_image, _, _, _, fs = data
haze, gts, fs = data
# print(haze_image.shape, gts.shape)
# print(haze.shape, gts.shape)

check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot'])))
check_mkdir(os.path.join(ckpt_path, exp_name,
'(%s) %s_%s' % (exp_name, name, args['snapshot'])))

haze = haze.cuda()

Expand All @@ -71,6 +76,9 @@ def main():
else:
res = net(haze).detach()

loss = criterion(res, gts.cuda())
loss_record.update(loss.item(), haze.size(0))

for i in range(len(fs)):
r = res[i].cpu().numpy().transpose([1, 2, 0])
gt = gts[i].cpu().numpy().transpose([1, 2, 0])
Expand All @@ -87,7 +95,7 @@ def main():
os.path.join(ckpt_path, exp_name,
'(%s) %s_%s' % (exp_name, name, args['snapshot']), '%s.png' % f))

print(f"[{name}] PSNR: {np.mean(psnrs):.6f}, SSIM: {np.mean(ssims):.6f}")
print(f"[{name}] L1: {loss_record.avg:.6f}, PSNR: {np.mean(psnrs):.6f}, SSIM: {np.mean(ssims):.6f}")


if __name__ == '__main__':
Expand Down

0 comments on commit d6f231f

Please sign in to comment.