-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_Sony.py
102 lines (78 loc) · 3.35 KB
/
test_Sony.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import time
import numpy as np
import rawpy
import glob
import torch
import torch.nn as nn
import torch.optim as optim
torch.cuda.empty_cache()
from PIL import Image
from model import SeeInDark
input_dir = '/Sony/short/'
gt_dir = '/Sony/long/'
m_path = 'Sony_test/saved_model/'
m_name = 'checkpoint_sony_e4000.pth'
result_dir = '/Sony_subset/test_results_Sony/'
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# print(f"Device: {device}")
device = torch.device('cpu')
#get test IDs
test_fns = glob.glob(gt_dir + '/1*.ARW')
test_ids = []
for i in range(len(test_fns)):
_, test_fn = os.path.split(test_fns[i])
test_ids.append(int(test_fn[0:5]))
def pack_raw(raw):
#pack Bayer image to 4 channels
im = np.maximum(raw - 512,0)/ (16383 - 512) #subtract the black level
im = np.expand_dims(im,axis=2)
img_shape = im.shape
H = img_shape[0]
W = img_shape[1]
out = np.concatenate((im[0:H:2,0:W:2,:],
im[0:H:2,1:W:2,:],
im[1:H:2,1:W:2,:],
im[1:H:2,0:W:2,:]), axis=2)
return out
model = SeeInDark()
model.load_state_dict(torch.load( m_path + m_name ,map_location=torch.device('cpu')))
model = model.to(device)
if not os.path.isdir(result_dir):
os.makedirs(result_dir)
for test_id in test_ids:
#test the first image in each sequence
in_files = glob.glob(input_dir + '%05d_00*.ARW'%test_id)
for k in range(len(in_files)):
in_path = in_files[k]
_, in_fn = os.path.split(in_path)
print(in_fn)
gt_files = glob.glob(gt_dir + '%05d_00*.ARW'%test_id)
gt_path = gt_files[0]
_, gt_fn = os.path.split(gt_path)
in_exposure = float(in_fn[9:-5])
gt_exposure = float(gt_fn[9:-5])
ratio = min(gt_exposure/in_exposure,300)
raw = rawpy.imread(in_path)
im = raw.raw_image_visible.astype(np.float32)
input_full = np.expand_dims(pack_raw(im),axis=0) *ratio
im = raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
scale_full = np.expand_dims(np.float32(im/65535.0),axis = 0)
gt_raw = rawpy.imread(gt_path)
im = gt_raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
gt_full = np.expand_dims(np.float32(im/65535.0),axis = 0)
input_full = np.minimum(input_full,1.0)
in_img = torch.from_numpy(input_full).permute(0,3,1,2).to(device)
with torch.no_grad():
out_img = model(in_img)
output = out_img.permute(0, 2, 3, 1).cpu().data.numpy()
output = np.minimum(np.maximum(output,0),1)
output = output[0,:,:,:]
gt_full = gt_full[0,:,:,:]
scale_full = scale_full[0,:,:,:]
origin_full = scale_full
scale_full = scale_full*np.mean(gt_full)/np.mean(scale_full) # scale the low-light image to the same mean of the groundtruth
Image.fromarray((origin_full*255).astype('uint8')).save(result_dir + '%5d_00_%d_ori.png'%(test_id,ratio))
Image.fromarray((output*255).astype('uint8')).save(result_dir + '%5d_00_%d_out.png'%(test_id,ratio))
Image.fromarray((scale_full*255).astype('uint8')).save(result_dir + '%5d_00_%d_scale.png'%(test_id,ratio))
Image.fromarray((gt_full*255).astype('uint8')).save(result_dir + '%5d_00_%d_gt.png'%(test_id,ratio))