forked from mperezcarrasco/PyTorch-Deep-SVDD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
111 lines (87 loc) · 3.94 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from torch import optim
import torch.nn.functional as F
import numpy as np
from barbar import Bar
from model import autoencoder, network
from utils.utils import weights_init_normal
class TrainerDeepSVDD:
def __init__(self, args, data, device):
self.args = args
self.train_loader, self.test_loader = data
self.device = device
def pretrain(self):
""" Pretraining the weights for the deep SVDD network using autoencoder"""
ae = autoencoder(self.args.latent_dim).to(self.device)
ae.apply(weights_init_normal)
optimizer = optim.Adam(ae.parameters(), lr=self.args.lr_ae,
weight_decay=self.args.weight_decay_ae)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
milestones=self.args.lr_milestones, gamma=0.1)
ae.train()
for epoch in range(self.args.num_epochs_ae):
total_loss = 0
for x, _ in Bar(self.train_loader):
x = x.float().to(self.device)
optimizer.zero_grad()
x_hat = ae(x)
reconst_loss = torch.mean(torch.sum((x_hat - x) ** 2, dim=tuple(range(1, x_hat.dim()))))
reconst_loss.backward()
optimizer.step()
total_loss += reconst_loss.item()
scheduler.step()
print('Pretraining Autoencoder... Epoch: {}, Loss: {:.3f}'.format(
epoch, total_loss/len(self.train_loader)))
self.save_weights_for_DeepSVDD(ae, self.train_loader)
def save_weights_for_DeepSVDD(self, model, dataloader):
"""Initialize Deep SVDD weights using the encoder weights of the pretrained autoencoder."""
c = self.set_c(model, dataloader)
net = network(self.args.latent_dim).to(self.device)
state_dict = model.state_dict()
net.load_state_dict(state_dict, strict=False)
torch.save({'center': c.cpu().data.numpy().tolist(),
'net_dict': net.state_dict()}, 'weights/pretrained_parameters.pth')
def set_c(self, model, dataloader, eps=0.1):
"""Initializing the center for the hypersphere"""
model.eval()
z_ = []
with torch.no_grad():
for x, _ in dataloader:
x = x.float().to(self.device)
z = model.encode(x)
z_.append(z.detach())
z_ = torch.cat(z_)
c = torch.mean(z_, dim=0)
c[(abs(c) < eps) & (c < 0)] = -eps
c[(abs(c) < eps) & (c > 0)] = eps
return c
def train(self):
"""Training the Deep SVDD model"""
net = network().to(self.device)
if self.args.pretrain==True:
state_dict = torch.load('weights/pretrained_parameters.pth')
net.load_state_dict(state_dict['net_dict'])
c = torch.Tensor(state_dict['center']).to(self.device)
else:
net.apply(weights_init_normal)
c = torch.randn(self.args.latent_dim).to(self.device)
optimizer = optim.Adam(net.parameters(), lr=self.args.lr,
weight_decay=self.args.weight_decay)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
milestones=self.args.lr_milestones, gamma=0.1)
net.train()
for epoch in range(self.args.num_epochs):
total_loss = 0
for x, _ in Bar(self.train_loader):
x = x.float().to(self.device)
optimizer.zero_grad()
z = net(x)
loss = torch.mean(torch.sum((z - c) ** 2, dim=1))
loss.backward()
optimizer.step()
total_loss += loss.item()
scheduler.step()
print('Training Deep SVDD... Epoch: {}, Loss: {:.3f}'.format(
epoch, total_loss/len(self.train_loader)))
self.net = net
self.c = c