-
Notifications
You must be signed in to change notification settings - Fork 119
/
Copy pathtrain_srresnet.py
152 lines (120 loc) · 5.35 KB
/
train_srresnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import time
import torch.backends.cudnn as cudnn
import torch
from torch import nn
from models import SRResNet
from datasets import SRDataset
from utils import *
# Data parameters
data_folder = './' # folder with JSON data files
crop_size = 96 # crop size of target HR images
scaling_factor = 4 # the scaling factor for the generator; the input LR images will be downsampled from the target HR images by this factor
# Model parameters
large_kernel_size = 9 # kernel size of the first and last convolutions which transform the inputs and outputs
small_kernel_size = 3 # kernel size of all convolutions in-between, i.e. those in the residual and subpixel convolutional blocks
n_channels = 64 # number of channels in-between, i.e. the input and output channels for the residual and subpixel convolutional blocks
n_blocks = 16 # number of residual blocks
# Learning parameters
checkpoint = None # path to model checkpoint, None if none
batch_size = 16 # batch size
start_epoch = 0 # start at this epoch
iterations = 1e6 # number of training iterations
workers = 4 # number of workers for loading data in the DataLoader
print_freq = 500 # print training status once every __ batches
lr = 1e-4 # learning rate
grad_clip = None # clip if gradients are exploding
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cudnn.benchmark = True
def main():
"""
Training.
"""
global start_epoch, epoch, checkpoint
# Initialize model or load checkpoint
if checkpoint is None:
model = SRResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size,
n_channels=n_channels, n_blocks=n_blocks, scaling_factor=scaling_factor)
# Initialize the optimizer
optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),
lr=lr)
else:
checkpoint = torch.load(checkpoint)
start_epoch = checkpoint['epoch'] + 1
model = checkpoint['model']
optimizer = checkpoint['optimizer']
# Move to default device
model = model.to(device)
criterion = nn.MSELoss().to(device)
# Custom dataloaders
train_dataset = SRDataset(data_folder,
split='train',
crop_size=crop_size,
scaling_factor=scaling_factor,
lr_img_type='imagenet-norm',
hr_img_type='[-1, 1]')
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers,
pin_memory=True) # note that we're passing the collate function here
# Total number of epochs to train for
epochs = int(iterations // len(train_loader) + 1)
# Epochs
for epoch in range(start_epoch, epochs):
# One epoch's training
train(train_loader=train_loader,
model=model,
criterion=criterion,
optimizer=optimizer,
epoch=epoch)
# Save checkpoint
torch.save({'epoch': epoch,
'model': model,
'optimizer': optimizer},
'checkpoint_srresnet.pth.tar')
def train(train_loader, model, criterion, optimizer, epoch):
"""
One epoch's training.
:param train_loader: DataLoader for training data
:param model: model
:param criterion: content loss function (Mean Squared-Error loss)
:param optimizer: optimizer
:param epoch: epoch number
"""
model.train() # training mode enables batch normalization
batch_time = AverageMeter() # forward prop. + back prop. time
data_time = AverageMeter() # data loading time
losses = AverageMeter() # loss
start = time.time()
# Batches
for i, (lr_imgs, hr_imgs) in enumerate(train_loader):
data_time.update(time.time() - start)
# Move to default device
lr_imgs = lr_imgs.to(device) # (batch_size (N), 3, 24, 24), imagenet-normed
hr_imgs = hr_imgs.to(device) # (batch_size (N), 3, 96, 96), in [-1, 1]
# Forward prop.
sr_imgs = model(lr_imgs) # (N, 3, 96, 96), in [-1, 1]
# Loss
loss = criterion(sr_imgs, hr_imgs) # scalar
# Backward prop.
optimizer.zero_grad()
loss.backward()
# Clip gradients, if necessary
if grad_clip is not None:
clip_gradient(optimizer, grad_clip)
# Update model
optimizer.step()
# Keep track of loss
losses.update(loss.item(), lr_imgs.size(0))
# Keep track of batch time
batch_time.update(time.time() - start)
# Reset start time
start = time.time()
# Print status
if i % print_freq == 0:
print('Epoch: [{0}][{1}/{2}]----'
'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})----'
'Data Time {data_time.val:.3f} ({data_time.avg:.3f})----'
'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(epoch, i, len(train_loader),
batch_time=batch_time,
data_time=data_time, loss=losses))
del lr_imgs, hr_imgs, sr_imgs # free some memory since their histories may be stored
if __name__ == '__main__':
main()