-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathUpdate.py
113 lines (87 loc) · 3.66 KB
/
Update.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import math
import pdb
class DatasetSplit(Dataset):
def __init__(self, dataset, idxs):
self.dataset = dataset
self.idxs = list(idxs)
def __len__(self):
return len(self.idxs)
def __getitem__(self, item):
image, label = self.dataset[self.idxs[item]]
return image, label
class LocalUpdate(object):
def __init__(self, args, dataset=None, idxs=None, pretrain=False):
self.args = args
self.loss_func = nn.CrossEntropyLoss()
self.selected_clients = []
self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)
self.pretrain = pretrain
def train(self, net, idx=-1, lr=0.1):
net.train()
# train and update
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.5)
epoch_loss = []
if self.pretrain:
local_eps = self.args.local_ep_pretrain
else:
local_eps = self.args.local_ep
for iter in range(local_eps):
batch_loss = []
for batch_idx, (images, labels) in enumerate(self.ldr_train):
images, labels = images.to(self.args.device), labels.to(self.args.device)
net.zero_grad()
log_probs = net(images)
loss = self.loss_func(log_probs, labels)
loss.backward()
optimizer.step()
batch_loss.append(loss.item())
epoch_loss.append(sum(batch_loss)/len(batch_loss))
return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
class LocalUpdateMTL(object):
def __init__(self, args, dataset=None, idxs=None, pretrain=False):
self.args = args
self.loss_func = nn.CrossEntropyLoss()
self.selected_clients = []
self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)
self.pretrain = pretrain
def train(self, net, lr=0.1, omega=None, W_glob=None, idx=None, w_glob_keys=None):
net.train()
# train and update
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.5)
epoch_loss = []
if self.pretrain:
local_eps = self.args.local_ep_pretrain
else:
local_eps = self.args.local_ep
for iter in range(local_eps):
batch_loss = []
for batch_idx, (images, labels) in enumerate(self.ldr_train):
images, labels = images.to(self.args.device), labels.to(self.args.device)
net.zero_grad()
log_probs = net(images)
loss = self.loss_func(log_probs, labels)
W = W_glob.clone()
W_local = [net.state_dict(keep_vars=True)[key].flatten() for key in w_glob_keys]
W_local = torch.cat(W_local)
W[:, idx] = W_local
loss_regularizer = 0
loss_regularizer += W.norm() ** 2
k = 4000
for i in range(W.shape[0] // k):
x = W[i * k:(i+1) * k, :]
loss_regularizer += x.mm(omega).mm(x.T).trace()
f = (int)(math.log10(W.shape[0])+1) + 1
loss_regularizer *= 10 ** (-f)
loss = loss + loss_regularizer
loss.backward()
optimizer.step()
batch_loss.append(loss.item())
epoch_loss.append(sum(batch_loss)/len(batch_loss))
return net.state_dict(), sum(epoch_loss) / len(epoch_loss)