forked from xinntao/BasicSR-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexample_model.py
86 lines (70 loc) · 3 KB
/
example_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from collections import OrderedDict
from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.models.sr_model import SRModel
from basicsr.utils import get_root_logger
from basicsr.utils.registry import MODEL_REGISTRY
@MODEL_REGISTRY.register()
class ExampleModel(SRModel):
"""Example model based on the SRModel class.
In this example model, we want to implement a new model that trains with both L1 and L2 loss.
New defined functions:
init_training_settings(self)
feed_data(self, data)
optimize_parameters(self, current_iter)
Inherited functions:
__init__(self, opt)
setup_optimizers(self)
test(self)
dist_validation(self, dataloader, current_iter, tb_logger, save_img)
nondist_validation(self, dataloader, current_iter, tb_logger, save_img)
_log_validation_metric_values(self, current_iter, dataset_name, tb_logger)
get_current_visuals(self)
save(self, epoch, current_iter)
"""
def init_training_settings(self):
self.net_g.train()
train_opt = self.opt['train']
self.ema_decay = train_opt.get('ema_decay', 0)
if self.ema_decay > 0:
logger = get_root_logger()
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
# define network net_g with Exponential Moving Average (EMA)
# net_g_ema is used only for testing on one GPU and saving
# There is no need to wrap with DistributedDataParallel
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
else:
self.model_ema(0) # copy net_g weight
self.net_g_ema.eval()
# define losses
self.l1_pix = build_loss(train_opt['l1_opt']).to(self.device)
self.l2_pix = build_loss(train_opt['l2_opt']).to(self.device)
# set up optimizers and schedulers
self.setup_optimizers()
self.setup_schedulers()
def feed_data(self, data):
self.lq = data['lq'].to(self.device)
if 'gt' in data:
self.gt = data['gt'].to(self.device)
def optimize_parameters(self, current_iter):
self.optimizer_g.zero_grad()
self.output = self.net_g(self.lq)
l_total = 0
loss_dict = OrderedDict()
# l1 loss
l_l1 = self.l1_pix(self.output, self.gt)
l_total += l_l1
loss_dict['l_l1'] = l_l1
# l2 loss
l_l2 = self.l2_pix(self.output, self.gt)
l_total += l_l2
loss_dict['l_l2'] = l_l2
l_total.backward()
self.optimizer_g.step()
self.log_dict = self.reduce_loss_dict(loss_dict)
if self.ema_decay > 0:
self.model_ema(decay=self.ema_decay)