-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathupdate_device_HEFL.py
137 lines (122 loc) · 5.11 KB
/
update_device_HEFL.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch.nn as nn
import torch.nn.functional as F
import gym
import torch
from sklearn.cluster import KMeans
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler
import copy
from torch.distributions import Categorical
import torch.optim as optim
from copy import deepcopy
import argparse
from models.Update import LocalUpdate
from models.Fed import FedAvg
from models.test import test_img
import matplotlib.pyplot as plt
#from tensorboardX import SummaryWriter
from torch.nn.utils import clip_grad_norm_
scaler_2 = MinMaxScaler(feature_range=(-5, 5))
def feature_cal(wlist, wglobal):
state_list = []
for i in range(100):
para_list = []
for (_, parameters1),(_, parameters2) in zip(wlist[i].items(),wglobal.items()):
para_list.append(torch.norm((parameters1.view(-1)- parameters2.view(-1)), p = 2,dim = 0).cpu().numpy().item())
state_list.append(torch.tensor(para_list).view(1,-1))
# para_list = torch.cat(para_list)
return torch.sum(torch.cat(state_list,0),1)
def get_state(clusteded_dic,state,i):
# if len(np.shape(state)) <3:
# state = np.expand_dims(state,0)
i = '%d'%(i)
device_idx = np.array(clusteded_dic[i])
device_num = len(np.array(clusteded_dic[i]))
return scaler_2.fit_transform(state[device_idx].reshape(-1,1)).reshape(1,-1)
def permute_device(action, clusteded_dic, cluster_index):
cluster_index = '%d'%(cluster_index)
real_index = action
device_list = clusteded_dic[cluster_index]
selected_device = device_list[real_index]
index_list = np.delete(device_list,real_index)
device_list = np.append(index_list,selected_device).tolist()
clusteded_dic[cluster_index] = device_list
return clusteded_dic
def cluster_devices(wlist):
dist_mx = []
class_index = {i: np.array([], dtype='int64') for i in range(10)}
for i in range(100):
temp = copy.deepcopy(wlist[i])
para_list1 = temp['fc2.weight']
dist_mx.append(para_list1.view(1,-1))
dist_mx = torch.cat(dist_mx,0)
kmeans = KMeans(n_clusters=10, random_state=0).fit(dist_mx.cpu())
for i in range(10):
class_index[i] = np.where(kmeans.labels_ == i)[0].tolist()
return class_index
def local_update(args,dataset_train,dataset_test,dict_users,idxs_users,epoch,net_glob,acc_list = None,cluster_acc = None, w_list = None, name = None,preset = None, coeff = None):
w_locals_temp = []
loss_locals = []
reward_n = []
if preset is None:
preset = 99
if coeff is None:
coeff = 1.2
if epoch ==0:
acc_list = [-100000]
w_list = []
elif epoch == 1:
cluster_acc = []
count = 0
for idx in idxs_users:
net_temp = copy.deepcopy(net_glob)
idx_str = '%d'%(idx)
local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx_str])
w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
result = w['fc2.weight'].view(1,-1)
w_locals_temp.append(copy.deepcopy(w))
loss_locals.append(copy.deepcopy(loss))
if epoch > 1:
net_temp.load_state_dict(copy.deepcopy(w))
acc_train, loss_train = test_img(net_temp, dataset_test, args)
# print(acc_train)
# print(max(cluster_acc[count]))
reward_n.append(np.clip(5*(acc_train.numpy().item()- max(cluster_acc[count])),-5,5)*min(1,epoch* 0.1))
# reward_n.append(np.clip((acc_train.numpy().item()- max(cluster_acc[count])),-1,1)*min(1,epoch* 0.1))
cluster_acc[count].append(acc_train.numpy().item())
elif epoch == 1:
net_temp.load_state_dict(copy.deepcopy(w))
acc_train, loss_train = test_img(net_temp, dataset_test, args)
cluster_acc.append([acc_train.numpy().item()])
count += 1
if epoch == 0:
w_list.append(w)
else:
w_list[idx] = copy.deepcopy(w)
w_glob = FedAvg(w_locals_temp)
state_matrix = feature_cal(w_list, w_glob)
net_glob.load_state_dict(w_glob)
# print loss
loss_avg = sum(loss_locals) / len(loss_locals)
# acc_test, loss_test = test_img(net_glob, dataset_test, args)
acc_train, loss_train = test_img(net_glob, dataset_test, args)
acc_list.append(acc_train)
print('Round {:3d}, Average loss {:.3f}, Training accuracy {:.3f}'.format(epoch, loss_avg,acc_train))
if name is not None:
file_handle=open(name,mode='a')
file_handle.write("Round:%.03f; Average loss: %.03f; Training accuracy:%.03f \n"%(epoch,loss_avg,acc_train))
file_handle.close()
if acc_train < preset :
done_n = [0] * 10
else:
done_n = [1] * 10
if epoch == 0 :
clusteded_dic = cluster_devices(w_list)
return state_matrix,net_glob, acc_list,clusteded_dic,w_list
else:
return state_matrix,net_glob, acc_list,cluster_acc,w_list,reward_n,done_n
# def get_state(clusteded_dic,state,i):
# i = '%d'%(i)
# device_idx = np.array(clusteded_dic[i])
# return torch.cat(((state[device_idx],state[-1].view(1,-1))),0).view(-1).cpu().detach().numpy()