-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmain.py
122 lines (94 loc) · 4.83 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from __future__ import print_function
import os
import argparse
from glob import glob
import numpy as np
from PIL import Image
import tensorflow as tf
from extra.model import lowlight_enhance
from extra.utils import *
parser = argparse.ArgumentParser(description='')
parser.add_argument('--use_gpu', dest='use_gpu', type=int, default=0, help='gpu flag, 1 for GPU and 0 for CPU') #改gpu
parser.add_argument('--gpu_idx', dest='gpu_idx', default="0", help='GPU idx')
parser.add_argument('--gpu_mem', dest='gpu_mem', type=float, default=0.5, help="0 to 1, gpu memory usage")
parser.add_argument('--phase', dest='phase', default='train', help='train or test') #改测试
parser.add_argument('--epoch', dest='epoch', type=int, default=100, help='number of total epoches')
parser.add_argument('--batch_size', dest='batch_size', type=int, default=16, help='number of samples in one batch')
parser.add_argument('--patch_size', dest='patch_size', type=int, default=48, help='patch size')
parser.add_argument('--start_lr', dest='start_lr', type=float, default=0.001, help='initial learning rate for adam')
parser.add_argument('--eval_every_epoch', dest='eval_every_epoch', default=20, help='evaluating and saving checkpoints every # epoch')
parser.add_argument('--checkpoint_dir', dest='ckpt_dir', default='./checkpoint', help='directory for checkpoints')
parser.add_argument('--sample_dir', dest='sample_dir', default='./sample', help='directory for evaluating outputs')
parser.add_argument('--save_dir', dest='save_dir', default='./test_results', help='directory for testing outputs')
parser.add_argument('--test_dir', dest='test_dir', default='./data/test/low', help='directory for testing inputs')
args = parser.parse_args()
def lowlight_train(lowlight_enhance):
if not os.path.exists(args.ckpt_dir):
os.makedirs(args.ckpt_dir)
if not os.path.exists(args.sample_dir):
os.makedirs(args.sample_dir)
lr = args.start_lr * np.ones([args.epoch])
lr[20:50] = lr[0] / 10.0 #学习率衰减,20个epoch后学习率为初始的0.1
lr[50:] = lr[20] / 10.0
train_low_data = []
train_high_data = []
'''读入训练图片'''
train_low_data_names = glob('./data/train/low/*.*')
train_low_data_names.sort()
train_high_data_names = glob('./data/train/high/*.*')
train_high_data_names.sort()
assert len(train_low_data_names) == len(train_high_data_names)
print('[*] Number of training data: %d' % len(train_low_data_names))
'''读入验证图片'''
eval_low_data = []
eval_high_data = []
eval_low_data_name = glob('./data/eval/low/*.*')
for idx in range(len(eval_low_data_name)):
eval_low_im = load_images(eval_low_data_name[idx])
eval_low_data.append(eval_low_im)
lowlight_enhance.train(train_low_data_names, train_high_data_names, eval_low_data, batch_size=args.batch_size,
patch_size=args.patch_size, epoch=args.epoch, lr=lr, sample_dir=args.sample_dir,
ckpt_dir=os.path.join(args.ckpt_dir, 'Denoising'), eval_every_epoch=args.eval_every_epoch,
train_phase="Denoising")
def lowlight_test(lowlight_enhance):
if args.test_dir == None:
print("[!] please provide --test_dir")
exit(0)
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
test_low_data_name = glob(os.path.join(args.test_dir) + '/*.*')
test_low_data = []
test_high_data = []
for idx in range(len(test_low_data_name)):
test_low_im = load_images(test_low_data_name[idx])
test_low_data.append(test_low_im)
lowlight_enhance.test(test_low_data, test_high_data, test_low_data_name, save_dir=args.save_dir)
def main(_):
if args.use_gpu:
print("[*] GPU\n")
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_idx
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
if args.phase == 'train':
model = lowlight_enhance(sess, is_training=True)
lowlight_train(model)
elif args.phase == 'test':
model = lowlight_enhance(sess, is_training=False)
lowlight_test(model)
else:
print('[!] Unknown phase')
exit(0)
else:
print("[*] CPU\n")
with tf.Session() as sess:
if args.phase == 'train':
model = lowlight_enhance(sess,is_training=True)
lowlight_train(model)
elif args.phase == 'test':
model = lowlight_enhance(sess, is_training=False)
lowlight_test(model)
else:
print('[!] Unknown phase')
exit(0)
if __name__ == '__main__':
tf.app.run()