From d6f231f8a3f1a3c5ca877b30d8778b9bb9b76668 Mon Sep 17 00:00:00 2001 From: jiaqixuac Date: Sun, 10 Oct 2021 11:54:05 +0800 Subject: [PATCH] Update test --- requirements.txt | 2 +- test.py | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index b47a587..cf30b0c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,4 @@ numpy>=1.20.2 Pillow>=8.2.0 scipy>=1.7.1 scikit-image>=0.18.3 -tqdm>=4.62.3 +tqdm>=4.62.2 diff --git a/test.py b/test.py index bff3864..efce9cf 100644 --- a/test.py +++ b/test.py @@ -3,10 +3,11 @@ import numpy as np import torch +from torch import nn from torchvision import transforms from tools.config import TEST_SOTS_ROOT, OHAZE_ROOT -from tools.utils import check_mkdir, sliding_forward +from tools.utils import AvgMeter, check_mkdir, sliding_forward from model import DM2FNet, DM2FNet_woPhy from datasets import SotsDataset, OHazeDataset from torch.utils.data import DataLoader @@ -36,6 +37,8 @@ def main(): with torch.no_grad(): + criterion = nn.L1Loss().cuda() + for name, root in to_test.items(): if 'SOTS' in name: net = DM2FNet().cuda() @@ -56,13 +59,15 @@ def main(): dataloader = DataLoader(dataset, batch_size=1) psnrs, ssims = [], [] + loss_record = AvgMeter() for idx, data in enumerate(dataloader): # haze_image, _, _, _, fs = data haze, gts, fs = data - # print(haze_image.shape, gts.shape) + # print(haze.shape, gts.shape) - check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']))) + check_mkdir(os.path.join(ckpt_path, exp_name, + '(%s) %s_%s' % (exp_name, name, args['snapshot']))) haze = haze.cuda() @@ -71,6 +76,9 @@ def main(): else: res = net(haze).detach() + loss = criterion(res, gts.cuda()) + loss_record.update(loss.item(), haze.size(0)) + for i in range(len(fs)): r = res[i].cpu().numpy().transpose([1, 2, 0]) gt = gts[i].cpu().numpy().transpose([1, 2, 0]) @@ -87,7 +95,7 @@ def main(): os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']), '%s.png' % f)) - print(f"[{name}] PSNR: {np.mean(psnrs):.6f}, SSIM: {np.mean(ssims):.6f}") + print(f"[{name}] L1: {loss_record.avg:.6f}, PSNR: {np.mean(psnrs):.6f}, SSIM: {np.mean(ssims):.6f}") if __name__ == '__main__':